You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
7.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import re
from typing import List, Dict, Any
import requests
import sqlite3
from vanna.base import VannaBase
from Config import *
class VannaUtil(VannaBase):
def __init__(self, db_type='sqlite', db_uri=None):
super().__init__()
self.api_key = MODEL_API_KEY
self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址
self.model = QWEN_MODEL_NAME # 根据实际模型名称调整
self.training_data = []
self.chat_history = []
self.db_type = db_type
self.db_uri = db_uri or 'Db/vanna.db' # 默认使用 SQLite
self._init_db()
def _init_db(self):
"""初始化数据库连接"""
if self.db_type == 'sqlite':
self.conn = sqlite3.connect(self.db_uri)
self._create_tables()
elif self.db_type == 'postgres':
import psycopg2
self.conn = psycopg2.connect(self.db_uri)
self._create_tables()
else:
raise ValueError(f"Unsupported database type: {self.db_type}")
def _create_tables(self):
"""创建训练数据表"""
cursor = self.conn.cursor()
if self.db_type == 'sqlite':
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
elif self.db_type == 'postgres':
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id SERIAL PRIMARY KEY,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
self.conn.commit()
def add_ddl(self, ddl: str, **kwargs) -> None:
"""添加 DDL"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('ddl', ddl))
self.conn.commit()
def add_documentation(self, doc: str, **kwargs) -> None:
"""添加文档"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('documentation', doc))
self.conn.commit()
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
"""添加问答对"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (?, ?, ?)', ('qa', question, sql))
self.conn.commit()
def get_related_ddl(self, question: str, **kwargs) -> str:
"""获取相关 DDL"""
cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('ddl',))
return "\n".join(row[0] for row in cursor.fetchall())
def get_related_documentation(self, question: str, **kwargs) -> str:
"""获取相关文档"""
cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('documentation',))
return "\n".join(row[0] for row in cursor.fetchall())
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
"""获取所有训练数据"""
cursor = self.conn.cursor()
cursor.execute('SELECT * FROM training_data')
columns = [column[0] for column in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
def remove_training_data(self, id: str, **kwargs) -> bool:
"""删除训练数据"""
cursor = self.conn.cursor()
cursor.execute('DELETE FROM training_data WHERE id = ?', (id,))
self.conn.commit()
return cursor.rowcount > 0
def generate_embedding(self, text: str, **kwargs) -> List[float]:
"""生成嵌入向量"""
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
"""获取相似问答对"""
return []
def system_message(self, message: str) -> None:
"""添加系统消息"""
self.chat_history = [{"role": "system", "content": message}]
def user_message(self, message: str) -> None:
"""添加用户消息"""
self.chat_history.append({"role": "user", "content": message})
def assistant_message(self, message: str) -> None:
"""添加助手消息"""
self.chat_history.append({"role": "assistant", "content": message})
def submit_prompt(self, prompt: str, **kwargs) -> str:
"""提交提示词"""
return self.generate_sql(question=prompt)
def _clean_sql_output(self, raw_sql: str) -> str:
"""增强版清洗逻辑"""
# 移除所有非SQL内容
cleaned = re.sub(r'^.*?(?=SELECT)', '', raw_sql, flags=re.IGNORECASE | re.DOTALL)
# 提取第一个完整SQL语句
match = re.search(r'(SELECT\s.+?;)', cleaned, re.IGNORECASE | re.DOTALL)
if match:
# 标准化空格和换行
clean_sql = re.sub(r'\s+', ' ', match.group(1)).strip()
# 确保没有重复SELECT
clean_sql = re.sub(r'(SELECT\s+)+', 'SELECT ', clean_sql, flags=re.IGNORECASE)
return clean_sql
return raw_sql
def _build_sql_prompt(self, question: str) -> str:
"""强化提示词"""
return f"""严格按以下要求生成Postgresql查询
表结构:
{self.get_related_ddl(question)}
问题:{question}
生成规则:
1. 只输出一个标准的SELECT语句
2. 绝对不要使用任何代码块标记
3. 语句以分号结尾
4. 不要包含任何解释或注释
5. 若需要多表查询使用显式JOIN语法
6. 确保没有重复的SELECT关键字
"""
def generate_sql(self, question: str, **kwargs) -> str:
"""同步生成SQL"""
try:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model,
"input": {
"messages": [{
"role": "user",
"content": self._build_sql_prompt(question)
}]
},
"parameters": {
"temperature": 0.1,
"max_tokens": 5000,
"result_format": "text"
}
}
response = requests.post(self.base_url, headers=headers, json=data)
response.raise_for_status()
raw_sql = response.json()['output']['text']
return self._clean_sql_output(raw_sql)
except Exception as e:
print(f"\nAPI请求错误: {str(e)}")
return ""