You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
6.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import json
import uuid
import uvicorn # 导入 uvicorn
from fastapi import FastAPI, Depends, Form, Query
from openai import OpenAI
from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles
from Config import MODEL_API_KEY, MODEL_API_URL, MODEL_NAME
from Model.biModel import *
from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx
from Text2Sql.Util.PostgreSQLUtil import get_db
from Text2Sql.Util.SaveToExcel import save_to_excel
from Text2Sql.Util.VannaUtil import VannaUtil
# 初始化 FastAPI
app = FastAPI()
# 配置静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
# 初始化一次vanna的类
vn = VannaUtil()
@app.get("/")
def read_root():
return {"message": "Welcome to AI SQL World!"}
# 通过语义生成Excel
# http://10.10.21.20:8000/questions/get_excel
@app.post("/questions/get_excel")
def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: PostgreSQLUtil = Depends(get_db)):
# 只接受guid号
if len(question_id) != 36:
return {"success": False, "message": "question_id格式错误"}
common_prompt = '''
返回的信息要求:
1、行政区划为NULL 或者是空字符的不参加统计
2、目标数据库是Postgresql 16
'''
question = question_str + common_prompt
# 先删除后插入,防止重复插入
delete_question(db, question_id)
insert_question(db, question_id, question)
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
# 更新question_id
update_question_by_id(db, question_id=question_id, sql=sql, state_id=1)
# 执行SQL查询
_data = db.execute_query(sql)
# 在static目录下生成一个guid号的临时文件
uuid_str = str(uuid.uuid4())
filename = f"static/{uuid_str}.xlsx"
save_to_excel(_data, filename)
# 更新EXCEL文件名称
update_question_by_id(db, question_id, excel_file_name=filename)
# 返回静态文件URL
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"}
# http://10.10.21.20:8000/questions/get_docx?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
@app.api_route("/questions/get_docx", methods=["POST", "GET"])
async def get_docx(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数
db: PostgreSQLUtil = Depends(get_db)
):
# 根据请求方式获取 question_id
if question_id is not None: # POST 请求
question_id = question_id
elif question_id_get is not None: # GET 请求
question_id = question_id_get
else:
return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询sql
sql = get_question_sql_by_id(db, question_id)
# 4、生成word报告
prompt = '''
请根据以下 JSON 数据整理出2000字左右的话描述当前数据情况。要求
1、以Markdown格式返回我将直接通过markdown格式生成Word。
2、标题统一为长春云校数据分析报告
3、内容中不要提到JSON数据统一称数据
4、尽量以条目列出这样更清晰
5、数据
'''
_data = db.execute_query(sql)
prompt = prompt + json.dumps(_data, ensure_ascii=False)
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=MODEL_API_KEY,
base_url=MODEL_API_URL,
)
# 调用 OpenAI API 生成总结(流式输出)
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"},
{"role": "user", "content": prompt}
],
max_tokens=3000, # 控制生成内容的长度
temperature=0.7, # 控制生成内容的创造性
stream=True # 启用流式输出
)
# 生成 Word 文档的文件名
uuid_str = str(uuid.uuid4())
filename = f"static/{uuid_str}.docx"
# 定义一个生成器函数,用于逐字返回流式结果
async def generate_stream():
summary = ""
try:
for chunk in response:
if chunk.choices[0].delta.content: # 检查是否有内容
chunk_content = chunk.choices[0].delta.content
# 逐字拆分并返回
for char in chunk_content:
print(char, end="", flush=True) # 逐字输出到控制台
yield char.encode("utf-8") # 将字符编码为 UTF-8 字节
summary += char # 将内容拼接到 summary 中
# 流式传输完成后,生成 Word 文档
markdown_to_docx(summary, output_file=filename)
# 返回最终的 JSON 数据
final_response = json.dumps({
"success": True,
"message": "Word文件生成成功",
"download_url": f"/static/{uuid_str}.docx"
})
print(final_response) # 输出最终 JSON 到控制台
yield final_response.encode("utf-8") # 将 JSON 数据编码为 UTF-8 字节
except Exception as e:
# 如果发生异常,返回错误信息
error_response = json.dumps({
"success": False,
"message": f"生成Word文件失败: {str(e)}"
})
print(error_response) # 输出错误信息到控制台
yield error_response.encode("utf-8") # 将错误信息编码为 UTF-8 字节
finally:
# 确保资源释放
if "response" in locals():
response.close()
# 使用 StreamingResponse 返回流式结果
return StreamingResponse(
generate_stream(),
media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8
headers={
"Cache-Control": "no-cache", # 禁用缓存
"Content-Type": "text/plain; charset=utf-8", # 设置内容类型和字符编码
}
)
# 确保直接运行脚本时启动 FastAPI 应用
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)