|
|
import re
|
|
|
from typing import List, Dict, Any
|
|
|
import requests
|
|
|
import sqlite3
|
|
|
from vanna.base import VannaBase
|
|
|
from Config import *
|
|
|
|
|
|
|
|
|
class VannaUtil(VannaBase):
|
|
|
def __init__(self, db_type='sqlite', db_uri=None):
|
|
|
super().__init__()
|
|
|
self.api_key = MODEL_API_KEY
|
|
|
self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址
|
|
|
self.model = QWEN_MODEL_NAME # 根据实际模型名称调整
|
|
|
self.training_data = []
|
|
|
self.chat_history = []
|
|
|
self.db_type = db_type
|
|
|
self.db_uri = db_uri or 'Db/vanna.db' # 默认使用 SQLite
|
|
|
self._init_db()
|
|
|
|
|
|
def _init_db(self):
|
|
|
"""初始化数据库连接"""
|
|
|
if self.db_type == 'sqlite':
|
|
|
self.conn = sqlite3.connect(self.db_uri)
|
|
|
self._create_tables()
|
|
|
elif self.db_type == 'postgres':
|
|
|
import psycopg2
|
|
|
self.conn = psycopg2.connect(self.db_uri)
|
|
|
self._create_tables()
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported database type: {self.db_type}")
|
|
|
|
|
|
def _create_tables(self):
|
|
|
"""创建训练数据表"""
|
|
|
cursor = self.conn.cursor()
|
|
|
if self.db_type == 'sqlite':
|
|
|
cursor.execute('''
|
|
|
CREATE TABLE IF NOT EXISTS training_data (
|
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
type TEXT,
|
|
|
question TEXT,
|
|
|
sql TEXT,
|
|
|
content TEXT
|
|
|
)
|
|
|
''')
|
|
|
elif self.db_type == 'postgres':
|
|
|
cursor.execute('''
|
|
|
CREATE TABLE IF NOT EXISTS training_data (
|
|
|
id SERIAL PRIMARY KEY,
|
|
|
type TEXT,
|
|
|
question TEXT,
|
|
|
sql TEXT,
|
|
|
content TEXT
|
|
|
)
|
|
|
''')
|
|
|
self.conn.commit()
|
|
|
|
|
|
def add_ddl(self, ddl: str, **kwargs) -> None:
|
|
|
"""添加 DDL"""
|
|
|
cursor = self.conn.cursor()
|
|
|
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('ddl', ddl))
|
|
|
self.conn.commit()
|
|
|
|
|
|
def add_documentation(self, doc: str, **kwargs) -> None:
|
|
|
"""添加文档"""
|
|
|
cursor = self.conn.cursor()
|
|
|
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('documentation', doc))
|
|
|
self.conn.commit()
|
|
|
|
|
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
|
|
|
"""添加问答对"""
|
|
|
cursor = self.conn.cursor()
|
|
|
cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (?, ?, ?)', ('qa', question, sql))
|
|
|
self.conn.commit()
|
|
|
|
|
|
def get_related_ddl(self, question: str, **kwargs) -> str:
|
|
|
"""获取相关 DDL"""
|
|
|
cursor = self.conn.cursor()
|
|
|
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('ddl',))
|
|
|
return "\n".join(row[0] for row in cursor.fetchall())
|
|
|
|
|
|
def get_related_documentation(self, question: str, **kwargs) -> str:
|
|
|
"""获取相关文档"""
|
|
|
cursor = self.conn.cursor()
|
|
|
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('documentation',))
|
|
|
return "\n".join(row[0] for row in cursor.fetchall())
|
|
|
|
|
|
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
|
|
|
"""获取所有训练数据"""
|
|
|
cursor = self.conn.cursor()
|
|
|
cursor.execute('SELECT * FROM training_data')
|
|
|
columns = [column[0] for column in cursor.description]
|
|
|
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
|
|
|
|
|
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
|
"""删除训练数据"""
|
|
|
cursor = self.conn.cursor()
|
|
|
cursor.execute('DELETE FROM training_data WHERE id = ?', (id,))
|
|
|
self.conn.commit()
|
|
|
return cursor.rowcount > 0
|
|
|
|
|
|
def generate_embedding(self, text: str, **kwargs) -> List[float]:
|
|
|
"""生成嵌入向量"""
|
|
|
return []
|
|
|
|
|
|
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
|
|
|
"""获取相似问答对"""
|
|
|
return []
|
|
|
|
|
|
def system_message(self, message: str) -> None:
|
|
|
"""添加系统消息"""
|
|
|
self.chat_history = [{"role": "system", "content": message}]
|
|
|
|
|
|
def user_message(self, message: str) -> None:
|
|
|
"""添加用户消息"""
|
|
|
self.chat_history.append({"role": "user", "content": message})
|
|
|
|
|
|
def assistant_message(self, message: str) -> None:
|
|
|
"""添加助手消息"""
|
|
|
self.chat_history.append({"role": "assistant", "content": message})
|
|
|
|
|
|
def submit_prompt(self, prompt: str, **kwargs) -> str:
|
|
|
"""提交提示词"""
|
|
|
return self.generate_sql(question=prompt)
|
|
|
|
|
|
def _clean_sql_output(self, raw_sql: str) -> str:
|
|
|
"""增强版清洗逻辑"""
|
|
|
# 移除所有非SQL内容
|
|
|
cleaned = re.sub(r'^.*?(?=SELECT)', '', raw_sql, flags=re.IGNORECASE | re.DOTALL)
|
|
|
# 提取第一个完整SQL语句
|
|
|
match = re.search(r'(SELECT\s.+?;)', cleaned, re.IGNORECASE | re.DOTALL)
|
|
|
if match:
|
|
|
# 标准化空格和换行
|
|
|
clean_sql = re.sub(r'\s+', ' ', match.group(1)).strip()
|
|
|
# 确保没有重复SELECT
|
|
|
clean_sql = re.sub(r'(SELECT\s+)+', 'SELECT ', clean_sql, flags=re.IGNORECASE)
|
|
|
return clean_sql
|
|
|
return raw_sql
|
|
|
|
|
|
def _build_sql_prompt(self, question: str) -> str:
|
|
|
"""强化提示词"""
|
|
|
return f"""严格按以下要求生成Postgresql查询:
|
|
|
|
|
|
表结构:
|
|
|
{self.get_related_ddl(question)}
|
|
|
|
|
|
问题:{question}
|
|
|
|
|
|
生成规则:
|
|
|
1. 只输出一个标准的SELECT语句
|
|
|
2. 绝对不要使用任何代码块标记
|
|
|
3. 语句以分号结尾
|
|
|
4. 不要包含任何解释或注释
|
|
|
5. 若需要多表查询,使用显式JOIN语法
|
|
|
6. 确保没有重复的SELECT关键字
|
|
|
"""
|
|
|
|
|
|
def generate_sql(self, question: str, **kwargs) -> str:
|
|
|
"""同步生成SQL"""
|
|
|
try:
|
|
|
headers = {
|
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
|
"Content-Type": "application/json"
|
|
|
}
|
|
|
|
|
|
data = {
|
|
|
"model": self.model,
|
|
|
"input": {
|
|
|
"messages": [{
|
|
|
"role": "user",
|
|
|
"content": self._build_sql_prompt(question)
|
|
|
}]
|
|
|
},
|
|
|
"parameters": {
|
|
|
"temperature": 0.1,
|
|
|
"max_tokens": 5000,
|
|
|
"result_format": "text"
|
|
|
}
|
|
|
}
|
|
|
|
|
|
response = requests.post(self.base_url, headers=headers, json=data)
|
|
|
response.raise_for_status()
|
|
|
|
|
|
raw_sql = response.json()['output']['text']
|
|
|
return self._clean_sql_output(raw_sql)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"\nAPI请求错误: {str(e)}")
|
|
|
return "" |