main
HuangHai 4 months ago
parent 8aca60cf07
commit ae86048959

@ -0,0 +1,70 @@
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date
from sqlalchemy.orm import declarative_base # 更新导入路径
from sqlalchemy.orm import sessionmaker
# 数据库连接配置
DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"
# 创建数据库引擎
engine = create_engine(DATABASE_URL)
# 创建 SessionLocal 类
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建 Base 类
Base = declarative_base()
# 定义 t_bi_question 表模型
class TBIQuestion(Base):
__tablename__ = "t_bi_question"
id = Column(Integer, primary_key=True, index=True)
question = Column(String(255), nullable=False)
state_id = Column(SmallInteger, nullable=False, default=0)
sql = Column(String(2048))
bar_chart_x_columns = Column(String(255))
bar_chart_y_columns = Column(String(255))
pie_chart_category_columns = Column(String(255))
pie_chart_value_column = Column(String(255))
session_id = Column(String(255), nullable=False)
excel_file_name = Column(String(255))
bar_chart_file_name = Column(String(255))
pie_chart_file_name = Column(String(255))
report_file_name = Column(String(255))
create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP")
is_system = Column(SmallInteger, nullable=False, default=0)
is_collect = Column(SmallInteger, nullable=False, default=0)
# 创建 Pydantic 模型
class TBIQuestionCreate(BaseModel):
question: str
state_id: int = 0
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = 0
is_collect: int = 0
class TBIQuestionUpdate(BaseModel):
question: str = None
state_id: int = None
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str = None
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = None
is_collect: int = None

@ -1,6 +1,9 @@
import io
import pandas as pd
from openpyxl import Workbook
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
from openpyxl.utils import get_column_letter
import pandas as pd
def save_to_excel(data, filename):
"""
@ -61,4 +64,77 @@ def save_to_excel(data, filename):
column_width = min(max(max_length + 2, 10)*2, 120) # 加 2 是为了留出一些空白
# 设置列宽
column_letter = get_column_letter(idx + 1)
worksheet.column_dimensions[column_letter].width = column_width
worksheet.column_dimensions[column_letter].width = column_width
def save_to_excel_stream(data):
"""
将数据集保存为格式化的Excel文件流
参数
data - 数据集 (列表字典格式例如[{"列1": "值1", "列2": "值2"}, ...])
返回
BytesIO - 包含Excel文件内容的流
"""
# 转换数据为DataFrame
df = pd.DataFrame(data)
# 创建一个 BytesIO 对象作为缓冲区
output = io.BytesIO()
# 创建Excel工作簿
wb = Workbook()
ws = wb.active
ws.title = '统计报表'
# 写入数据
for row in pd.DataFrame(data).itertuples(index=False):
ws.append(row)
# 定义边框样式
thin_border = Border(left=Side(style='thin'),
right=Side(style='thin'),
top=Side(style='thin'),
bottom=Side(style='thin'))
# 设置全局行高
for row in ws.iter_rows():
ws.row_dimensions[row[0].row].height = 20
# 为所有单元格添加边框
for cell in row:
cell.border = thin_border
# 设置标题样式
header_font = Font(bold=True, size=14)
header_fill = PatternFill(start_color='ADD8E6', end_color='ADD8E6', fill_type='solid')
for cell in ws[1]:
cell.font = header_font
cell.fill = header_fill
cell.alignment = Alignment(horizontal='center', vertical='center')
# 设置数据行样式
data_font = Font(size=14)
for row in ws.iter_rows(min_row=2):
for cell in row:
cell.font = data_font
cell.alignment = Alignment(vertical='center', wrap_text=True)
# 动态设置列宽
for idx, column in enumerate(df.columns):
# 获取列的最大长度
max_length = max(
df[column].astype(str).map(len).max(), # 数据列的最大长度
len(str(column)) # 列名的长度
)
# 计算列宽,确保在 10 到 120 之间
column_width = min(max(max_length + 2, 10) * 2, 120) # 加 2 是为了留出一些空白
# 设置列宽
column_letter = get_column_letter(idx + 1)
ws.column_dimensions[column_letter].width = column_width
# 将工作簿保存到 BytesIO 对象
wb.save(output)
output.seek(0) # 将指针移动到流的开头
return output

@ -1,61 +1,128 @@
import re
from typing import List, Dict, Any
import requests
import sqlite3
from vanna.base import VannaBase
from Config import *
class VannaUtil(VannaBase):
def __init__(self):
def __init__(self, db_type='sqlite', db_uri=None):
super().__init__()
self.api_key = MODEL_API_KEY
self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址
self.model = QWEN_MODEL_NAME # 根据实际模型名称调整
self.training_data = []
self.chat_history = []
self.db_type = db_type
self.db_uri = db_uri or 'vanna.db' # 默认使用 SQLite
self._init_db()
def _init_db(self):
"""初始化数据库连接"""
if self.db_type == 'sqlite':
self.conn = sqlite3.connect(self.db_uri)
self._create_tables()
elif self.db_type == 'postgres':
import psycopg2
self.conn = psycopg2.connect(self.db_uri)
self._create_tables()
else:
raise ValueError(f"Unsupported database type: {self.db_type}")
def _create_tables(self):
"""创建训练数据表"""
cursor = self.conn.cursor()
if self.db_type == 'sqlite':
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
elif self.db_type == 'postgres':
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id SERIAL PRIMARY KEY,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
self.conn.commit()
# ---------- 必须实现的抽象方法 ----------
def add_ddl(self, ddl: str, **kwargs) -> None:
self.training_data.append({"type": "ddl", "content": ddl})
"""添加 DDL"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('ddl', ddl))
self.conn.commit()
def add_documentation(self, doc: str, **kwargs) -> None:
self.training_data.append({"type": "documentation", "content": doc})
"""添加文档"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('documentation', doc))
self.conn.commit()
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
self.training_data.append({"type": "qa", "question": question, "sql": sql})
"""添加问答对"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (?, ?, ?)', ('qa', question, sql))
self.conn.commit()
def get_related_ddl(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"])
"""获取相关 DDL"""
cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('ddl',))
return "\n".join(row[0] for row in cursor.fetchall())
def get_related_documentation(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"])
"""获取相关文档"""
cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('documentation',))
return "\n".join(row[0] for row in cursor.fetchall())
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
return self.training_data
"""获取所有训练数据"""
cursor = self.conn.cursor()
cursor.execute('SELECT * FROM training_data')
columns = [column[0] for column in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
def remove_training_data(self, id: str, **kwargs) -> bool:
"""删除训练数据"""
cursor = self.conn.cursor()
cursor.execute('DELETE FROM training_data WHERE id = ?', (id,))
self.conn.commit()
return cursor.rowcount > 0
def generate_embedding(self, text: str, **kwargs) -> List[float]:
"""生成嵌入向量"""
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
"""获取相似问答对"""
return []
# ---------- 对话方法 ----------
def system_message(self, message: str) -> None:
"""添加系统消息"""
self.chat_history = [{"role": "system", "content": message}]
def user_message(self, message: str) -> None:
"""添加用户消息"""
self.chat_history.append({"role": "user", "content": message})
def assistant_message(self, message: str) -> None:
"""添加助手消息"""
self.chat_history.append({"role": "assistant", "content": message})
def submit_prompt(self, prompt: str, **kwargs) -> str:
"""提交提示词"""
return self.generate_sql(question=prompt)
# ---------- 其他方法 ----------
def remove_training_data(self, id: str, **kwargs) -> bool:
return True
def generate_embedding(self, text: str, **kwargs) -> List[float]:
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
return []
def _clean_sql_output(self, raw_sql: str) -> str:
"""增强版清洗逻辑"""
# 移除所有非SQL内容
@ -119,4 +186,4 @@ class VannaUtil(VannaBase):
except Exception as e:
print(f"\nAPI请求错误: {str(e)}")
return ""
return ""

@ -1,84 +1,13 @@
import re
from typing import io
import pandas as pd
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date
from sqlalchemy.orm import declarative_base # 更新导入路径
from sqlalchemy.orm import sessionmaker, Session
import uvicorn # 导入 uvicorn
from fastapi import FastAPI, HTTPException, Depends
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil
from Text2Sql.Util.SaveToExcel import save_to_excel
from Text2Sql.Util.SaveToExcel import save_to_excel_stream
from Text2Sql.Util.VannaUtil import VannaUtil
# 数据库连接配置
DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"
# 创建数据库引擎
engine = create_engine(DATABASE_URL)
# 创建 SessionLocal 类
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建 Base 类
Base = declarative_base()
# 定义 t_bi_question 表模型
class TBIQuestion(Base):
__tablename__ = "t_bi_question"
id = Column(Integer, primary_key=True, index=True)
question = Column(String(255), nullable=False)
state_id = Column(SmallInteger, nullable=False, default=0)
sql = Column(String(2048))
bar_chart_x_columns = Column(String(255))
bar_chart_y_columns = Column(String(255))
pie_chart_category_columns = Column(String(255))
pie_chart_value_column = Column(String(255))
session_id = Column(String(255), nullable=False)
excel_file_name = Column(String(255))
bar_chart_file_name = Column(String(255))
pie_chart_file_name = Column(String(255))
report_file_name = Column(String(255))
create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP")
is_system = Column(SmallInteger, nullable=False, default=0)
is_collect = Column(SmallInteger, nullable=False, default=0)
# 创建 Pydantic 模型
class TBIQuestionCreate(BaseModel):
question: str
state_id: int = 0
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = 0
is_collect: int = 0
class TBIQuestionUpdate(BaseModel):
question: str = None
state_id: int = None
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str = None
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = None
is_collect: int = None
from Model.biModel import *
# 初始化 FastAPI
app = FastAPI()
@ -87,6 +16,7 @@ app = FastAPI()
def read_root():
return {"message": "Hello, World!"}
# 获取数据库会话
def get_db():
db = SessionLocal()
@ -95,6 +25,7 @@ def get_db():
finally:
db.close()
# 创建记录
@app.post("/questions/", response_model=TBIQuestionCreate)
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
@ -104,6 +35,7 @@ def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
db.refresh(db_question)
return db_question
# 读取记录
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
def read_question(question_id: int, db: Session = Depends(get_db)):
@ -112,6 +44,7 @@ def read_question(question_id: int, db: Session = Depends(get_db)):
raise HTTPException(status_code=404, detail="Question not found")
return db_question
# 更新记录
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
@ -125,6 +58,7 @@ def update_question(question_id: int, question: TBIQuestionUpdate, db: Session =
db.refresh(db_question)
return db_question
# 删除记录
@app.delete("/questions/{question_id}")
def delete_question(question_id: int, db: Session = Depends(get_db)):
@ -135,16 +69,17 @@ def delete_question(question_id: int, db: Session = Depends(get_db)):
db.commit()
return {"message": "Question deleted successfully"}
# 通过语义生成Excel
@app.post("/questions/getExcel")
def getExcel(question: str):
@app.post("/questions/get_excel")
def get_excel(question: str):
# 指定学段
question = '''
查询:
1发布时间是2024年度
2每个学段每个科目上传课程数量按由多到少排序
3字段名: 学段,科目,排名,课程数量
'''
# question = '''
# 查询:
# 1、发布时间是2024年度
# 2、每个学段每个科目上传课程数量按由多到少排序
# 3、字段名: 学段,科目,排名,课程数量
# '''
common_prompt = '''
返回的信息要求
1行政区划为NULL 或者是空字符的不参加统计
@ -165,45 +100,10 @@ def getExcel(question: str):
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": "attachment; filename=导出信息.xlsx"}
)
def save_to_excel_stream(data):
# 将数据保存为Excel文件流
df = pd.DataFrame(data)
output = io.BytesIO()
with pd.ExcelWriter(output, engine='openpyxl') as writer:
df.to_excel(writer, index=False)
output.seek(0) # 将指针移动到流的开头
return output
# 添加 main 函数
def main():
# 开始训练
print("开始训练...")
# 打开AreaSchoolLesson.sql文件内容
with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file:
ddl = file.read()
# 训练数据
vn.train(
ddl=ddl
)
# 添加有关业务术语或定义的文档
# vn.train(documentation="Sql/AreaSchoolLesson.md")
# 使用 SQL 进行训练
with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file:
sql_content = file.read()
# 使用正则表达式提取注释和 SQL 语句
sql_pattern = r'/\*(.*?)\*/(.*?);'
sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL)
# 打印提取的注释和 SQL 语句
for i, (comment, sql) in enumerate(sql_snippets, 1):
vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n')
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
# 确保直接运行脚本时启动 FastAPI 应用
if __name__ == "__main__":
vn = VannaUtil()
main()
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)

Binary file not shown.
Loading…
Cancel
Save