|
|
|
@ -1,13 +1,15 @@
|
|
|
|
|
import json
|
|
|
|
|
import uuid
|
|
|
|
|
|
|
|
|
|
import uvicorn # 导入 uvicorn
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Depends, Form
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
from starlette.responses import JSONResponse
|
|
|
|
|
from fastapi import FastAPI, Depends, Form
|
|
|
|
|
from openai import OpenAI
|
|
|
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
|
|
|
|
|
|
from Config import MODEL_API_KEY, MODEL_API_URL, MODEL_NAME
|
|
|
|
|
from Model.biModel import *
|
|
|
|
|
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil
|
|
|
|
|
from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx
|
|
|
|
|
from Text2Sql.Util.PostgreSQLUtil import get_db
|
|
|
|
|
from Text2Sql.Util.SaveToExcel import save_to_excel
|
|
|
|
|
from Text2Sql.Util.VannaUtil import VannaUtil
|
|
|
|
|
|
|
|
|
@ -15,81 +17,113 @@ from Text2Sql.Util.VannaUtil import VannaUtil
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
# 配置静态文件目录
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
|
|
|
|
# 初始化一次vanna的类
|
|
|
|
|
vn = VannaUtil()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
|
|
def read_root():
|
|
|
|
|
return {"message": "Welcome to Vanna AI SQL !"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 创建记录
|
|
|
|
|
@app.post("/questions/", response_model=TBIQuestionCreate)
|
|
|
|
|
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
|
|
|
|
|
db_question = TBIQuestion(**question.dict())
|
|
|
|
|
db.add(db_question)
|
|
|
|
|
db.commit()
|
|
|
|
|
db.refresh(db_question)
|
|
|
|
|
return db_question
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 读取记录
|
|
|
|
|
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
|
|
|
|
|
def read_question(question_id: int, db: Session = Depends(get_db)):
|
|
|
|
|
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
|
|
|
|
|
if db_question is None:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Question not found")
|
|
|
|
|
return db_question
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 更新记录
|
|
|
|
|
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
|
|
|
|
|
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
|
|
|
|
|
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
|
|
|
|
|
if db_question is None:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Question not found")
|
|
|
|
|
for key, value in question.dict().items():
|
|
|
|
|
if value is not None:
|
|
|
|
|
setattr(db_question, key, value)
|
|
|
|
|
db.commit()
|
|
|
|
|
db.refresh(db_question)
|
|
|
|
|
return db_question
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 删除记录
|
|
|
|
|
@app.delete("/questions/{question_id}")
|
|
|
|
|
def delete_question(question_id: int, db: Session = Depends(get_db)):
|
|
|
|
|
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
|
|
|
|
|
if db_question is None:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Question not found")
|
|
|
|
|
db.delete(db_question)
|
|
|
|
|
db.commit()
|
|
|
|
|
return {"message": "Question deleted successfully"}
|
|
|
|
|
return {"message": "Welcome to AI SQL World!"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 通过语义生成Excel
|
|
|
|
|
# http://10.10.21.20:8000/questions/get_excel
|
|
|
|
|
# 参数:question
|
|
|
|
|
@app.post("/questions/get_excel")
|
|
|
|
|
def get_excel(question: str = Form(...)):
|
|
|
|
|
def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: PostgreSQLUtil = Depends(get_db)):
|
|
|
|
|
# 只接受guid号
|
|
|
|
|
if len(question_id) != 36:
|
|
|
|
|
return {"success": False, "message": "question_id格式错误"}
|
|
|
|
|
|
|
|
|
|
common_prompt = '''
|
|
|
|
|
返回的信息要求:
|
|
|
|
|
1、行政区划为NULL 或者是空字符的不参加统计
|
|
|
|
|
2、目标数据库是Postgresql 16
|
|
|
|
|
'''
|
|
|
|
|
question = question + common_prompt
|
|
|
|
|
question = question_str + common_prompt
|
|
|
|
|
|
|
|
|
|
# 先删除后插入,防止重复插入
|
|
|
|
|
delete_question(db, question_id)
|
|
|
|
|
insert_question(db, question_id, question)
|
|
|
|
|
|
|
|
|
|
# 获取完整 SQL
|
|
|
|
|
sql = vn.generate_sql(question)
|
|
|
|
|
print("生成的查询 SQL:\n", sql)
|
|
|
|
|
# 执行SQL查询
|
|
|
|
|
with PostgreSQLUtil() as db:
|
|
|
|
|
_data = db.execute_query(sql)
|
|
|
|
|
# 在static目录下,生成一个guid号的临时文件
|
|
|
|
|
uuidStr = str(uuid.uuid4())
|
|
|
|
|
filename = f"static/{uuidStr}.xlsx"
|
|
|
|
|
save_to_excel(_data, filename)
|
|
|
|
|
# 返回静态文件URL
|
|
|
|
|
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuidStr}.xlsx"}
|
|
|
|
|
|
|
|
|
|
# 更新question_id
|
|
|
|
|
update_question_by_id(db, question_id=question_id, sql=sql, state_id=1)
|
|
|
|
|
|
|
|
|
|
# 执行SQL查询
|
|
|
|
|
_data = db.execute_query(sql)
|
|
|
|
|
# 在static目录下,生成一个guid号的临时文件
|
|
|
|
|
uuid_str = str(uuid.uuid4())
|
|
|
|
|
filename = f"static/{uuid_str}.xlsx"
|
|
|
|
|
save_to_excel(_data, filename)
|
|
|
|
|
# 更新EXCEL文件名称
|
|
|
|
|
update_question_by_id(db, question_id, excel_file_name=filename)
|
|
|
|
|
# 返回静态文件URL
|
|
|
|
|
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取docx
|
|
|
|
|
# http://10.10.21.20:8000/questions/get_docx
|
|
|
|
|
@app.post("/questions/get_docx")
|
|
|
|
|
def get_docx(question_id: str = Form(...), db: PostgreSQLUtil = Depends(get_db)):
|
|
|
|
|
select_sql = """
|
|
|
|
|
select * from t_bi_question where id=%s
|
|
|
|
|
"""
|
|
|
|
|
_data = db.execute_query(select_sql, (question_id,))
|
|
|
|
|
sql = _data[0]['sql']
|
|
|
|
|
|
|
|
|
|
# 4、生成word报告
|
|
|
|
|
prompt = '''
|
|
|
|
|
请根据以下 JSON 数据,整理出2000字左右的话描述当前数据情况。要求:
|
|
|
|
|
1、以Markdown格式返回,我将直接通过markdown格式生成Word。
|
|
|
|
|
2、标题统一为:长春云校数据分析报告
|
|
|
|
|
3、内容中不要提到JSON数据,统一称:数据
|
|
|
|
|
4、尽量以条目列出,这样更清晰
|
|
|
|
|
5、数据:
|
|
|
|
|
'''
|
|
|
|
|
_data = db.execute_query(sql)
|
|
|
|
|
prompt = prompt + json.dumps(_data, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
# 初始化 OpenAI 客户端
|
|
|
|
|
client = OpenAI(
|
|
|
|
|
api_key=MODEL_API_KEY,
|
|
|
|
|
base_url=MODEL_API_URL,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 调用 OpenAI API 生成总结(流式输出)
|
|
|
|
|
response = client.chat.completions.create(
|
|
|
|
|
model=MODEL_NAME,
|
|
|
|
|
messages=[
|
|
|
|
|
{"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"},
|
|
|
|
|
{"role": "user", "content": prompt}
|
|
|
|
|
],
|
|
|
|
|
max_tokens=3000, # 控制生成内容的长度
|
|
|
|
|
temperature=0.7, # 控制生成内容的创造性
|
|
|
|
|
stream=True # 启用流式输出
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 初始化变量用于存储流式输出的内容
|
|
|
|
|
summary = ""
|
|
|
|
|
|
|
|
|
|
# 处理流式输出
|
|
|
|
|
for chunk in response:
|
|
|
|
|
if chunk.choices[0].delta.content: # 检查是否有内容
|
|
|
|
|
chunk_content = chunk.choices[0].delta.content
|
|
|
|
|
print(chunk_content, end="", flush=True) # 实时打印到控制台
|
|
|
|
|
summary += chunk_content # 将内容拼接到 summary 中
|
|
|
|
|
|
|
|
|
|
# 最终 summary 为完整的 Markdown 内容
|
|
|
|
|
print("\n\n流式输出完成,summary 已拼接为完整字符串。")
|
|
|
|
|
# 生成 Word 文档
|
|
|
|
|
uuid_str = str(uuid.uuid4())
|
|
|
|
|
filename = f"static/{uuid_str}.docx"
|
|
|
|
|
markdown_to_docx(summary, output_file=filename)
|
|
|
|
|
|
|
|
|
|
# 返回静态文件URL
|
|
|
|
|
return {"success": True, "message": "Word文件生成成功", "download_url": f"/static/{uuid_str}.docx"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 确保直接运行脚本时启动 FastAPI 应用
|
|
|
|
|