import re from typing import List, Dict, Any import requests import sqlite3 from vanna.base import VannaBase from Config import * class VannaUtil(VannaBase): def __init__(self, db_type='sqlite', db_uri=None): super().__init__() self.api_key = MODEL_API_KEY self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址 self.model = QWEN_MODEL_NAME # 根据实际模型名称调整 self.training_data = [] self.chat_history = [] self.db_type = db_type self.db_uri = db_uri or 'Db/vanna.db' # 默认使用 SQLite self._init_db() def _init_db(self): """初始化数据库连接""" if self.db_type == 'sqlite': self.conn = sqlite3.connect(self.db_uri) self._create_tables() elif self.db_type == 'postgres': import psycopg2 self.conn = psycopg2.connect(self.db_uri) self._create_tables() else: raise ValueError(f"Unsupported database type: {self.db_type}") def _create_tables(self): """创建训练数据表""" cursor = self.conn.cursor() if self.db_type == 'sqlite': cursor.execute(''' CREATE TABLE IF NOT EXISTS training_data ( id INTEGER PRIMARY KEY AUTOINCREMENT, type TEXT, question TEXT, sql TEXT, content TEXT ) ''') elif self.db_type == 'postgres': cursor.execute(''' CREATE TABLE IF NOT EXISTS training_data ( id SERIAL PRIMARY KEY, type TEXT, question TEXT, sql TEXT, content TEXT ) ''') self.conn.commit() def add_ddl(self, ddl: str, **kwargs) -> None: """添加 DDL""" cursor = self.conn.cursor() cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('ddl', ddl)) self.conn.commit() def add_documentation(self, doc: str, **kwargs) -> None: """添加文档""" cursor = self.conn.cursor() cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('documentation', doc)) self.conn.commit() def add_question_sql(self, question: str, sql: str, **kwargs) -> None: """添加问答对""" cursor = self.conn.cursor() cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (?, ?, ?)', ('qa', question, sql)) self.conn.commit() def get_related_ddl(self, question: str, **kwargs) -> str: """获取相关 DDL""" cursor = self.conn.cursor() cursor.execute('SELECT content FROM training_data WHERE type = ?', ('ddl',)) return "\n".join(row[0] for row in cursor.fetchall()) def get_related_documentation(self, question: str, **kwargs) -> str: """获取相关文档""" cursor = self.conn.cursor() cursor.execute('SELECT content FROM training_data WHERE type = ?', ('documentation',)) return "\n".join(row[0] for row in cursor.fetchall()) def get_training_data(self, **kwargs) -> List[Dict[str, Any]]: """获取所有训练数据""" cursor = self.conn.cursor() cursor.execute('SELECT * FROM training_data') columns = [column[0] for column in cursor.description] return [dict(zip(columns, row)) for row in cursor.fetchall()] def remove_training_data(self, id: str, **kwargs) -> bool: """删除训练数据""" cursor = self.conn.cursor() cursor.execute('DELETE FROM training_data WHERE id = ?', (id,)) self.conn.commit() return cursor.rowcount > 0 def generate_embedding(self, text: str, **kwargs) -> List[float]: """生成嵌入向量""" return [] def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]: """获取相似问答对""" return [] def system_message(self, message: str) -> None: """添加系统消息""" self.chat_history = [{"role": "system", "content": message}] def user_message(self, message: str) -> None: """添加用户消息""" self.chat_history.append({"role": "user", "content": message}) def assistant_message(self, message: str) -> None: """添加助手消息""" self.chat_history.append({"role": "assistant", "content": message}) def submit_prompt(self, prompt: str, **kwargs) -> str: """提交提示词""" return self.generate_sql(question=prompt) def _clean_sql_output(self, raw_sql: str) -> str: """增强版清洗逻辑""" # 移除所有非SQL内容 cleaned = re.sub(r'^.*?(?=SELECT)', '', raw_sql, flags=re.IGNORECASE | re.DOTALL) # 提取第一个完整SQL语句 match = re.search(r'(SELECT\s.+?;)', cleaned, re.IGNORECASE | re.DOTALL) if match: # 标准化空格和换行 clean_sql = re.sub(r'\s+', ' ', match.group(1)).strip() # 确保没有重复SELECT clean_sql = re.sub(r'(SELECT\s+)+', 'SELECT ', clean_sql, flags=re.IGNORECASE) return clean_sql return raw_sql def _build_sql_prompt(self, question: str) -> str: """强化提示词""" return f"""严格按以下要求生成Postgresql查询: 表结构: {self.get_related_ddl(question)} 问题:{question} 生成规则: 1. 只输出一个标准的SELECT语句 2. 绝对不要使用任何代码块标记 3. 语句以分号结尾 4. 不要包含任何解释或注释 5. 若需要多表查询,使用显式JOIN语法 6. 确保没有重复的SELECT关键字 """ def generate_sql(self, question: str, **kwargs) -> str: """同步生成SQL""" try: headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } data = { "model": self.model, "input": { "messages": [{ "role": "user", "content": self._build_sql_prompt(question) }] }, "parameters": { "temperature": 0.1, "max_tokens": 5000, "result_format": "text" } } response = requests.post(self.base_url, headers=headers, json=data) response.raise_for_status() raw_sql = response.json()['output']['text'] return self._clean_sql_output(raw_sql) except Exception as e: print(f"\nAPI请求错误: {str(e)}") return ""