import re from typing import List, Dict, Any import requests from vanna.base import VannaBase from Config import * class VannaUtil(VannaBase): def __init__(self): super().__init__() self.api_key = MODEL_API_KEY self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址 self.model = QWEN_MODEL_NAME # 根据实际模型名称调整 self.training_data = [] self.chat_history = [] # ---------- 必须实现的抽象方法 ---------- def add_ddl(self, ddl: str, **kwargs) -> None: self.training_data.append({"type": "ddl", "content": ddl}) def add_documentation(self, doc: str, **kwargs) -> None: self.training_data.append({"type": "documentation", "content": doc}) def add_question_sql(self, question: str, sql: str, **kwargs) -> None: self.training_data.append({"type": "qa", "question": question, "sql": sql}) def get_related_ddl(self, question: str, **kwargs) -> str: return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"]) def get_related_documentation(self, question: str, **kwargs) -> str: return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"]) def get_training_data(self, **kwargs) -> List[Dict[str, Any]]: return self.training_data # ---------- 对话方法 ---------- def system_message(self, message: str) -> None: self.chat_history = [{"role": "system", "content": message}] def user_message(self, message: str) -> None: self.chat_history.append({"role": "user", "content": message}) def assistant_message(self, message: str) -> None: self.chat_history.append({"role": "assistant", "content": message}) def submit_prompt(self, prompt: str, **kwargs) -> str: return self.generate_sql(question=prompt) # ---------- 其他方法 ---------- def remove_training_data(self, id: str, **kwargs) -> bool: return True def generate_embedding(self, text: str, **kwargs) -> List[float]: return [] def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]: return [] def _clean_sql_output(self, raw_sql: str) -> str: """增强版清洗逻辑""" # 移除所有非SQL内容 cleaned = re.sub(r'^.*?(?=SELECT)', '', raw_sql, flags=re.IGNORECASE | re.DOTALL) # 提取第一个完整SQL语句 match = re.search(r'(SELECT\s.+?;)', cleaned, re.IGNORECASE | re.DOTALL) if match: # 标准化空格和换行 clean_sql = re.sub(r'\s+', ' ', match.group(1)).strip() # 确保没有重复SELECT clean_sql = re.sub(r'(SELECT\s+)+', 'SELECT ', clean_sql, flags=re.IGNORECASE) return clean_sql return raw_sql def _build_sql_prompt(self, question: str) -> str: """强化提示词""" return f"""严格按以下要求生成Postgresql查询: 表结构: {self.get_related_ddl(question)} 问题:{question} 生成规则: 1. 只输出一个标准的SELECT语句 2. 绝对不要使用任何代码块标记 3. 语句以分号结尾 4. 不要包含任何解释或注释 5. 若需要多表查询,使用显式JOIN语法 6. 确保没有重复的SELECT关键字 """ def generate_sql(self, question: str, **kwargs) -> str: """同步生成SQL""" try: headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } data = { "model": self.model, "input": { "messages": [{ "role": "user", "content": self._build_sql_prompt(question) }] }, "parameters": { "temperature": 0.1, "max_tokens": 5000, "result_format": "text" } } response = requests.post(self.base_url, headers=headers, json=data) response.raise_for_status() raw_sql = response.json()['output']['text'] return self._clean_sql_output(raw_sql) except Exception as e: print(f"\nAPI请求错误: {str(e)}") return ""