You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

376 lines
14 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import asyncio
import json
import uuid
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, Depends
from fastapi import Form, Query
from fastapi.staticfiles import StaticFiles
from openai import AsyncOpenAI
from starlette.responses import StreamingResponse
from Config import *
from Model.biModel import *
from Text2Sql.Util.EchartsUtil import generate_chart
from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx
from Text2Sql.Util.SaveToExcel import save_to_excel
from Text2Sql.Util.VannaUtil import VannaUtil
from fastapi.middleware.cors import CORSMiddleware # 导入跨域中间件
vn = VannaUtil()
# 初始化 FastAPI 应用
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化连接池
app.state.pool = await asyncpg.create_pool(
host=PG_HOST,
port=PG_PORT,
database=PG_DATABASE,
user=PG_USER,
password=PG_PASSWORD,
min_size=5,
max_size=20
)
yield
# 关闭时释放连接池
await app.state.pool.close()
# 初始化 FastAPI
app = FastAPI(lifespan=lifespan)
# 添加跨域支持
app.add_middleware(
CORSMiddleware, # 直接使用 CORSMiddleware 类
allow_origins=["*"], # 允许所有来源
allow_credentials=True, # 允许携带凭证(如 cookies
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有 HTTP 头
)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
# 依赖注入连接池
async def get_db():
async with app.state.pool.acquire() as connection:
yield connection
# 初始化 OpenAI 客户端
client = AsyncOpenAI(
api_key=MODEL_API_KEY,
base_url=MODEL_API_URL,
)
@app.post("/questions/get_excel")
async def get_excel(question_id: str = Form(...), question_str: str = Form(...),
db: asyncpg.Connection = Depends(get_db)):
# 只接受guid号
if len(question_id) != 36:
return {"success": False, "message": "question_id格式错误"}
# 检查此question_id 是不是已存在存在的不能再次生成需要全新创建一个新问题ID
if await get_question_by_id(db, question_id):
return {"success": False, "message": "question_id已存在请重新生成"}
common_prompt = '''
返回的信息要求:
1、行政区划为NULL 或者是空字符的不参加统计
2、目标数据库是Postgresql 16
'''
question = question_str + common_prompt
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
# 插入数据
await insert_question(db, question_id, question)
# 检查如果sql为空则返回错误信息
if not sql:
return {"success": False, "message": "无法生成相应的SQL语句"}
# 检查如果SQL无法正确执行返回错误消息
if not await get_data_by_sql(db, sql):
return {"success": False, "message": "无法生成相应的SQL语句"}
# 更新question_id
await update_question_by_id(db, question_id=question_id, sql=sql, state_id=1)
# 执行SQL查询
_data = await db.fetch(sql)
# 在static目录下生成一个guid号的临时文件
uuid_str = str(uuid.uuid4())
filename = f"static/xlsx/{uuid_str}.xlsx"
save_to_excel(_data, filename)
# 更新EXCEL文件名称
await update_question_by_id(db, question_id, excel_file_name=filename)
# 返回静态文件URL
return {"success": True, "message": "Excel文件生成成功", "download_url": filename}
# http://10.10.21.20:8000/questions/get_docx_stream?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
@app.api_route("/questions/get_docx_stream", methods=["POST", "GET"])
async def get_docx_stream(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数
db: asyncpg.Connection = Depends(get_db)
):
# 根据请求方式获取 question_id
if question_id is not None: # POST 请求
question_id = question_id
elif question_id_get is not None: # GET 请求
question_id = question_id_get
else:
return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询sql
sql = (await get_question_by_id(db, question_id))[0]['sql']
# 生成word报告
prompt = '''
请根据以下 JSON 数据整理出2000字左右的话描述当前数据情况。要求
1、以Markdown格式返回我将直接通过markdown格式生成Word。
2、标题统一为长春云校数据分析报告
3、内容中不要提到JSON数据统一称数据
4、尽量以条目列出这样更清晰
5、数据
'''
_data = await db.fetch(sql)
# print(_data)
# 将 asyncpg.Record 转换为 JSON 格式
from decimal import Decimal
# Define a custom function to handle Decimal objects
def decimal_default(obj):
if isinstance(obj, Decimal):
return float(obj) # Convert Decimal to float
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
# In your get_docx_stream function, modify the json.dumps calls:
json_data = json.dumps([dict(record) for record in _data], ensure_ascii=False, default=decimal_default)
prompt = prompt + json_data # No need to call json.dumps again on json_data
# 调用 OpenAI API 生成总结(流式输出)
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"},
{"role": "user", "content": prompt}
],
max_tokens=3000, # 控制生成内容的长度
temperature=0.7, # 控制生成内容的创造性
stream=True # 启用流式输出
)
# 生成 Word 文档的文件名
uuid_str = str(uuid.uuid4())
filename = f"static/docx/{uuid_str}.docx"
# 定义一个生成器函数,用于逐字返回流式结果
async def generate_stream():
summary = ""
try:
# 获取数据库连接
db = await app.state.pool.acquire()
async for chunk in response:
if chunk.choices[0].delta.content:
chunk_content = chunk.choices[0].delta.content
for char in chunk_content:
print(char, end="", flush=True)
yield char.encode("utf-8")
summary += char
# 流式传输完成后,生成 Word 文档
markdown_to_docx(summary, output_file=filename)
# 记录到数据库
await update_question_by_id(db, question_id, docx_file_name=filename)
except asyncio.CancelledError:
# 客户端提前断开连接,无需处理
print("客户端断开连接")
except Exception as e:
error_response = json.dumps({
"success": False,
"message": f"生成Word文件失败: {str(e)}"
})
print(error_response)
yield error_response.encode("utf-8")
finally:
# 确保资源释放
if "response" in locals():
await response.close()
if "db" in locals():
await app.state.pool.release(db) # 释放连接回连接池
# 使用 StreamingResponse 返回流式结果
return StreamingResponse(
generate_stream(),
media_type="text/plain; charset=utf-8", # 明确指定字符编码为 UTF-8
headers={
"Cache-Control": "no-cache", # 禁用缓存
"Content-Type": "text/event-stream; charset=utf-8", # 设置内容类型和字符编码
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲(如果使用 Nginx
}
)
# 返回生成的Word文件下载地址
# http://10.10.21.20:8000/questions/get_docx_file?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
@app.api_route("/questions/get_docx_file", methods=["POST", "GET"])
async def get_docx_file(
question_id: str = Form(None, description="问题ID(POST请求)"), # POST 请求参数
question_id_get: str = Query(None, description="问题ID(GET请求)"), # GET 请求参数
db: asyncpg.Connection = Depends(get_db) # 添加 db 参数
):
# 根据请求方式获取 question_id
if question_id is not None: # POST 请求
question_id = question_id
elif question_id_get is not None: # GET 请求
question_id = question_id_get
else:
return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询docx_file_name
docx_file_name = (await get_question_by_id(db, question_id))[0]['docx_file_name']
# 返回成功和静态文件的URL
return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"}
# 设置问题为系统推荐问题 ,0:取消1设置
@app.post("/questions/set_system_recommend")
async def set_system_recommend(
question_id: str = Form(...),
flag: str = Form(...),
db: asyncpg.Connection = Depends(get_db) # 添加 db 参数
):
await set_system_recommend_questions(db, question_id, flag)
# 提示保存成功
return {"success": True, "message": "保存成功"}
@app.post("/questions/set_user_collect")
async def set_user_collect(
question_id: str = Form(...),
flag: str = Form(...),
db: asyncpg.Connection = Depends(get_db) # 添加 db 参数
):
await set_user_collect_questions(db, question_id, flag)
# 提示保存成功
return {"success": True, "message": "保存成功"}
@app.get("/questions/get_system_recommend")
async def get_system_recommend(
db: asyncpg.Connection = Depends(get_db), # 数据库连接
page: int = Query(1, description="当前页码,默认为 1"), # 页码参数
page_size: int = Query(10, description="每页数据量,默认为 10") # 每页数据量参数
):
# 计算 OFFSET
offset = (page - 1) * page_size
# 查询分页后的系统推荐问题
system_recommend_questions = await get_system_recommend_questions(db, offset, page_size)
# 查询总数据量
total = await get_system_recommend_questions_count(db)
# 返回查询结果
return {
"success": True,
"data": system_recommend_questions,
"pagination": {
"page": page,
"page_size": page_size,
"total": total # 总数据量
}
}
@app.get("/questions/get_user_questions")
async def get_user_questions(
db: asyncpg.Connection = Depends(get_db), # 数据库连接
page: int = Query(1, description="当前页码,默认为 1"), # 页码参数
page_size: int = Query(10, description="每页数据量,默认为 10"), # 每页数据量参数
type_id: int = Query(1, description="1用户收藏的问题0全部")
):
# 计算 OFFSET
offset = (page - 1) * page_size
# 查询分页后的用户收藏问题
user_questions = await get_user_publish_questions(db, type_id, offset, page_size)
# 查询总数据量
total = await get_user_publish_questions_count(db)
# 返回查询结果
return {
"success": True,
"data": user_questions,
"pagination": {
"page": page,
"page_size": page_size,
"total": total # 总数据量
}
}
# 获取数据库字段名
@app.get("/questions/get_data_scheme_by_id")
async def get_data_scheme_by_id(
question_id: str = Query(None, description="问题ID(GET请求)"),
db: asyncpg.Connection = Depends(get_db)
):
sql = (await get_question_by_id(db, question_id))[0]['sql']
column_names = (await get_column_names(db, sql))
return {"success": True, "data": column_names}
# 生成图表
@app.post("/questions/get_chart")
async def get_chart(
question_id: str = Form(None, description="问题IDPOST请求"),
title: str = Form(None, description="图表标题POST请求"),
db: asyncpg.Connection = Depends(get_db)
):
# 根据问题ID获取查询SQL
sql = (await get_question_by_id(db, question_id))[0]['sql']
# 执行SQL获取数据集
_data = await get_data_by_sql(db, sql)
# 调用函数
category_columns_str, value_column_str = generate_columns_with_ai(_data)
print(f"category_columns_str: {category_columns_str}")
print(f"value_column_str: {value_column_str}")
# 图表文件名称
uuid_str = str(uuid.uuid4())
filename = f"static/html/{uuid_str}.html"
# 根据图表类型生成图表
generate_chart(
_data=_data,
title=title,
category_columns=category_columns_str.split(","),
value_column=value_column_str,
output_file=filename
)
# 返回静态文件URL
return {"success": True, "message": "图表文件生成成功", "url": filename}
# 启动 FastAPI
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, workers=4)