|
|
|
@ -0,0 +1,110 @@
|
|
|
|
|
import os
|
|
|
|
|
import requests
|
|
|
|
|
from vanna.base import VannaBase
|
|
|
|
|
from typing import List, Dict, Any
|
|
|
|
|
from Config import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QwenChat(VannaBase):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.api_key = MODEL_API_KEY
|
|
|
|
|
self.model = "qwen-max"
|
|
|
|
|
self.training_data = []
|
|
|
|
|
self.chat_history = [] # 新增:用于存储对话历史
|
|
|
|
|
|
|
|
|
|
# ---------- 必须实现的抽象方法 ----------
|
|
|
|
|
def add_ddl(self, ddl: str, **kwargs) -> None:
|
|
|
|
|
self.training_data.append({"type": "ddl", "content": ddl})
|
|
|
|
|
|
|
|
|
|
def add_documentation(self, doc: str, **kwargs) -> None:
|
|
|
|
|
self.training_data.append({"type": "documentation", "content": doc})
|
|
|
|
|
|
|
|
|
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
|
|
|
|
|
self.training_data.append({"type": "qa", "question": question, "sql": sql})
|
|
|
|
|
|
|
|
|
|
def get_related_ddl(self, question: str, **kwargs) -> str:
|
|
|
|
|
return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"])
|
|
|
|
|
|
|
|
|
|
def get_related_documentation(self, question: str, **kwargs) -> str:
|
|
|
|
|
return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"])
|
|
|
|
|
|
|
|
|
|
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
|
|
|
|
|
return self.training_data
|
|
|
|
|
|
|
|
|
|
# ---------- 新增的必需方法 ----------
|
|
|
|
|
def system_message(self, message: str) -> None:
|
|
|
|
|
"""系统消息(用于初始化对话)"""
|
|
|
|
|
self.chat_history.append({"role": "system", "content": message})
|
|
|
|
|
|
|
|
|
|
def user_message(self, message: str) -> None:
|
|
|
|
|
"""用户消息"""
|
|
|
|
|
self.chat_history.append({"role": "user", "content": message})
|
|
|
|
|
|
|
|
|
|
def assistant_message(self, message: str) -> None:
|
|
|
|
|
"""助手消息"""
|
|
|
|
|
self.chat_history.append({"role": "assistant", "content": message})
|
|
|
|
|
|
|
|
|
|
def submit_prompt(self, prompt: str, **kwargs) -> str:
|
|
|
|
|
"""提交提示词(这里直接调用 Qwen API)"""
|
|
|
|
|
return self.generate_sql(question=prompt)
|
|
|
|
|
|
|
|
|
|
# ---------- 其他需要实现的方法 ----------
|
|
|
|
|
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def generate_embedding(self, text: str, **kwargs) -> List[float]:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
# ---------- 原有生成 SQL 的方法 ----------
|
|
|
|
|
def generate_sql(self, question: str, **kwargs) -> str:
|
|
|
|
|
url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
|
|
|
|
headers = {
|
|
|
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
}
|
|
|
|
|
prompt = f"基于以下 MySQL 表结构,生成 SQL 查询。只需返回 SQL 代码,不要解释。\n\n表结构:\n{self.get_related_ddl(question)}\n\n问题:{question}"
|
|
|
|
|
|
|
|
|
|
data = {
|
|
|
|
|
"model": self.model,
|
|
|
|
|
"input": {"messages": [{"role": "user", "content": prompt}]},
|
|
|
|
|
"parameters": {}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
response = requests.post(url, headers=headers, json=data)
|
|
|
|
|
if response.status_code == 200:
|
|
|
|
|
return response.json()["output"]["text"]
|
|
|
|
|
else:
|
|
|
|
|
raise Exception(f"Qwen API 错误: {response.text}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------------------
|
|
|
|
|
# 初始化并使用实例
|
|
|
|
|
# ----------------------------
|
|
|
|
|
vn = QwenChat()
|
|
|
|
|
|
|
|
|
|
# 添加训练数据
|
|
|
|
|
vn.train(
|
|
|
|
|
ddl="""
|
|
|
|
|
CREATE TABLE employees (
|
|
|
|
|
id INT PRIMARY KEY,
|
|
|
|
|
name VARCHAR(100),
|
|
|
|
|
department VARCHAR(50),
|
|
|
|
|
salary DECIMAL(10, 2),
|
|
|
|
|
hire_date DATE
|
|
|
|
|
);
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
vn.train(documentation="部门 'Sales' 的员工工资超过 50000")
|
|
|
|
|
|
|
|
|
|
# 生成 SQL
|
|
|
|
|
question = "显示 2020 年后入职的 Sales 部门员工姓名和工资,且工资大于 50000"
|
|
|
|
|
try:
|
|
|
|
|
sql = vn.generate_sql(question=question)
|
|
|
|
|
print("生成的 SQL:\n", sql)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print("错误:", str(e))
|