diff --git a/AI/Text2Sql/Util/PostgreSQLUtil.py b/AI/Text2Sql/Util/PostgreSQLUtil.py new file mode 100644 index 00000000..e896e6c9 --- /dev/null +++ b/AI/Text2Sql/Util/PostgreSQLUtil.py @@ -0,0 +1,40 @@ +import json +from datetime import date, datetime + +from asyncpg.pool import Pool +from sqlalchemy.dialects.postgresql import asyncpg + +from Config import * + + +class PostgreSQLUtil: + def __init__(self, pool: Pool): + self.pool = pool + + async def execute_query(self, sql, params=None): + async with self.pool.acquire() as connection: + result = await connection.fetch(sql, params) + return result + + async def query_to_json(self, sql, params=None): + data = await self.execute_query(sql, params) + return json.dumps(data, default=self.json_serializer) + + @staticmethod + def json_serializer(obj): + """处理JSON无法序列化的类型""" + if isinstance(obj, (date, datetime)): + return obj.isoformat() + raise TypeError(f"Type {type(obj)} not serializable") + +async def create_pool(): + return await asyncpg.create_pool( + host=PG_HOST, + port=PG_PORT, + database=PG_DATABASE, + user=PG_USER, + password=PG_PASSWORD, + min_size=1, + max_size=10 + ) + diff --git a/AI/Text2Sql/__pycache__/app.cpython-310.pyc b/AI/Text2Sql/__pycache__/app.cpython-310.pyc index dd63c21a..d95a34ee 100644 Binary files a/AI/Text2Sql/__pycache__/app.cpython-310.pyc and b/AI/Text2Sql/__pycache__/app.cpython-310.pyc differ diff --git a/AI/Text2Sql/app.py b/AI/Text2Sql/app.py index fba6225f..e4928ab5 100644 --- a/AI/Text2Sql/app.py +++ b/AI/Text2Sql/app.py @@ -1,14 +1,12 @@ import asyncio import json import uuid -from datetime import date, datetime -from asyncpg.pool import Pool -from fastapi import FastAPI, Form, Query -from fastapi.staticfiles import StaticFiles from contextlib import asynccontextmanager -from fastapi import FastAPI, Depends -import asyncpg + import uvicorn +from fastapi import FastAPI, Depends +from fastapi import Form, Query +from fastapi.staticfiles import StaticFiles from openai import AsyncOpenAI from starlette.responses import StreamingResponse @@ -47,37 +45,6 @@ async def get_db(): async with app.state.pool.acquire() as connection: yield connection -class PostgreSQLUtil: - def __init__(self, pool: Pool): - self.pool = pool - - async def execute_query(self, sql, params=None): - async with self.pool.acquire() as connection: - result = await connection.fetch(sql, params) - return result - - async def query_to_json(self, sql, params=None): - data = await self.execute_query(sql, params) - return json.dumps(data, default=self.json_serializer) - - @staticmethod - def json_serializer(obj): - """处理JSON无法序列化的类型""" - if isinstance(obj, (date, datetime)): - return obj.isoformat() - raise TypeError(f"Type {type(obj)} not serializable") - -async def create_pool(): - return await asyncpg.create_pool( - host=PG_HOST, - port=PG_PORT, - database=PG_DATABASE, - user=PG_USER, - password=PG_PASSWORD, - min_size=1, - max_size=10 - ) - @app.post("/questions/get_excel") async def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: asyncpg.Connection = Depends(get_db)): # 只接受guid号 @@ -218,6 +185,58 @@ async def get_docx_stream( } ) + +# 返回生成的Word文件下载地址 +# http://10.10.21.20:8000/questions/get_docx_file?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443 +@app.api_route("/questions/get_docx_file", methods=["POST", "GET"]) +async def get_docx_file( + question_id: str = Form(None, description="问题ID(POST请求)"), # POST 请求参数 + question_id_get: str = Query(None, description="问题ID(GET请求)"), # GET 请求参数 +): + # 根据请求方式获取 question_id + if question_id is not None: # POST 请求 + question_id = question_id + elif question_id_get is not None: # GET 请求 + question_id = question_id_get + else: + return {"success": False, "message": "缺少问题ID参数"} + + # 根据问题ID获取查询docx_file_name + docx_file_name = get_question_by_id(question_id)[0]['docx_file_name'] + + # 返回成功和静态文件的URL + return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"} + + +# 设置问题为系统推荐问题 ,0:取消,1:设置 +@app.post("/questions/set_system_recommend") +def set_system_recommend(question_id: str = Form(...), flag: str = Form(...)): + set_system_recommend_questions(question_id, flag) + # 提示保存成功 + return {"success": True, "message": "保存成功"} + +# 设置问题为用户收藏问题 ,0:取消,1:设置 +@app.post("/questions/set_user_collect") +def set_user_collect(question_id: str = Form(...), flag: str = Form(...)): + set_user_collect_questions(question_id, flag) + # 提示保存成功 + return {"success": True, "message": "保存成功"} + +# 查询有哪些系统推荐问题 +@app.get("/questions/get_system_recommend") +def get_system_recommend(): + # 查询所有系统推荐问题 + system_recommend_questions = get_system_recommend_questions() + # 返回查询结果 + return {"success": True, "data": system_recommend_questions} + +# 查询有哪些用户收藏问题 +@app.get("/questions/get_user_collect") +def get_user_collect(): + # 查询所有用户收藏问题 + user_collect_questions = get_user_collect_questions() + # 返回查询结果 + return {"success": True, "data": user_collect_questions} # 启动 FastAPI if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file