main
HuangHai 4 months ago
parent d29e6d681a
commit 9ac006e604

@ -16,47 +16,6 @@ from Util.EchartsUtil import *
3应该有类似于 保存为用例查询历史等功能让用户方便利旧
'''
def infer_axes_fields(field_names, sample_data):
"""
使用 AI 大模型推断 X 轴和 Y 轴的字段
:param field_names: 数据字段名列表
:param sample_data: 数据示例前几行
:return: X 轴字段, Y 轴字段
"""
client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) # 初始化OpenAI客户端
# 构造提示词
prompt = f"""
以下是数据的字段名和示例数据
字段名: {field_names}
示例数据: {sample_data}
请根据这些数据推荐适合用于柱状图的 X 轴和 Y 轴字段
X 轴应为分类字段如学段科目行政区划等规则如下
1. 如果学段科目同时存在返回学段+科目
2. 如果学段科目都不存在返回学段
3. 如果只有学段或科目存在返回存在的字段
Y 轴应为数值字段如课程数量数量等
请直接返回 X 轴和 Y 轴的字段名格式为X_轴字段, Y_轴字段
"""
# 调用 AI 大模型
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个数据分析助手,帮助用户选择合适的字段生成图表。"},
{"role": "user", "content": prompt}
],
max_tokens=50
)
# 解析 AI 返回的结果
result = response.choices[0].message.content.strip()
x_column, y_column = result.split(", ")
return x_column, y_column
if __name__ == "__main__":
vn = VannaUtil()
@ -70,7 +29,7 @@ if __name__ == "__main__":
ddl=ddl
)
# 添加有关业务术语或定义的文档
vn.train(documentation="Sql/AreaSchoolLesson.md")
# vn.train(documentation="Sql/AreaSchoolLesson.md")
# 使用 SQL 进行训练
with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file:
@ -127,26 +86,20 @@ if __name__ == "__main__":
field_names = list(_data[0].keys()) if _data else []
sample_data = _data[:3] # 取前 3 行作为示例数据
# 推断 X 轴和 Y 轴字段
x_column, y_column = infer_axes_fields(field_names, sample_data)
x_columns = x_column.split('+')
y_columns = y_column.split('+')
# 1、生成柱状图
generate_bar_chart(
_data=_data,
title="学段+科目课程数量柱状图",
x_columns=x_columns, # 动态指定 X 轴列
y_columns=y_columns, # 动态指定 Y 轴列
x_columns=['学段', '科目'], # 动态指定 X 轴列
y_columns=['课程数量'], # 动态指定 Y 轴列
output_file="d:/lesson_bar_chart.html"
)
# 2、生成饼状图
generate_pie_chart(
_data=_data,
title="学段+科目分布",
category_columns=x_columns, # 多列组合参数
value_column=y_columns[0],
category_columns=['学段', '科目'], # 多列组合参数
value_column='课程数量',
output_file="d:/lesson_pie_chart.html"
)

@ -0,0 +1,188 @@
import re
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date
from sqlalchemy.orm import declarative_base # 更新导入路径
from sqlalchemy.orm import sessionmaker, Session
import uvicorn # 导入 uvicorn
from starlette.staticfiles import StaticFiles
from Text2Sql.Util.VannaUtil import VannaUtil
# 数据库连接配置
DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"
# 创建数据库引擎
engine = create_engine(DATABASE_URL)
# 创建 SessionLocal 类
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建 Base 类
Base = declarative_base()
# 定义 t_bi_question 表模型
class TBIQuestion(Base):
__tablename__ = "t_bi_question"
id = Column(Integer, primary_key=True, index=True)
question = Column(String(255), nullable=False)
state_id = Column(SmallInteger, nullable=False, default=0)
sql = Column(String(2048))
bar_chart_x_columns = Column(String(255))
bar_chart_y_columns = Column(String(255))
pie_chart_category_columns = Column(String(255))
pie_chart_value_column = Column(String(255))
session_id = Column(String(255), nullable=False)
excel_file_name = Column(String(255))
bar_chart_file_name = Column(String(255))
pie_chart_file_name = Column(String(255))
report_file_name = Column(String(255))
create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP")
is_system = Column(SmallInteger, nullable=False, default=0)
is_collect = Column(SmallInteger, nullable=False, default=0)
# 创建 Pydantic 模型
class TBIQuestionCreate(BaseModel):
question: str
state_id: int = 0
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = 0
is_collect: int = 0
class TBIQuestionUpdate(BaseModel):
question: str = None
state_id: int = None
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str = None
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = None
is_collect: int = None
# 初始化 FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "Hello, World!"}
# 获取数据库会话
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# 创建记录
@app.post("/questions/", response_model=TBIQuestionCreate)
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
db_question = TBIQuestion(**question.dict())
db.add(db_question)
db.commit()
db.refresh(db_question)
return db_question
# 读取记录
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
def read_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
return db_question
# 更新记录
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
for key, value in question.dict().items():
if value is not None:
setattr(db_question, key, value)
db.commit()
db.refresh(db_question)
return db_question
# 删除记录
@app.delete("/questions/{question_id}")
def delete_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
db.delete(db_question)
db.commit()
return {"message": "Question deleted successfully"}
@app.post("/questions/generate_sql")
def generate_sql(question: str):
# 指定学段
question = '''
查询:
1发布时间是2024年度
2每个学段每个科目上传课程数量按由多到少排序
3字段名: 学段,科目,排名,课程数量
'''
common_prompt = '''
返回的信息要求
1行政区划为NULL 或者是空字符的不参加统计
2目标数据库是Postgresql 16
'''
question = question + common_prompt
# 开始查询
print("开始查询...")
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
# 添加 main 函数
def main():
# 开始训练
print("开始训练...")
# 打开AreaSchoolLesson.sql文件内容
with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file:
ddl = file.read()
# 训练数据
vn.train(
ddl=ddl
)
# 添加有关业务术语或定义的文档
# vn.train(documentation="Sql/AreaSchoolLesson.md")
# 使用 SQL 进行训练
with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file:
sql_content = file.read()
# 使用正则表达式提取注释和 SQL 语句
sql_pattern = r'/\*(.*?)\*/(.*?);'
sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL)
# 打印提取的注释和 SQL 语句
for i, (comment, sql) in enumerate(sql_snippets, 1):
vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n')
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
# 确保直接运行脚本时启动 FastAPI 应用
if __name__ == "__main__":
vn = VannaUtil()
main()
Loading…
Cancel
Save