You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

222 lines
8.7 KiB

4 months ago
import json
4 months ago
import uuid
4 months ago
import uvicorn # 导入 uvicorn
4 months ago
from fastapi import FastAPI, Depends, Form, Query
4 months ago
from openai import OpenAI
4 months ago
from starlette.responses import StreamingResponse
4 months ago
from starlette.staticfiles import StaticFiles
4 months ago
from Config import MODEL_API_KEY, MODEL_API_URL, MODEL_NAME
4 months ago
from Model.biModel import *
4 months ago
from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx
from Text2Sql.Util.PostgreSQLUtil import get_db
4 months ago
from Text2Sql.Util.SaveToExcel import save_to_excel
4 months ago
from Text2Sql.Util.VannaUtil import VannaUtil
# 初始化 FastAPI
app = FastAPI()
4 months ago
# 配置静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
4 months ago
# 初始化一次vanna的类
4 months ago
vn = VannaUtil()
4 months ago
4 months ago
@app.get("/")
def read_root():
4 months ago
return {"message": "Welcome to AI SQL World!"}
4 months ago
4 months ago
4 months ago
# 通过语义生成Excel
4 months ago
# http://10.10.21.20:8000/questions/get_excel
4 months ago
@app.post("/questions/get_excel")
4 months ago
def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: PostgreSQLUtil = Depends(get_db)):
# 只接受guid号
if len(question_id) != 36:
return {"success": False, "message": "question_id格式错误"}
4 months ago
4 months ago
common_prompt = '''
返回的信息要求
1行政区划为NULL 或者是空字符的不参加统计
2目标数据库是Postgresql 16
'''
4 months ago
question = question_str + common_prompt
# 先删除后插入,防止重复插入
delete_question(db, question_id)
insert_question(db, question_id, question)
4 months ago
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
4 months ago
# 更新question_id
update_question_by_id(db, question_id=question_id, sql=sql, state_id=1)
# 执行SQL查询
_data = db.execute_query(sql)
# 在static目录下生成一个guid号的临时文件
uuid_str = str(uuid.uuid4())
filename = f"static/{uuid_str}.xlsx"
save_to_excel(_data, filename)
# 更新EXCEL文件名称
update_question_by_id(db, question_id, excel_file_name=filename)
# 返回静态文件URL
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"}
4 months ago
# http://10.10.21.20:8000/questions/get_docx_stream?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
@app.api_route("/questions/get_docx_stream", methods=["POST", "GET"])
async def get_docx_stream(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数
db: PostgreSQLUtil = Depends(get_db)
4 months ago
):
# 根据请求方式获取 question_id
if question_id is not None: # POST 请求
question_id = question_id
elif question_id_get is not None: # GET 请求
question_id = question_id_get
else:
return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询sql
4 months ago
sql = get_question_by_id(db, question_id)[0]['sql']
4 months ago
# 4、生成word报告
prompt = '''
请根据以下 JSON 数据整理出2000字左右的话描述当前数据情况要求
1以Markdown格式返回我将直接通过markdown格式生成Word
2标题统一为长春云校数据分析报告
3内容中不要提到JSON数据统一称数据
4尽量以条目列出这样更清晰
5数据
'''
_data = db.execute_query(sql)
prompt = prompt + json.dumps(_data, ensure_ascii=False)
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=MODEL_API_KEY,
base_url=MODEL_API_URL,
)
# 调用 OpenAI API 生成总结(流式输出)
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"},
{"role": "user", "content": prompt}
],
max_tokens=3000, # 控制生成内容的长度
temperature=0.7, # 控制生成内容的创造性
stream=True # 启用流式输出
)
4 months ago
# 生成 Word 文档的文件名
4 months ago
uuid_str = str(uuid.uuid4())
filename = f"static/{uuid_str}.docx"
4 months ago
4 months ago
# 定义一个生成器函数,用于逐字返回流式结果
async def generate_stream():
summary = ""
try:
for chunk in response:
if chunk.choices[0].delta.content: # 检查是否有内容
chunk_content = chunk.choices[0].delta.content
# 逐字拆分并返回
for char in chunk_content:
yield char.encode("utf-8") # 将字符编码为 UTF-8 字节
summary += char # 将内容拼接到 summary 中
# 流式传输完成后,生成 Word 文档
markdown_to_docx(summary, output_file=filename)
4 months ago
# 记录到数据库
update_question_by_id(db, question_id, docx_file_name=filename)
4 months ago
4 months ago
except Exception as e:
# 如果发生异常,返回错误信息
error_response = json.dumps({
"success": False,
"message": f"生成Word文件失败: {str(e)}"
})
yield error_response.encode("utf-8") # 将错误信息编码为 UTF-8 字节
4 months ago
4 months ago
finally:
# 确保资源释放
if "response" in locals():
response.close()
# 使用 StreamingResponse 返回流式结果
return StreamingResponse(
generate_stream(),
media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8
headers={
"Cache-Control": "no-cache", # 禁用缓存
4 months ago
"Content-Type": "text/event-stream; charset=utf-8", # 设置内容类型和字符编码
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
4 months ago
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲(如果使用 Nginx
4 months ago
}
)
4 months ago
4 months ago
4 months ago
# 返回生成的Word文件下载地址
# http://10.10.21.20:8000/questions/get_docx_file?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
@app.api_route("/questions/get_docx_file", methods=["POST", "GET"])
async def get_docx_file(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数
db: PostgreSQLUtil = Depends(get_db)
):
# 根据请求方式获取 question_id
if question_id is not None: # POST 请求
question_id = question_id
elif question_id_get is not None: # GET 请求
question_id = question_id_get
else:
return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询docx_file_name
docx_file_name = get_question_by_id(db, question_id)[0]['docx_file_name']
# 返回成功和静态文件的URL
return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"}
4 months ago
# 设置问题为系统推荐问题 ,0:取消1设置
@app.post("/questions/set_system_recommend")
def set_system_recommend(question_id: str = Form(...), flag: str = Form(...), db: PostgreSQLUtil = Depends(get_db)):
4 months ago
set_system_recommend_questions(db, question_id, flag)
4 months ago
# 提示保存成功
return {"success": True, "message": "保存成功"}
# 设置问题为用户收藏问题 ,0:取消1设置
@app.post("/questions/set_user_collect")
def set_user_collect(question_id: str = Form(...), flag: str = Form(...), db: PostgreSQLUtil = Depends(get_db)):
4 months ago
set_user_collect_questions(db, question_id, flag)
4 months ago
# 提示保存成功
return {"success": True, "message": "保存成功"}
4 months ago
# 查询有哪些系统推荐问题
@app.get("/questions/get_system_recommend")
def get_system_recommend(db: PostgreSQLUtil = Depends(get_db)):
# 查询所有系统推荐问题
system_recommend_questions = get_system_recommend_questions(db)
# 返回查询结果
return {"success": True, "data": system_recommend_questions}
# 查询有哪些用户收藏问题
@app.get("/questions/get_user_collect")
def get_user_collect(db: PostgreSQLUtil = Depends(get_db)):
# 查询所有用户收藏问题
user_collect_questions = get_user_collect_questions(db)
# 返回查询结果
return {"success": True, "data": user_collect_questions}
4 months ago
4 months ago
# 确保直接运行脚本时启动 FastAPI 应用
if __name__ == "__main__":
4 months ago
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)