You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import requests
from vanna.base import VannaBase
from typing import List, Dict, Any
from Config import *
class QwenChat(VannaBase):
def __init__(self):
super().__init__()
self.api_key = MODEL_API_KEY
self.model = "qwen-max"
self.training_data = []
self.chat_history = [] # 新增:用于存储对话历史
# ---------- 必须实现的抽象方法 ----------
def add_ddl(self, ddl: str, **kwargs) -> None:
self.training_data.append({"type": "ddl", "content": ddl})
def add_documentation(self, doc: str, **kwargs) -> None:
self.training_data.append({"type": "documentation", "content": doc})
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
self.training_data.append({"type": "qa", "question": question, "sql": sql})
def get_related_ddl(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"])
def get_related_documentation(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"])
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
return self.training_data
# ---------- 新增的必需方法 ----------
def system_message(self, message: str) -> None:
"""系统消息(用于初始化对话)"""
self.chat_history.append({"role": "system", "content": message})
def user_message(self, message: str) -> None:
"""用户消息"""
self.chat_history.append({"role": "user", "content": message})
def assistant_message(self, message: str) -> None:
"""助手消息"""
self.chat_history.append({"role": "assistant", "content": message})
def submit_prompt(self, prompt: str, **kwargs) -> str:
"""提交提示词(这里直接调用 Qwen API"""
return self.generate_sql(question=prompt)
# ---------- 其他需要实现的方法 ----------
def remove_training_data(self, id: str, **kwargs) -> bool:
return True
def generate_embedding(self, text: str, **kwargs) -> List[float]:
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
return []
# ---------- 原有生成 SQL 的方法 ----------
def generate_sql(self, question: str, **kwargs) -> str:
url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
prompt = f"基于以下 MySQL 表结构,生成 SQL 查询。只需返回 SQL 代码,不要解释。\n\n表结构:\n{self.get_related_ddl(question)}\n\n问题:{question}"
data = {
"model": self.model,
"input": {"messages": [{"role": "user", "content": prompt}]},
"parameters": {}
}
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
return response.json()["output"]["text"]
else:
raise Exception(f"Qwen API 错误: {response.text}")
# ----------------------------
# 初始化并使用实例
# ----------------------------
vn = QwenChat()
# 添加训练数据
vn.train(
ddl="""
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10, 2),
hire_date DATE
);
"""
)
vn.train(documentation="部门 'Sales' 的员工工资超过 50000")
# 生成 SQL
question = "显示 2020 年后入职的 Sales 部门员工姓名和工资,且工资大于 50000"
try:
sql = vn.generate_sql(question=question)
print("生成的 SQL:\n", sql)
except Exception as e:
print("错误:", str(e))