You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

255 lines
9.4 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import asyncio
import json
import uuid
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, Depends
from fastapi import Form, Query
from fastapi.staticfiles import StaticFiles
from openai import AsyncOpenAI
from starlette.responses import StreamingResponse
from Config import *
from Model.biModel import *
from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx
from Text2Sql.Util.SaveToExcel import save_to_excel
from Text2Sql.Util.VannaUtil import VannaUtil
# 初始化 FastAPI
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
vn = VannaUtil()
# 初始化 FastAPI 应用
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化连接池
app.state.pool = await asyncpg.create_pool(
host=PG_HOST,
port=PG_PORT,
database=PG_DATABASE,
user=PG_USER,
password=PG_PASSWORD,
min_size=1,
max_size=10
)
yield
# 关闭时释放连接池
await app.state.pool.close()
app = FastAPI(lifespan=lifespan)
# 依赖注入连接池
async def get_db():
async with app.state.pool.acquire() as connection:
yield connection
@app.post("/questions/get_excel")
async def get_excel(question_id: str = Form(...), question_str: str = Form(...),
db: asyncpg.Connection = Depends(get_db)):
# 只接受guid号
if len(question_id) != 36:
return {"success": False, "message": "question_id格式错误"}
common_prompt = '''
返回的信息要求:
1、行政区划为NULL 或者是空字符的不参加统计
2、目标数据库是Postgresql 16
'''
question = question_str + common_prompt
# 先删除后插入,防止重复插入
await delete_question(db, question_id)
await insert_question(db, question_id, question)
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
# 更新question_id
await update_question_by_id(db, question_id=question_id, sql=sql, state_id=1)
# 执行SQL查询
_data = await db.fetch(sql)
# 在static目录下生成一个guid号的临时文件
uuid_str = str(uuid.uuid4())
filename = f"static/{uuid_str}.xlsx"
save_to_excel(_data, filename)
# 更新EXCEL文件名称
await update_question_by_id(db, question_id, excel_file_name=filename)
# 返回静态文件URL
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"}
# http://10.10.21.20:8000/questions/get_docx_stream?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
# 初始化 OpenAI 客户端
client = AsyncOpenAI(
api_key=MODEL_API_KEY,
base_url=MODEL_API_URL,
)
@app.api_route("/questions/get_docx_stream", methods=["POST", "GET"])
async def get_docx_stream(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数
db: asyncpg.Connection = Depends(get_db)
):
# 根据请求方式获取 question_id
if question_id is not None: # POST 请求
question_id = question_id
elif question_id_get is not None: # GET 请求
question_id = question_id_get
else:
return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询sql
sql = (await db.fetch("SELECT * FROM t_bi_question WHERE id = $1", question_id))[0]['sql']
# 生成word报告
prompt = '''
请根据以下 JSON 数据整理出2000字左右的话描述当前数据情况。要求
1、以Markdown格式返回我将直接通过markdown格式生成Word。
2、标题统一为长春云校数据分析报告
3、内容中不要提到JSON数据统一称数据
4、尽量以条目列出这样更清晰
5、数据
'''
_data = await db.fetch(sql)
# print(_data)
# 将 asyncpg.Record 转换为 JSON 格式
json_data = json.dumps([dict(record) for record in _data], ensure_ascii=False)
print(json_data) # 打印 JSON 数据
prompt = prompt + json.dumps(json_data, ensure_ascii=False)
# 调用 OpenAI API 生成总结(流式输出)
response = await client.chat.completions.create(
model=QWEN_MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"},
{"role": "user", "content": prompt}
],
max_tokens=3000, # 控制生成内容的长度
temperature=0.7, # 控制生成内容的创造性
stream=True # 启用流式输出
)
# 生成 Word 文档的文件名
uuid_str = str(uuid.uuid4())
filename = f"static/{uuid_str}.docx"
# 定义一个生成器函数,用于逐字返回流式结果
async def generate_stream():
summary = ""
try:
async for chunk in response: # 使用 async for 处理流式响应
if chunk.choices[0].delta.content: # 检查是否有内容
chunk_content = chunk.choices[0].delta.content
# 逐字拆分并返回
for char in chunk_content:
print(char, end="", flush=True) # 逐字输出到控制台
yield char.encode("utf-8") # 将字符编码为 UTF-8 字节
summary += char # 将内容拼接到 summary 中
# 流式传输完成后,生成 Word 文档
markdown_to_docx(summary, output_file=filename)
# 记录到数据库
await db.execute("UPDATE t_bi_question SET docx_file_name = $1 WHERE id = $2", filename, question_id)
except asyncio.CancelledError:
# 客户端提前断开连接,无需处理
print("客户端断开连接")
except Exception as e:
# 如果发生异常,返回错误信息
error_response = json.dumps({
"success": False,
"message": f"生成Word文件失败: {str(e)}"
})
print(error_response) # 输出错误信息到控制台
yield error_response.encode("utf-8") # 将错误信息编码为 UTF-8 字节
finally:
# 确保资源释放
if "response" in locals():
await response.close()
# 使用 StreamingResponse 返回流式结果
return StreamingResponse(
generate_stream(),
media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8
headers={
"Cache-Control": "no-cache", # 禁用缓存
"Content-Type": "text/event-stream; charset=utf-8", # 设置内容类型和字符编码
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲(如果使用 Nginx
}
)
# 返回生成的Word文件下载地址
# http://10.10.21.20:8000/questions/get_docx_file?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
@app.api_route("/questions/get_docx_file", methods=["POST", "GET"])
async def get_docx_file(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数
):
# 根据请求方式获取 question_id
if question_id is not None: # POST 请求
question_id = question_id
elif question_id_get is not None: # GET 请求
question_id = question_id_get
else:
return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询docx_file_name
docx_file_name = get_question_by_id(question_id)[0]['docx_file_name']
# 返回成功和静态文件的URL
return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"}
# 设置问题为系统推荐问题 ,0:取消1设置
@app.post("/questions/set_system_recommend")
def set_system_recommend(question_id: str = Form(...), flag: str = Form(...)):
set_system_recommend_questions(question_id, flag)
# 提示保存成功
return {"success": True, "message": "保存成功"}
# 设置问题为用户收藏问题 ,0:取消1设置
@app.post("/questions/set_user_collect")
def set_user_collect(question_id: str = Form(...), flag: str = Form(...)):
set_user_collect_questions(question_id, flag)
# 提示保存成功
return {"success": True, "message": "保存成功"}
# 查询有哪些系统推荐问题
@app.get("/questions/get_system_recommend")
def get_system_recommend():
# 查询所有系统推荐问题
system_recommend_questions = get_system_recommend_questions()
# 返回查询结果
return {"success": True, "data": system_recommend_questions}
# 查询有哪些用户收藏问题
@app.get("/questions/get_user_collect")
def get_user_collect():
# 查询所有用户收藏问题
user_collect_questions = get_user_collect_questions()
# 返回查询结果
return {"success": True, "data": user_collect_questions}
# 启动 FastAPI
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True, workers=4)