main
HuangHai 4 months ago
parent 8aca60cf07
commit ae86048959

@ -0,0 +1,70 @@
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date
from sqlalchemy.orm import declarative_base # 更新导入路径
from sqlalchemy.orm import sessionmaker
# 数据库连接配置
DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"
# 创建数据库引擎
engine = create_engine(DATABASE_URL)
# 创建 SessionLocal 类
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建 Base 类
Base = declarative_base()
# 定义 t_bi_question 表模型
class TBIQuestion(Base):
__tablename__ = "t_bi_question"
id = Column(Integer, primary_key=True, index=True)
question = Column(String(255), nullable=False)
state_id = Column(SmallInteger, nullable=False, default=0)
sql = Column(String(2048))
bar_chart_x_columns = Column(String(255))
bar_chart_y_columns = Column(String(255))
pie_chart_category_columns = Column(String(255))
pie_chart_value_column = Column(String(255))
session_id = Column(String(255), nullable=False)
excel_file_name = Column(String(255))
bar_chart_file_name = Column(String(255))
pie_chart_file_name = Column(String(255))
report_file_name = Column(String(255))
create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP")
is_system = Column(SmallInteger, nullable=False, default=0)
is_collect = Column(SmallInteger, nullable=False, default=0)
# 创建 Pydantic 模型
class TBIQuestionCreate(BaseModel):
question: str
state_id: int = 0
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = 0
is_collect: int = 0
class TBIQuestionUpdate(BaseModel):
question: str = None
state_id: int = None
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str = None
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = None
is_collect: int = None

@ -1,6 +1,9 @@
import io
import pandas as pd
from openpyxl import Workbook
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
from openpyxl.utils import get_column_letter from openpyxl.utils import get_column_letter
import pandas as pd
def save_to_excel(data, filename): def save_to_excel(data, filename):
""" """
@ -61,4 +64,77 @@ def save_to_excel(data, filename):
column_width = min(max(max_length + 2, 10)*2, 120) # 加 2 是为了留出一些空白 column_width = min(max(max_length + 2, 10)*2, 120) # 加 2 是为了留出一些空白
# 设置列宽 # 设置列宽
column_letter = get_column_letter(idx + 1) column_letter = get_column_letter(idx + 1)
worksheet.column_dimensions[column_letter].width = column_width worksheet.column_dimensions[column_letter].width = column_width
def save_to_excel_stream(data):
"""
将数据集保存为格式化的Excel文件流
参数
data - 数据集 (列表字典格式例如[{"列1": "值1", "列2": "值2"}, ...])
返回
BytesIO - 包含Excel文件内容的流
"""
# 转换数据为DataFrame
df = pd.DataFrame(data)
# 创建一个 BytesIO 对象作为缓冲区
output = io.BytesIO()
# 创建Excel工作簿
wb = Workbook()
ws = wb.active
ws.title = '统计报表'
# 写入数据
for row in pd.DataFrame(data).itertuples(index=False):
ws.append(row)
# 定义边框样式
thin_border = Border(left=Side(style='thin'),
right=Side(style='thin'),
top=Side(style='thin'),
bottom=Side(style='thin'))
# 设置全局行高
for row in ws.iter_rows():
ws.row_dimensions[row[0].row].height = 20
# 为所有单元格添加边框
for cell in row:
cell.border = thin_border
# 设置标题样式
header_font = Font(bold=True, size=14)
header_fill = PatternFill(start_color='ADD8E6', end_color='ADD8E6', fill_type='solid')
for cell in ws[1]:
cell.font = header_font
cell.fill = header_fill
cell.alignment = Alignment(horizontal='center', vertical='center')
# 设置数据行样式
data_font = Font(size=14)
for row in ws.iter_rows(min_row=2):
for cell in row:
cell.font = data_font
cell.alignment = Alignment(vertical='center', wrap_text=True)
# 动态设置列宽
for idx, column in enumerate(df.columns):
# 获取列的最大长度
max_length = max(
df[column].astype(str).map(len).max(), # 数据列的最大长度
len(str(column)) # 列名的长度
)
# 计算列宽,确保在 10 到 120 之间
column_width = min(max(max_length + 2, 10) * 2, 120) # 加 2 是为了留出一些空白
# 设置列宽
column_letter = get_column_letter(idx + 1)
ws.column_dimensions[column_letter].width = column_width
# 将工作簿保存到 BytesIO 对象
wb.save(output)
output.seek(0) # 将指针移动到流的开头
return output

@ -1,61 +1,128 @@
import re import re
from typing import List, Dict, Any from typing import List, Dict, Any
import requests import requests
import sqlite3
from vanna.base import VannaBase from vanna.base import VannaBase
from Config import * from Config import *
class VannaUtil(VannaBase): class VannaUtil(VannaBase):
def __init__(self): def __init__(self, db_type='sqlite', db_uri=None):
super().__init__() super().__init__()
self.api_key = MODEL_API_KEY self.api_key = MODEL_API_KEY
self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址 self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址
self.model = QWEN_MODEL_NAME # 根据实际模型名称调整 self.model = QWEN_MODEL_NAME # 根据实际模型名称调整
self.training_data = [] self.training_data = []
self.chat_history = [] self.chat_history = []
self.db_type = db_type
self.db_uri = db_uri or 'vanna.db' # 默认使用 SQLite
self._init_db()
def _init_db(self):
"""初始化数据库连接"""
if self.db_type == 'sqlite':
self.conn = sqlite3.connect(self.db_uri)
self._create_tables()
elif self.db_type == 'postgres':
import psycopg2
self.conn = psycopg2.connect(self.db_uri)
self._create_tables()
else:
raise ValueError(f"Unsupported database type: {self.db_type}")
def _create_tables(self):
"""创建训练数据表"""
cursor = self.conn.cursor()
if self.db_type == 'sqlite':
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
elif self.db_type == 'postgres':
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id SERIAL PRIMARY KEY,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
self.conn.commit()
# ---------- 必须实现的抽象方法 ----------
def add_ddl(self, ddl: str, **kwargs) -> None: def add_ddl(self, ddl: str, **kwargs) -> None:
self.training_data.append({"type": "ddl", "content": ddl}) """添加 DDL"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('ddl', ddl))
self.conn.commit()
def add_documentation(self, doc: str, **kwargs) -> None: def add_documentation(self, doc: str, **kwargs) -> None:
self.training_data.append({"type": "documentation", "content": doc}) """添加文档"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('documentation', doc))
self.conn.commit()
def add_question_sql(self, question: str, sql: str, **kwargs) -> None: def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
self.training_data.append({"type": "qa", "question": question, "sql": sql}) """添加问答对"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (?, ?, ?)', ('qa', question, sql))
self.conn.commit()
def get_related_ddl(self, question: str, **kwargs) -> str: def get_related_ddl(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"]) """获取相关 DDL"""
cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('ddl',))
return "\n".join(row[0] for row in cursor.fetchall())
def get_related_documentation(self, question: str, **kwargs) -> str: def get_related_documentation(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"]) """获取相关文档"""
cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('documentation',))
return "\n".join(row[0] for row in cursor.fetchall())
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]: def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
return self.training_data """获取所有训练数据"""
cursor = self.conn.cursor()
cursor.execute('SELECT * FROM training_data')
columns = [column[0] for column in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
def remove_training_data(self, id: str, **kwargs) -> bool:
"""删除训练数据"""
cursor = self.conn.cursor()
cursor.execute('DELETE FROM training_data WHERE id = ?', (id,))
self.conn.commit()
return cursor.rowcount > 0
def generate_embedding(self, text: str, **kwargs) -> List[float]:
"""生成嵌入向量"""
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
"""获取相似问答对"""
return []
# ---------- 对话方法 ----------
def system_message(self, message: str) -> None: def system_message(self, message: str) -> None:
"""添加系统消息"""
self.chat_history = [{"role": "system", "content": message}] self.chat_history = [{"role": "system", "content": message}]
def user_message(self, message: str) -> None: def user_message(self, message: str) -> None:
"""添加用户消息"""
self.chat_history.append({"role": "user", "content": message}) self.chat_history.append({"role": "user", "content": message})
def assistant_message(self, message: str) -> None: def assistant_message(self, message: str) -> None:
"""添加助手消息"""
self.chat_history.append({"role": "assistant", "content": message}) self.chat_history.append({"role": "assistant", "content": message})
def submit_prompt(self, prompt: str, **kwargs) -> str: def submit_prompt(self, prompt: str, **kwargs) -> str:
"""提交提示词"""
return self.generate_sql(question=prompt) return self.generate_sql(question=prompt)
# ---------- 其他方法 ----------
def remove_training_data(self, id: str, **kwargs) -> bool:
return True
def generate_embedding(self, text: str, **kwargs) -> List[float]:
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
return []
def _clean_sql_output(self, raw_sql: str) -> str: def _clean_sql_output(self, raw_sql: str) -> str:
"""增强版清洗逻辑""" """增强版清洗逻辑"""
# 移除所有非SQL内容 # 移除所有非SQL内容
@ -119,4 +186,4 @@ class VannaUtil(VannaBase):
except Exception as e: except Exception as e:
print(f"\nAPI请求错误: {str(e)}") print(f"\nAPI请求错误: {str(e)}")
return "" return ""

@ -1,84 +1,13 @@
import re import re
from typing import io
import pandas as pd
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date
from sqlalchemy.orm import declarative_base # 更新导入路径
from sqlalchemy.orm import sessionmaker, Session
import uvicorn # 导入 uvicorn import uvicorn # 导入 uvicorn
from fastapi import FastAPI, HTTPException, Depends
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil
from Text2Sql.Util.SaveToExcel import save_to_excel from Text2Sql.Util.SaveToExcel import save_to_excel_stream
from Text2Sql.Util.VannaUtil import VannaUtil from Text2Sql.Util.VannaUtil import VannaUtil
from Model.biModel import *
# 数据库连接配置
DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"
# 创建数据库引擎
engine = create_engine(DATABASE_URL)
# 创建 SessionLocal 类
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建 Base 类
Base = declarative_base()
# 定义 t_bi_question 表模型
class TBIQuestion(Base):
__tablename__ = "t_bi_question"
id = Column(Integer, primary_key=True, index=True)
question = Column(String(255), nullable=False)
state_id = Column(SmallInteger, nullable=False, default=0)
sql = Column(String(2048))
bar_chart_x_columns = Column(String(255))
bar_chart_y_columns = Column(String(255))
pie_chart_category_columns = Column(String(255))
pie_chart_value_column = Column(String(255))
session_id = Column(String(255), nullable=False)
excel_file_name = Column(String(255))
bar_chart_file_name = Column(String(255))
pie_chart_file_name = Column(String(255))
report_file_name = Column(String(255))
create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP")
is_system = Column(SmallInteger, nullable=False, default=0)
is_collect = Column(SmallInteger, nullable=False, default=0)
# 创建 Pydantic 模型
class TBIQuestionCreate(BaseModel):
question: str
state_id: int = 0
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = 0
is_collect: int = 0
class TBIQuestionUpdate(BaseModel):
question: str = None
state_id: int = None
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str = None
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = None
is_collect: int = None
# 初始化 FastAPI # 初始化 FastAPI
app = FastAPI() app = FastAPI()
@ -87,6 +16,7 @@ app = FastAPI()
def read_root(): def read_root():
return {"message": "Hello, World!"} return {"message": "Hello, World!"}
# 获取数据库会话 # 获取数据库会话
def get_db(): def get_db():
db = SessionLocal() db = SessionLocal()
@ -95,6 +25,7 @@ def get_db():
finally: finally:
db.close() db.close()
# 创建记录 # 创建记录
@app.post("/questions/", response_model=TBIQuestionCreate) @app.post("/questions/", response_model=TBIQuestionCreate)
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)): def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
@ -104,6 +35,7 @@ def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
db.refresh(db_question) db.refresh(db_question)
return db_question return db_question
# 读取记录 # 读取记录
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate) @app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
def read_question(question_id: int, db: Session = Depends(get_db)): def read_question(question_id: int, db: Session = Depends(get_db)):
@ -112,6 +44,7 @@ def read_question(question_id: int, db: Session = Depends(get_db)):
raise HTTPException(status_code=404, detail="Question not found") raise HTTPException(status_code=404, detail="Question not found")
return db_question return db_question
# 更新记录 # 更新记录
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate) @app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)): def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
@ -125,6 +58,7 @@ def update_question(question_id: int, question: TBIQuestionUpdate, db: Session =
db.refresh(db_question) db.refresh(db_question)
return db_question return db_question
# 删除记录 # 删除记录
@app.delete("/questions/{question_id}") @app.delete("/questions/{question_id}")
def delete_question(question_id: int, db: Session = Depends(get_db)): def delete_question(question_id: int, db: Session = Depends(get_db)):
@ -135,16 +69,17 @@ def delete_question(question_id: int, db: Session = Depends(get_db)):
db.commit() db.commit()
return {"message": "Question deleted successfully"} return {"message": "Question deleted successfully"}
# 通过语义生成Excel # 通过语义生成Excel
@app.post("/questions/getExcel") @app.post("/questions/get_excel")
def getExcel(question: str): def get_excel(question: str):
# 指定学段 # 指定学段
question = ''' # question = '''
查询: # 查询:
1发布时间是2024年度 # 1、发布时间是2024年度
2每个学段每个科目上传课程数量按由多到少排序 # 2、每个学段每个科目上传课程数量按由多到少排序
3字段名: 学段,科目,排名,课程数量 # 3、字段名: 学段,科目,排名,课程数量
''' # '''
common_prompt = ''' common_prompt = '''
返回的信息要求 返回的信息要求
1行政区划为NULL 或者是空字符的不参加统计 1行政区划为NULL 或者是空字符的不参加统计
@ -165,45 +100,10 @@ def getExcel(question: str):
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": "attachment; filename=导出信息.xlsx"} headers={"Content-Disposition": "attachment; filename=导出信息.xlsx"}
) )
def save_to_excel_stream(data):
# 将数据保存为Excel文件流
df = pd.DataFrame(data)
output = io.BytesIO()
with pd.ExcelWriter(output, engine='openpyxl') as writer:
df.to_excel(writer, index=False)
output.seek(0) # 将指针移动到流的开头
return output
# 添加 main 函数
def main():
# 开始训练
print("开始训练...")
# 打开AreaSchoolLesson.sql文件内容
with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file:
ddl = file.read()
# 训练数据
vn.train(
ddl=ddl
)
# 添加有关业务术语或定义的文档
# vn.train(documentation="Sql/AreaSchoolLesson.md")
# 使用 SQL 进行训练
with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file:
sql_content = file.read()
# 使用正则表达式提取注释和 SQL 语句
sql_pattern = r'/\*(.*?)\*/(.*?);'
sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL)
# 打印提取的注释和 SQL 语句
for i, (comment, sql) in enumerate(sql_snippets, 1):
vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n')
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
# 确保直接运行脚本时启动 FastAPI 应用 # 确保直接运行脚本时启动 FastAPI 应用
if __name__ == "__main__": if __name__ == "__main__":
vn = VannaUtil() vn = VannaUtil()
main() uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)

Binary file not shown.
Loading…
Cancel
Save