main
黄海 5 months ago
parent 498971b9f3
commit 30eaeccea7

@ -5,7 +5,7 @@ from pathlib import Path
MODEL_API_KEY = "sk-01d13a39e09844038322108ecdbd1bbc"
MODEL_API_URL= 'https://dashscope.aliyuncs.com/compatible-mode/v1'
MODEL_NAME = "deepseek-v3"
#MODEL_NAME='qwen-plus'
QWEN_MODEL_NAME='qwen-plus'
# 华为云

@ -0,0 +1,61 @@
CREATE TABLE "public"."t_crawler_book" (
"book_id" varchar(255) COLLATE "pg_catalog"."default" NOT NULL,
"book_name" varchar(255) COLLATE "pg_catalog"."default",
"scheme_id" varchar(255) COLLATE "pg_catalog"."default",
"subject_id" varchar(255) COLLATE "pg_catalog"."default",
"stage_id" varchar(255) COLLATE "pg_catalog"."default",
"id" varchar(255) COLLATE "pg_catalog"."default" NOT NULL
)
;
COMMENT ON COLUMN "public"."t_crawler_book"."book_id" IS '册ID';
COMMENT ON COLUMN "public"."t_crawler_book"."book_name" IS '册的名称';
COMMENT ON COLUMN "public"."t_crawler_book"."scheme_id" IS '版本ID比如人教版本ID,与t_crawler_scheme表中的scheme_id关联';
COMMENT ON COLUMN "public"."t_crawler_book"."subject_id" IS '科目id,与t_crawler_subject中的subject_id关联';
COMMENT ON COLUMN "public"."t_crawler_book"."stage_id" IS '学段ID与t_crawler_stage中的stage_id关联';
COMMENT ON COLUMN "public"."t_crawler_book"."id" IS '主键,无实际意义';
COMMENT ON TABLE "public"."t_crawler_book" IS '课程章节目录中的册概念,比如三年级上册,四年级下册';
DROP TABLE IF EXISTS "public"."t_crawler_scheme";
CREATE TABLE "public"."t_crawler_scheme" (
"scheme_id" varchar(255) COLLATE "pg_catalog"."default" NOT NULL,
"scheme_name" varchar(255) COLLATE "pg_catalog"."default",
"subject_id" varchar(255) COLLATE "pg_catalog"."default",
"stage_id" varchar(255) COLLATE "pg_catalog"."default",
"id" varchar(255) COLLATE "pg_catalog"."default" NOT NULL
)
;
COMMENT ON COLUMN "public"."t_crawler_scheme"."scheme_id" IS '版本ID';
COMMENT ON COLUMN "public"."t_crawler_scheme"."scheme_name" IS '版本名称';
COMMENT ON COLUMN "public"."t_crawler_scheme"."subject_id" IS '学科ID';
COMMENT ON COLUMN "public"."t_crawler_scheme"."stage_id" IS '学段ID';
COMMENT ON COLUMN "public"."t_crawler_scheme"."id" IS '主键';
COMMENT ON TABLE "public"."t_crawler_scheme" IS '教材版本,目前一般一个学科一个版本';
DROP TABLE IF EXISTS "public"."t_crawler_stage";
CREATE TABLE "public"."t_crawler_stage" (
"stage_id" varchar(255) COLLATE "pg_catalog"."default" NOT NULL,
"stage_name" varchar(255) COLLATE "pg_catalog"."default"
)
;
COMMENT ON COLUMN "public"."t_crawler_stage"."stage_id" IS '学段ID';
COMMENT ON COLUMN "public"."t_crawler_stage"."stage_name" IS '学段名称';
COMMENT ON TABLE "public"."t_crawler_stage" IS '学段表';
DROP TABLE IF EXISTS "public"."t_crawler_subject";
CREATE TABLE "public"."t_crawler_subject" (
"subject_id" varchar(255) COLLATE "pg_catalog"."default" NOT NULL,
"subject_name" varchar(255) COLLATE "pg_catalog"."default",
"stage_id" varchar(255) COLLATE "pg_catalog"."default"
)
;
COMMENT ON COLUMN "public"."t_crawler_subject"."subject_id" IS '科目ID';
COMMENT ON COLUMN "public"."t_crawler_subject"."subject_name" IS '科目名称';
COMMENT ON COLUMN "public"."t_crawler_subject"."stage_id" IS '学段ID';
COMMENT ON TABLE "public"."t_crawler_subject" IS '学科表';
ALTER TABLE "public"."t_crawler_book" ADD CONSTRAINT "t_crawler_booke_pkey" PRIMARY KEY ("id");
ALTER TABLE "public"."t_crawler_scheme" ADD CONSTRAINT "t_crawler_scheme_pkey" PRIMARY KEY ("id");
ALTER TABLE "public"."t_crawler_subject" ADD CONSTRAINT "t_crawler_subject_pkey" PRIMARY KEY ("subject_id");

@ -0,0 +1,89 @@
import psycopg2
from psycopg2 import OperationalError
from psycopg2.extras import RealDictCursor
import json
from datetime import date, datetime
class PostgreSQLUtil:
def __init__(self, host="10.10.14.71", port=5432,
dbname="szjz_db", user="postgres", password="DsideaL147258369"):
self.conn_params = {
"host": host,
"port": port,
"dbname": dbname,
"user": user,
"password": password
}
self.connection = None
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def connect(self):
try:
self.connection = psycopg2.connect(**self.conn_params)
print("成功连接到PostgreSQL数据库")
except OperationalError as e:
print(f"连接错误: {e}")
raise
def close(self):
if self.connection:
self.connection.close()
print("数据库连接已关闭")
def execute_query(self, sql, params=None, return_dict=True):
"""执行查询并返回结果"""
try:
with self.connection.cursor(
cursor_factory=RealDictCursor if return_dict else None
) as cursor:
cursor.execute(sql, params)
if cursor.description:
columns = [desc[0] for desc in cursor.description]
results = cursor.fetchall()
# 转换字典格式
if return_dict:
return results
else:
return [dict(zip(columns, row)) for row in results]
else:
return {"rowcount": cursor.rowcount}
except Exception as e:
print(f"执行SQL出错: {e}")
self.connection.rollback()
raise
finally:
self.connection.commit()
def query_to_json(self, sql, params=None):
"""返回JSON格式结果"""
data = self.execute_query(sql, params)
return json.dumps(data, default=self.json_serializer)
@staticmethod
def json_serializer(obj):
"""处理JSON无法序列化的类型"""
if isinstance(obj, (date, datetime)):
return obj.isoformat()
raise TypeError(f"Type {type(obj)} not serializable")
# 使用示例
if __name__ == "__main__":
with PostgreSQLUtil() as db:
# 示例查询
result = db.execute_query("SELECT version()")
print("数据库版本:", result)
# 返回JSON
json_data = db.query_to_json("SELECT * FROM t_base_class LIMIT 2")
print("JSON结果:", json_data)

@ -0,0 +1,69 @@
import pandas as pd
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
from openpyxl.utils import get_column_letter
def save_to_excel(data, filename):
"""
将数据集保存为格式化的Excel文件
参数
data - 数据集 (列表字典格式例如[{"列1": "值1", "列2": "值2"}, ...])
filename - 输出文件名 (需包含.xlsx扩展名)
"""
# 转换数据为DataFrame
df = pd.DataFrame(data)
# 创建Excel写入对象
with pd.ExcelWriter(filename, engine='openpyxl') as writer:
df.to_excel(writer, index=False, sheet_name='Sheet1')
# 获取工作表对象
workbook = writer.book
worksheet = writer.sheets['Sheet1']
# 定义边框样式
thin_border = Border(left=Side(style='thin'),
right=Side(style='thin'),
top=Side(style='thin'),
bottom=Side(style='thin'))
# 设置全局行高
for row in worksheet.iter_rows():
worksheet.row_dimensions[row[0].row].height = 20
# 为所有单元格添加边框
for cell in row:
cell.border = thin_border
# 设置标题样式
header_font = Font(bold=True, size=14)
header_fill = PatternFill(start_color='ADD8E6', end_color='ADD8E6', fill_type='solid')
for cell in worksheet[1]:
cell.font = header_font
cell.fill = header_fill
cell.alignment = Alignment(horizontal='center', vertical='center')
# 设置数据行样式
data_font = Font(size=14)
for row in worksheet.iter_rows(min_row=2):
for cell in row:
cell.font = data_font
cell.alignment = Alignment(vertical='center', wrap_text=True)
# 设置列宽第一列固定40其他自动调整
for idx, column in enumerate(worksheet.columns):
column_letter = get_column_letter(idx + 1)
worksheet.column_dimensions[column_letter].width = 60
max_length = 0
for cell in column:
try:
value_len = len(str(cell.value))
if value_len > max_length:
max_length = value_len
except:
pass
adjusted_width = (max_length * 1.2 + 5)
worksheet.column_dimensions[column_letter].width = adjusted_width

@ -0,0 +1,39 @@
import os
import platform
from Text2SqlUtil import *
from Text2Sql.PostgreSQLUtil import PostgreSQLUtil
from Text2Sql.SaveToExcel import save_to_excel
if __name__ == "__main__":
vn = DeepSeekVanna()
# 打开CreateTable.sql文件内容
with open("CreateTable.sql", "r", encoding="utf-8") as file:
ddl = file.read()
# 训练数据
vn.train(
ddl=ddl
)
# 流式生成演示
question = "查询小学语文有哪些册,返回册信息的所有相关属性。"
# 获取完整 SQL
sql = vn.generate_sql(question)
print("最终 SQL:\n", sql)
# 执行SQL查询
with PostgreSQLUtil() as db:
sample_data = db.execute_query(sql)
filename = "导出信息.xlsx"
save_to_excel(sample_data, filename)
# 在代码最后添加自动打开逻辑
if platform.system() == "Windows":
try:
full_path = os.path.abspath(filename)
print(f"\n✅ 文件已保存到:{full_path}")
os.startfile(full_path) # 关键代码
except Exception as e:
print(f"\n⚠️ 自动打开失败: {str(e)},请手动打开文件")
else:
print("\n⚠️ 非Windows系统请手动打开文件")

@ -0,0 +1,123 @@
import re
from typing import List, Dict, Any
import requests
from vanna.base import VannaBase
from Config import *
class DeepSeekVanna(VannaBase):
def __init__(self):
super().__init__()
self.api_key = MODEL_API_KEY
self.base_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" # 阿里云专用API地址
self.model = QWEN_MODEL_NAME # 根据实际模型名称调整
self.training_data = []
self.chat_history = []
# ---------- 必须实现的抽象方法 ----------
def add_ddl(self, ddl: str, **kwargs) -> None:
self.training_data.append({"type": "ddl", "content": ddl})
def add_documentation(self, doc: str, **kwargs) -> None:
self.training_data.append({"type": "documentation", "content": doc})
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
self.training_data.append({"type": "qa", "question": question, "sql": sql})
def get_related_ddl(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"])
def get_related_documentation(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"])
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
return self.training_data
# ---------- 对话方法 ----------
def system_message(self, message: str) -> None:
self.chat_history = [{"role": "system", "content": message}]
def user_message(self, message: str) -> None:
self.chat_history.append({"role": "user", "content": message})
def assistant_message(self, message: str) -> None:
self.chat_history.append({"role": "assistant", "content": message})
def submit_prompt(self, prompt: str, **kwargs) -> str:
return self.generate_sql(question=prompt)
# ---------- 其他方法 ----------
def remove_training_data(self, id: str, **kwargs) -> bool:
return True
def generate_embedding(self, text: str, **kwargs) -> List[float]:
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
return []
def _clean_sql_output(self, raw_sql: str) -> str:
"""增强版清洗逻辑"""
# 移除所有非SQL内容
cleaned = re.sub(r'^.*?(?=SELECT)', '', raw_sql, flags=re.IGNORECASE|re.DOTALL)
# 提取第一个完整SQL语句
match = re.search(r'(SELECT\s.+?;)', cleaned, re.IGNORECASE|re.DOTALL)
if match:
# 标准化空格和换行
clean_sql = re.sub(r'\s+', ' ', match.group(1)).strip()
# 确保没有重复SELECT
clean_sql = re.sub(r'(SELECT\s+)+', 'SELECT ', clean_sql, flags=re.IGNORECASE)
return clean_sql
return raw_sql
def _build_sql_prompt(self, question: str) -> str:
"""强化提示词"""
return f"""严格按以下要求生成Postgresql查询
表结构
{self.get_related_ddl(question)}
问题{question}
生成规则
1. 只输出一个标准的SELECT语句
2. 绝对不要使用任何代码块标记
3. 语句以分号结尾
4. 不要包含任何解释或注释
5. 若需要多表查询使用显式JOIN语法
6. 确保没有重复的SELECT关键字
"""
def generate_sql(self, question: str, **kwargs) -> str:
"""同步生成SQL"""
try:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model,
"input": {
"messages": [{
"role": "user",
"content": self._build_sql_prompt(question)
}]
},
"parameters": {
"temperature": 0.1,
"max_tokens": 5000,
"result_format": "text"
}
}
response = requests.post(self.base_url, headers=headers, json=data)
response.raise_for_status()
raw_sql = response.json()['output']['text']
return self._clean_sql_output(raw_sql)
except Exception as e:
print(f"\nAPI请求错误: {str(e)}")
return ""

Binary file not shown.

@ -1,110 +0,0 @@
import os
import requests
from vanna.base import VannaBase
from typing import List, Dict, Any
from Config import *
class QwenChat(VannaBase):
def __init__(self):
super().__init__()
self.api_key = MODEL_API_KEY
self.model = "qwen-max"
self.training_data = []
self.chat_history = [] # 新增:用于存储对话历史
# ---------- 必须实现的抽象方法 ----------
def add_ddl(self, ddl: str, **kwargs) -> None:
self.training_data.append({"type": "ddl", "content": ddl})
def add_documentation(self, doc: str, **kwargs) -> None:
self.training_data.append({"type": "documentation", "content": doc})
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
self.training_data.append({"type": "qa", "question": question, "sql": sql})
def get_related_ddl(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"])
def get_related_documentation(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"])
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
return self.training_data
# ---------- 新增的必需方法 ----------
def system_message(self, message: str) -> None:
"""系统消息(用于初始化对话)"""
self.chat_history.append({"role": "system", "content": message})
def user_message(self, message: str) -> None:
"""用户消息"""
self.chat_history.append({"role": "user", "content": message})
def assistant_message(self, message: str) -> None:
"""助手消息"""
self.chat_history.append({"role": "assistant", "content": message})
def submit_prompt(self, prompt: str, **kwargs) -> str:
"""提交提示词(这里直接调用 Qwen API"""
return self.generate_sql(question=prompt)
# ---------- 其他需要实现的方法 ----------
def remove_training_data(self, id: str, **kwargs) -> bool:
return True
def generate_embedding(self, text: str, **kwargs) -> List[float]:
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
return []
# ---------- 原有生成 SQL 的方法 ----------
def generate_sql(self, question: str, **kwargs) -> str:
url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
prompt = f"基于以下 MySQL 表结构,生成 SQL 查询。只需返回 SQL 代码,不要解释。\n\n表结构:\n{self.get_related_ddl(question)}\n\n问题:{question}"
data = {
"model": self.model,
"input": {"messages": [{"role": "user", "content": prompt}]},
"parameters": {}
}
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
return response.json()["output"]["text"]
else:
raise Exception(f"Qwen API 错误: {response.text}")
# ----------------------------
# 初始化并使用实例
# ----------------------------
vn = QwenChat()
# 添加训练数据
vn.train(
ddl="""
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10, 2),
hire_date DATE
);
"""
)
vn.train(documentation="部门 'Sales' 的员工工资超过 50000")
# 生成 SQL
question = "显示 2020 年后入职的 Sales 部门员工姓名和工资,且工资大于 50000"
try:
sql = vn.generate_sql(question=question)
print("生成的 SQL:\n", sql)
except Exception as e:
print("错误:", str(e))

Binary file not shown.

After

Width:  |  Height:  |  Size: 545 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 161 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Loading…
Cancel
Save