|
|
|
@ -1,13 +1,17 @@
|
|
|
|
|
import json
|
|
|
|
|
import uuid
|
|
|
|
|
|
|
|
|
|
import uvicorn # 导入 uvicorn
|
|
|
|
|
from datetime import date, datetime
|
|
|
|
|
from asyncpg.pool import Pool
|
|
|
|
|
from fastapi import FastAPI, Form, Query
|
|
|
|
|
from openai import OpenAI
|
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
from fastapi import FastAPI, Depends
|
|
|
|
|
import asyncpg
|
|
|
|
|
import uvicorn
|
|
|
|
|
from openai import AsyncOpenAI
|
|
|
|
|
from starlette.responses import StreamingResponse
|
|
|
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
|
|
|
|
|
|
from Config import MODEL_API_KEY, MODEL_API_URL, QWEN_MODEL_NAME
|
|
|
|
|
from Config import *
|
|
|
|
|
from Model.biModel import *
|
|
|
|
|
from Text2Sql.Util.MarkdownToDocxUtil import markdown_to_docx
|
|
|
|
|
from Text2Sql.Util.SaveToExcel import save_to_excel
|
|
|
|
@ -15,22 +19,66 @@ from Text2Sql.Util.VannaUtil import VannaUtil
|
|
|
|
|
|
|
|
|
|
# 初始化 FastAPI
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
# 配置静态文件目录
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
|
|
|
|
# 初始化一次vanna的类
|
|
|
|
|
vn = VannaUtil()
|
|
|
|
|
# 初始化 FastAPI 应用
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
|
# 启动时初始化连接池
|
|
|
|
|
app.state.pool = await asyncpg.create_pool(
|
|
|
|
|
host=PG_HOST,
|
|
|
|
|
port=PG_PORT,
|
|
|
|
|
database=PG_DATABASE,
|
|
|
|
|
user=PG_USER,
|
|
|
|
|
password=PG_PASSWORD,
|
|
|
|
|
min_size=1,
|
|
|
|
|
max_size=10
|
|
|
|
|
)
|
|
|
|
|
yield
|
|
|
|
|
# 关闭时释放连接池
|
|
|
|
|
await app.state.pool.close()
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
|
|
|
|
# 依赖注入连接池
|
|
|
|
|
async def get_db():
|
|
|
|
|
async with app.state.pool.acquire() as connection:
|
|
|
|
|
yield connection
|
|
|
|
|
|
|
|
|
|
class PostgreSQLUtil:
|
|
|
|
|
def __init__(self, pool: Pool):
|
|
|
|
|
self.pool = pool
|
|
|
|
|
|
|
|
|
|
async def execute_query(self, sql, params=None):
|
|
|
|
|
async with self.pool.acquire() as connection:
|
|
|
|
|
result = await connection.fetch(sql, params)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
async def query_to_json(self, sql, params=None):
|
|
|
|
|
data = await self.execute_query(sql, params)
|
|
|
|
|
return json.dumps(data, default=self.json_serializer)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def json_serializer(obj):
|
|
|
|
|
"""处理JSON无法序列化的类型"""
|
|
|
|
|
if isinstance(obj, (date, datetime)):
|
|
|
|
|
return obj.isoformat()
|
|
|
|
|
raise TypeError(f"Type {type(obj)} not serializable")
|
|
|
|
|
|
|
|
|
|
async def create_pool():
|
|
|
|
|
return await asyncpg.create_pool(
|
|
|
|
|
host=PG_HOST,
|
|
|
|
|
port=PG_PORT,
|
|
|
|
|
database=PG_DATABASE,
|
|
|
|
|
user=PG_USER,
|
|
|
|
|
password=PG_PASSWORD,
|
|
|
|
|
min_size=1,
|
|
|
|
|
max_size=10
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
|
|
def read_root():
|
|
|
|
|
return {"message": "Welcome to AI SQL World!"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 通过语义生成Excel
|
|
|
|
|
# http://10.10.21.20:8000/questions/get_excel
|
|
|
|
|
@app.post("/questions/get_excel")
|
|
|
|
|
def get_excel(question_id: str = Form(...), question_str: str = Form(...)):
|
|
|
|
|
async def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: asyncpg.Connection = Depends(get_db)):
|
|
|
|
|
# 只接受guid号
|
|
|
|
|
if len(question_id) != 36:
|
|
|
|
|
return {"success": False, "message": "question_id格式错误"}
|
|
|
|
@ -43,34 +91,40 @@ def get_excel(question_id: str = Form(...), question_str: str = Form(...)):
|
|
|
|
|
question = question_str + common_prompt
|
|
|
|
|
|
|
|
|
|
# 先删除后插入,防止重复插入
|
|
|
|
|
delete_question(question_id)
|
|
|
|
|
insert_question(question_id, question)
|
|
|
|
|
await delete_question(db, question_id)
|
|
|
|
|
await insert_question(db, question_id, question)
|
|
|
|
|
|
|
|
|
|
# 获取完整 SQL
|
|
|
|
|
sql = vn.generate_sql(question)
|
|
|
|
|
print("生成的查询 SQL:\n", sql)
|
|
|
|
|
|
|
|
|
|
# 更新question_id
|
|
|
|
|
update_question_by_id(question_id=question_id, sql=sql, state_id=1)
|
|
|
|
|
await update_question_by_id(db, question_id=question_id, sql=sql, state_id=1)
|
|
|
|
|
|
|
|
|
|
# 执行SQL查询
|
|
|
|
|
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
|
|
|
|
|
_data = db.execute_query(sql)
|
|
|
|
|
_data = await db.fetch(sql)
|
|
|
|
|
# 在static目录下,生成一个guid号的临时文件
|
|
|
|
|
uuid_str = str(uuid.uuid4())
|
|
|
|
|
filename = f"static/{uuid_str}.xlsx"
|
|
|
|
|
save_to_excel(_data, filename)
|
|
|
|
|
# 更新EXCEL文件名称
|
|
|
|
|
update_question_by_id(question_id, excel_file_name=filename)
|
|
|
|
|
await update_question_by_id(db, question_id, excel_file_name=filename)
|
|
|
|
|
# 返回静态文件URL
|
|
|
|
|
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuid_str}.xlsx"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# http://10.10.21.20:8000/questions/get_docx_stream?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
|
|
|
|
|
# 初始化 OpenAI 客户端
|
|
|
|
|
client = AsyncOpenAI(
|
|
|
|
|
api_key=MODEL_API_KEY,
|
|
|
|
|
base_url=MODEL_API_URL,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@app.api_route("/questions/get_docx_stream", methods=["POST", "GET"])
|
|
|
|
|
async def get_docx_stream(
|
|
|
|
|
question_id: str = Form(None, description="问题ID(POST请求)"), # POST 请求参数
|
|
|
|
|
question_id_get: str = Query(None, description="问题ID(GET请求)") # GET 请求参数
|
|
|
|
|
question_id_get: str = Query(None, description="问题ID(GET请求)"), # GET 请求参数
|
|
|
|
|
db: asyncpg.Connection = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
# 根据请求方式获取 question_id
|
|
|
|
|
if question_id is not None: # POST 请求
|
|
|
|
@ -81,8 +135,9 @@ async def get_docx_stream(
|
|
|
|
|
return {"success": False, "message": "缺少问题ID参数"}
|
|
|
|
|
|
|
|
|
|
# 根据问题ID获取查询sql
|
|
|
|
|
sql = get_question_by_id(question_id)[0]['sql']
|
|
|
|
|
# 4、生成word报告
|
|
|
|
|
sql = (await db.fetch("SELECT * FROM t_bi_question WHERE id = $1", question_id))[0]['sql']
|
|
|
|
|
|
|
|
|
|
# 生成word报告
|
|
|
|
|
prompt = '''
|
|
|
|
|
请根据以下 JSON 数据,整理出2000字左右的话描述当前数据情况。要求:
|
|
|
|
|
1、以Markdown格式返回,我将直接通过markdown格式生成Word。
|
|
|
|
@ -91,19 +146,15 @@ async def get_docx_stream(
|
|
|
|
|
4、尽量以条目列出,这样更清晰
|
|
|
|
|
5、数据:
|
|
|
|
|
'''
|
|
|
|
|
with PostgreSQLUtil(postgresql_pool.getconn()) as db:
|
|
|
|
|
_data = db.execute_query(sql)
|
|
|
|
|
prompt = prompt + json.dumps(_data, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
# 初始化 OpenAI 客户端
|
|
|
|
|
client = OpenAI(
|
|
|
|
|
api_key=MODEL_API_KEY,
|
|
|
|
|
base_url=MODEL_API_URL,
|
|
|
|
|
)
|
|
|
|
|
_data = await db.fetch(sql)
|
|
|
|
|
#print(_data)
|
|
|
|
|
# 将 asyncpg.Record 转换为 JSON 格式
|
|
|
|
|
json_data = json.dumps([dict(record) for record in _data], ensure_ascii=False)
|
|
|
|
|
print(json_data) # 打印 JSON 数据
|
|
|
|
|
prompt = prompt + json.dumps(json_data, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
# 调用 OpenAI API 生成总结(流式输出)
|
|
|
|
|
response = client.chat.completions.create(
|
|
|
|
|
#model=MODEL_NAME,
|
|
|
|
|
response = await client.chat.completions.create(
|
|
|
|
|
model=QWEN_MODEL_NAME,
|
|
|
|
|
messages=[
|
|
|
|
|
{"role": "system", "content": "你是一个数据分析助手,擅长从 JSON 数据中提取关键信息并生成详细的总结。"},
|
|
|
|
@ -122,7 +173,7 @@ async def get_docx_stream(
|
|
|
|
|
async def generate_stream():
|
|
|
|
|
summary = ""
|
|
|
|
|
try:
|
|
|
|
|
for chunk in response:
|
|
|
|
|
async for chunk in response: # 使用 async for 处理流式响应
|
|
|
|
|
if chunk.choices[0].delta.content: # 检查是否有内容
|
|
|
|
|
chunk_content = chunk.choices[0].delta.content
|
|
|
|
|
# 逐字拆分并返回
|
|
|
|
@ -135,7 +186,7 @@ async def get_docx_stream(
|
|
|
|
|
markdown_to_docx(summary, output_file=filename)
|
|
|
|
|
|
|
|
|
|
# 记录到数据库
|
|
|
|
|
update_question_by_id(question_id, docx_file_name=filename)
|
|
|
|
|
await db.execute("UPDATE t_bi_question SET docx_file_name = $1 WHERE id = $2", filename, question_id)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# 如果发生异常,返回错误信息
|
|
|
|
@ -149,7 +200,7 @@ async def get_docx_stream(
|
|
|
|
|
finally:
|
|
|
|
|
# 确保资源释放
|
|
|
|
|
if "response" in locals():
|
|
|
|
|
response.close()
|
|
|
|
|
await response.aclose()
|
|
|
|
|
|
|
|
|
|
# 使用 StreamingResponse 返回流式结果
|
|
|
|
|
return StreamingResponse(
|
|
|
|
@ -164,60 +215,6 @@ async def get_docx_stream(
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 返回生成的Word文件下载地址
|
|
|
|
|
# http://10.10.21.20:8000/questions/get_docx_file?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
|
|
|
|
|
@app.api_route("/questions/get_docx_file", methods=["POST", "GET"])
|
|
|
|
|
async def get_docx_file(
|
|
|
|
|
question_id: str = Form(None, description="问题ID(POST请求)"), # POST 请求参数
|
|
|
|
|
question_id_get: str = Query(None, description="问题ID(GET请求)"), # GET 请求参数
|
|
|
|
|
):
|
|
|
|
|
# 根据请求方式获取 question_id
|
|
|
|
|
if question_id is not None: # POST 请求
|
|
|
|
|
question_id = question_id
|
|
|
|
|
elif question_id_get is not None: # GET 请求
|
|
|
|
|
question_id = question_id_get
|
|
|
|
|
else:
|
|
|
|
|
return {"success": False, "message": "缺少问题ID参数"}
|
|
|
|
|
|
|
|
|
|
# 根据问题ID获取查询docx_file_name
|
|
|
|
|
docx_file_name = get_question_by_id(question_id)[0]['docx_file_name']
|
|
|
|
|
|
|
|
|
|
# 返回成功和静态文件的URL
|
|
|
|
|
return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 设置问题为系统推荐问题 ,0:取消,1:设置
|
|
|
|
|
@app.post("/questions/set_system_recommend")
|
|
|
|
|
def set_system_recommend(question_id: str = Form(...), flag: str = Form(...)):
|
|
|
|
|
set_system_recommend_questions(question_id, flag)
|
|
|
|
|
# 提示保存成功
|
|
|
|
|
return {"success": True, "message": "保存成功"}
|
|
|
|
|
|
|
|
|
|
# 设置问题为用户收藏问题 ,0:取消,1:设置
|
|
|
|
|
@app.post("/questions/set_user_collect")
|
|
|
|
|
def set_user_collect(question_id: str = Form(...), flag: str = Form(...)):
|
|
|
|
|
set_user_collect_questions(question_id, flag)
|
|
|
|
|
# 提示保存成功
|
|
|
|
|
return {"success": True, "message": "保存成功"}
|
|
|
|
|
|
|
|
|
|
# 查询有哪些系统推荐问题
|
|
|
|
|
@app.get("/questions/get_system_recommend")
|
|
|
|
|
def get_system_recommend():
|
|
|
|
|
# 查询所有系统推荐问题
|
|
|
|
|
system_recommend_questions = get_system_recommend_questions()
|
|
|
|
|
# 返回查询结果
|
|
|
|
|
return {"success": True, "data": system_recommend_questions}
|
|
|
|
|
|
|
|
|
|
# 查询有哪些用户收藏问题
|
|
|
|
|
@app.get("/questions/get_user_collect")
|
|
|
|
|
def get_user_collect():
|
|
|
|
|
# 查询所有用户收藏问题
|
|
|
|
|
user_collect_questions = get_user_collect_questions()
|
|
|
|
|
# 返回查询结果
|
|
|
|
|
return {"success": True, "data": user_collect_questions}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 确保直接运行脚本时启动 FastAPI 应用
|
|
|
|
|
# 启动 FastAPI
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|
|
|
|
|
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|