import os import requests from vanna.base import VannaBase from typing import List, Dict, Any from Config import * class QwenChat(VannaBase): def __init__(self): super().__init__() self.api_key = MODEL_API_KEY self.model = "qwen-max" self.training_data = [] self.chat_history = [] # 新增:用于存储对话历史 # ---------- 必须实现的抽象方法 ---------- def add_ddl(self, ddl: str, **kwargs) -> None: self.training_data.append({"type": "ddl", "content": ddl}) def add_documentation(self, doc: str, **kwargs) -> None: self.training_data.append({"type": "documentation", "content": doc}) def add_question_sql(self, question: str, sql: str, **kwargs) -> None: self.training_data.append({"type": "qa", "question": question, "sql": sql}) def get_related_ddl(self, question: str, **kwargs) -> str: return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"]) def get_related_documentation(self, question: str, **kwargs) -> str: return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"]) def get_training_data(self, **kwargs) -> List[Dict[str, Any]]: return self.training_data # ---------- 新增的必需方法 ---------- def system_message(self, message: str) -> None: """系统消息(用于初始化对话)""" self.chat_history.append({"role": "system", "content": message}) def user_message(self, message: str) -> None: """用户消息""" self.chat_history.append({"role": "user", "content": message}) def assistant_message(self, message: str) -> None: """助手消息""" self.chat_history.append({"role": "assistant", "content": message}) def submit_prompt(self, prompt: str, **kwargs) -> str: """提交提示词(这里直接调用 Qwen API)""" return self.generate_sql(question=prompt) # ---------- 其他需要实现的方法 ---------- def remove_training_data(self, id: str, **kwargs) -> bool: return True def generate_embedding(self, text: str, **kwargs) -> List[float]: return [] def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]: return [] # ---------- 原有生成 SQL 的方法 ---------- def generate_sql(self, question: str, **kwargs) -> str: url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } prompt = f"基于以下 MySQL 表结构,生成 SQL 查询。只需返回 SQL 代码,不要解释。\n\n表结构:\n{self.get_related_ddl(question)}\n\n问题:{question}" data = { "model": self.model, "input": {"messages": [{"role": "user", "content": prompt}]}, "parameters": {} } response = requests.post(url, headers=headers, json=data) if response.status_code == 200: return response.json()["output"]["text"] else: raise Exception(f"Qwen API 错误: {response.text}") # ---------------------------- # 初始化并使用实例 # ---------------------------- vn = QwenChat() # 添加训练数据 vn.train( ddl=""" CREATE TABLE employees ( id INT PRIMARY KEY, name VARCHAR(100), department VARCHAR(50), salary DECIMAL(10, 2), hire_date DATE ); """ ) vn.train(documentation="部门 'Sales' 的员工工资超过 50000") # 生成 SQL question = "显示 2020 年后入职的 Sales 部门员工姓名和工资,且工资大于 50000" try: sql = vn.generate_sql(question=question) print("生成的 SQL:\n", sql) except Exception as e: print("错误:", str(e))