You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
4.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import re
from typing import List, Dict, Any
import requests
from vanna.base import VannaBase
from Config import *
class DeepSeekVanna(VannaBase):
def __init__(self):
super().__init__()
self.api_key = MODEL_API_KEY
self.base_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" # 阿里云专用API地址
self.model = QWEN_MODEL_NAME # 根据实际模型名称调整
self.training_data = []
self.chat_history = []
# ---------- 必须实现的抽象方法 ----------
def add_ddl(self, ddl: str, **kwargs) -> None:
self.training_data.append({"type": "ddl", "content": ddl})
def add_documentation(self, doc: str, **kwargs) -> None:
self.training_data.append({"type": "documentation", "content": doc})
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
self.training_data.append({"type": "qa", "question": question, "sql": sql})
def get_related_ddl(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"])
def get_related_documentation(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"])
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
return self.training_data
# ---------- 对话方法 ----------
def system_message(self, message: str) -> None:
self.chat_history = [{"role": "system", "content": message}]
def user_message(self, message: str) -> None:
self.chat_history.append({"role": "user", "content": message})
def assistant_message(self, message: str) -> None:
self.chat_history.append({"role": "assistant", "content": message})
def submit_prompt(self, prompt: str, **kwargs) -> str:
return self.generate_sql(question=prompt)
# ---------- 其他方法 ----------
def remove_training_data(self, id: str, **kwargs) -> bool:
return True
def generate_embedding(self, text: str, **kwargs) -> List[float]:
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
return []
def _clean_sql_output(self, raw_sql: str) -> str:
"""增强版清洗逻辑"""
# 移除所有非SQL内容
cleaned = re.sub(r'^.*?(?=SELECT)', '', raw_sql, flags=re.IGNORECASE|re.DOTALL)
# 提取第一个完整SQL语句
match = re.search(r'(SELECT\s.+?;)', cleaned, re.IGNORECASE|re.DOTALL)
if match:
# 标准化空格和换行
clean_sql = re.sub(r'\s+', ' ', match.group(1)).strip()
# 确保没有重复SELECT
clean_sql = re.sub(r'(SELECT\s+)+', 'SELECT ', clean_sql, flags=re.IGNORECASE)
return clean_sql
return raw_sql
def _build_sql_prompt(self, question: str) -> str:
"""强化提示词"""
return f"""严格按以下要求生成Postgresql查询
表结构:
{self.get_related_ddl(question)}
问题:{question}
生成规则:
1. 只输出一个标准的SELECT语句
2. 绝对不要使用任何代码块标记
3. 语句以分号结尾
4. 不要包含任何解释或注释
5. 若需要多表查询使用显式JOIN语法
6. 确保没有重复的SELECT关键字
"""
def generate_sql(self, question: str, **kwargs) -> str:
"""同步生成SQL"""
try:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model,
"input": {
"messages": [{
"role": "user",
"content": self._build_sql_prompt(question)
}]
},
"parameters": {
"temperature": 0.1,
"max_tokens": 5000,
"result_format": "text"
}
}
response = requests.post(self.base_url, headers=headers, json=data)
response.raise_for_status()
raw_sql = response.json()['output']['text']
return self._clean_sql_output(raw_sql)
except Exception as e:
print(f"\nAPI请求错误: {str(e)}")
return ""