import json import uuid import uvicorn # 导入 uvicorn from fastapi import FastAPI, Depends, Form, Query from openai import OpenAI from starlette.responses import StreamingResponse from starlette.staticfiles import StaticFiles from Config import MODEL_API_KEY, MODEL_API_URL, MODEL_NAME from Model.biModel import * from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx from Text2Sql.Util.PostgreSQLUtil import get_db from Text2Sql.Util.SaveToExcel import save_to_excel from Text2Sql.Util.VannaUtil import VannaUtil # 初始化 FastAPI app = FastAPI() # 配置静态文件目录 app.mount("/static", StaticFiles(directory="static"), name="static") # 初始化一次vanna的类 vn = VannaUtil() @app.get("/") def read_root(): return {"message": "Welcome to AI SQL World!"} # 通过语义生成Excel # http://10.10.21.20:8000/questions/get_excel @app.post("/questions/get_excel") def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: PostgreSQLUtil = Depends(get_db)): # 只接受guid号 if len(question_id) != 36: return {"success": False, "message": "question_id格式错误"} common_prompt = ''' 返回的信息要求: 1、行政区划为NULL 或者是空字符的不参加统计 2、目标数据库是Postgresql 16 ''' question = question_str + common_prompt # 先删除后插入,防止重复插入 delete_question(db, question_id) insert_question(db, question_id, question) # 获取完整 SQL sql = vn.generate_sql(question) print("生成的查询 SQL:\n", sql) # 更新question_id update_question_by_id(db, question_id=question_id, sql=sql, state_id=1) # 执行SQL查询 _data = db.execute_query(sql) # 在static目录下,生成一个guid号的临时文件 uuid_str = str(uuid.uuid4()) filename = f"static/{uuid_str}.xlsx" save_to_excel(_data, filename) # 更新EXCEL文件名称 update_question_by_id(db, question_id, excel_file_name=filename) # 返回静态文件URL return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"} # http://10.10.21.20:8000/questions/get_docx_stream?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443 @app.api_route("/questions/get_docx_stream", methods=["POST", "GET"]) async def get_docx_stream( question_id: str = Form(None, description="问题ID(POST请求)"), # POST 请求参数 question_id_get: str = Query(None, description="问题ID(GET请求)"), # GET 请求参数 db: PostgreSQLUtil = Depends(get_db) ): # 根据请求方式获取 question_id if question_id is not None: # POST 请求 question_id = question_id elif question_id_get is not None: # GET 请求 question_id = question_id_get else: return {"success": False, "message": "缺少问题ID参数"} # 根据问题ID获取查询sql sql = get_question_by_id(db, question_id)[0]['sql'] # 4、生成word报告 prompt = ''' 请根据以下 JSON 数据,整理出2000字左右的话描述当前数据情况。要求: 1、以Markdown格式返回,我将直接通过markdown格式生成Word。 2、标题统一为:长春云校数据分析报告 3、内容中不要提到JSON数据,统一称:数据 4、尽量以条目列出,这样更清晰 5、数据: ''' _data = db.execute_query(sql) prompt = prompt + json.dumps(_data, ensure_ascii=False) # 初始化 OpenAI 客户端 client = OpenAI( api_key=MODEL_API_KEY, base_url=MODEL_API_URL, ) # 调用 OpenAI API 生成总结(流式输出) response = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"}, {"role": "user", "content": prompt} ], max_tokens=3000, # 控制生成内容的长度 temperature=0.7, # 控制生成内容的创造性 stream=True # 启用流式输出 ) # 生成 Word 文档的文件名 uuid_str = str(uuid.uuid4()) filename = f"static/{uuid_str}.docx" # 定义一个生成器函数,用于逐字返回流式结果 async def generate_stream(): summary = "" try: for chunk in response: if chunk.choices[0].delta.content: # 检查是否有内容 chunk_content = chunk.choices[0].delta.content # 逐字拆分并返回 for char in chunk_content: yield char.encode("utf-8") # 将字符编码为 UTF-8 字节 summary += char # 将内容拼接到 summary 中 # 流式传输完成后,生成 Word 文档 markdown_to_docx(summary, output_file=filename) # 记录到数据库 update_question_by_id(db, question_id, docx_file_name=filename) except Exception as e: # 如果发生异常,返回错误信息 error_response = json.dumps({ "success": False, "message": f"生成Word文件失败: {str(e)}" }) yield error_response.encode("utf-8") # 将错误信息编码为 UTF-8 字节 finally: # 确保资源释放 if "response" in locals(): response.close() # 使用 StreamingResponse 返回流式结果 return StreamingResponse( generate_stream(), media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8 headers={ "Cache-Control": "no-cache", # 禁用缓存 "Content-Type": "text/event-stream; charset=utf-8", # 设置内容类型和字符编码 "Transfer-Encoding": "chunked", "Connection": "keep-alive", "X-Accel-Buffering": "no", # 禁用 Nginx 缓冲(如果使用 Nginx) } ) # 返回生成的Word文件下载地址 # http://10.10.21.20:8000/questions/get_docx_file?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443 @app.api_route("/questions/get_docx_file", methods=["POST", "GET"]) async def get_docx_file( question_id: str = Form(None, description="问题ID(POST请求)"), # POST 请求参数 question_id_get: str = Query(None, description="问题ID(GET请求)"), # GET 请求参数 db: PostgreSQLUtil = Depends(get_db) ): # 根据请求方式获取 question_id if question_id is not None: # POST 请求 question_id = question_id elif question_id_get is not None: # GET 请求 question_id = question_id_get else: return {"success": False, "message": "缺少问题ID参数"} # 根据问题ID获取查询docx_file_name docx_file_name = get_question_by_id(db, question_id)[0]['docx_file_name'] # 返回成功和静态文件的URL return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"} # 设置问题为系统推荐问题 ,0:取消,1:设置 @app.post("/questions/set_system_recommend") def set_system_recommend(question_id: str = Form(...), flag: str = Form(...), db: PostgreSQLUtil = Depends(get_db)): set_system_recommend_questions(db, question_id, flag) # 提示保存成功 return {"success": True, "message": "保存成功"} # 设置问题为用户收藏问题 ,0:取消,1:设置 @app.post("/questions/set_user_collect") def set_user_collect(question_id: str = Form(...), flag: str = Form(...), db: PostgreSQLUtil = Depends(get_db)): set_user_collect_questions(db, question_id, flag) # 提示保存成功 return {"success": True, "message": "保存成功"} # 查询有哪些系统推荐问题 @app.get("/questions/get_system_recommend") def get_system_recommend(db: PostgreSQLUtil = Depends(get_db)): # 查询所有系统推荐问题 system_recommend_questions = get_system_recommend_questions(db) # 返回查询结果 return {"success": True, "data": system_recommend_questions} # 查询有哪些用户收藏问题 @app.get("/questions/get_user_collect") def get_user_collect(db: PostgreSQLUtil = Depends(get_db)): # 查询所有用户收藏问题 user_collect_questions = get_user_collect_questions(db) # 返回查询结果 return {"success": True, "data": user_collect_questions} # 确保直接运行脚本时启动 FastAPI 应用 if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)