You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

170 lines
6.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import re
from typing import List, Dict, Any
import requests
import psycopg2
from vanna.base import VannaBase
from Config import *
class VannaUtil(VannaBase):
def __init__(self, db_uri=None):
super().__init__()
self.api_key = MODEL_API_KEY
self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址
self.model = QWEN_MODEL_NAME # 根据实际模型名称调整
self.training_data = []
self.chat_history = []
self.db_uri = db_uri or VANNA_POSTGRESQL_URI # 默认 PostgreSQL 连接字符串
self._init_db()
def _init_db(self):
"""初始化 PostgreSQL 数据库连接"""
self.conn = psycopg2.connect(self.db_uri)
self._create_tables()
def _create_tables(self):
"""创建训练数据表"""
cursor = self.conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id SERIAL PRIMARY KEY,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
self.conn.commit()
def add_ddl(self, ddl: str, **kwargs) -> None:
"""添加 DDL"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (%s, %s)', ('ddl', ddl))
self.conn.commit()
def add_documentation(self, doc: str, **kwargs) -> None:
"""添加文档"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (%s, %s)', ('documentation', doc))
self.conn.commit()
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
"""添加问答对"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (%s, %s, %s)', ('qa', question, sql))
self.conn.commit()
def get_related_ddl(self, question: str, **kwargs) -> str:
"""获取相关 DDL"""
cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = %s', ('ddl',))
return "\n".join(row[0] for row in cursor.fetchall())
def get_related_documentation(self, question: str, **kwargs) -> str:
"""获取相关文档"""
cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = %s', ('documentation',))
return "\n".join(row[0] for row in cursor.fetchall())
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
"""获取所有训练数据"""
cursor = self.conn.cursor()
cursor.execute('SELECT * FROM training_data')
columns = [column[0] for column in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
def remove_training_data(self, id: str, **kwargs) -> bool:
"""删除训练数据"""
cursor = self.conn.cursor()
cursor.execute('DELETE FROM training_data WHERE id = %s', (id,))
self.conn.commit()
return cursor.rowcount > 0
def generate_embedding(self, text: str, **kwargs) -> List[float]:
"""生成嵌入向量"""
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
"""获取相似问答对"""
return []
def system_message(self, message: str) -> None:
"""添加系统消息"""
self.chat_history = [{"role": "system", "content": message}]
def user_message(self, message: str) -> None:
"""添加用户消息"""
self.chat_history.append({"role": "user", "content": message})
def assistant_message(self, message: str) -> None:
"""添加助手消息"""
self.chat_history.append({"role": "assistant", "content": message})
def submit_prompt(self, prompt: str, **kwargs) -> str:
"""提交提示词"""
return self.generate_sql(question=prompt)
def _clean_sql_output(self, raw_sql: str) -> str:
"""增强版清洗逻辑"""
# 移除所有非SQL内容
cleaned = re.sub(r'^.*?(?=SELECT)', '', raw_sql, flags=re.IGNORECASE | re.DOTALL)
# 提取第一个完整SQL语句
match = re.search(r'(SELECT\s.+?;)', cleaned, re.IGNORECASE | re.DOTALL)
if match:
# 标准化空格和换行
clean_sql = re.sub(r'\s+', ' ', match.group(1)).strip()
# 确保没有重复SELECT
clean_sql = re.sub(r'(SELECT\s+)+', 'SELECT ', clean_sql, flags=re.IGNORECASE)
return clean_sql
return raw_sql
def _build_sql_prompt(self, question: str) -> str:
"""强化提示词"""
return f"""严格按以下要求生成Postgresql查询
表结构:
{self.get_related_ddl(question)}
问题:{question}
生成规则:
1. 只输出一个标准的SELECT语句
2. 绝对不要使用任何代码块标记
3. 语句以分号结尾
4. 不要包含任何解释或注释
5. 若需要多表查询使用显式JOIN语法
6. 确保没有重复的SELECT关键字
"""
def generate_sql(self, question: str, **kwargs) -> str:
"""同步生成SQL"""
try:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model,
"input": {
"messages": [{
"role": "user",
"content": self._build_sql_prompt(question)
}]
},
"parameters": {
"temperature": 0.1,
"max_tokens": 5000,
"result_format": "text"
}
}
response = requests.post(self.base_url, headers=headers, json=data)
response.raise_for_status()
raw_sql = response.json()['output']['text']
return self._clean_sql_output(raw_sql)
except Exception as e:
print(f"\nAPI请求错误: {str(e)}")
return ""