main
HuangHai 4 months ago
parent 8c804a2e8f
commit 9265818e27

@ -0,0 +1,40 @@
import json
from datetime import date, datetime
from asyncpg.pool import Pool
from sqlalchemy.dialects.postgresql import asyncpg
from Config import *
class PostgreSQLUtil:
def __init__(self, pool: Pool):
self.pool = pool
async def execute_query(self, sql, params=None):
async with self.pool.acquire() as connection:
result = await connection.fetch(sql, params)
return result
async def query_to_json(self, sql, params=None):
data = await self.execute_query(sql, params)
return json.dumps(data, default=self.json_serializer)
@staticmethod
def json_serializer(obj):
"""处理JSON无法序列化的类型"""
if isinstance(obj, (date, datetime)):
return obj.isoformat()
raise TypeError(f"Type {type(obj)} not serializable")
async def create_pool():
return await asyncpg.create_pool(
host=PG_HOST,
port=PG_PORT,
database=PG_DATABASE,
user=PG_USER,
password=PG_PASSWORD,
min_size=1,
max_size=10
)

@ -1,14 +1,12 @@
import asyncio
import json
import uuid
from datetime import date, datetime
from asyncpg.pool import Pool
from fastapi import FastAPI, Form, Query
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends
import asyncpg
import uvicorn
from fastapi import FastAPI, Depends
from fastapi import Form, Query
from fastapi.staticfiles import StaticFiles
from openai import AsyncOpenAI
from starlette.responses import StreamingResponse
@ -47,37 +45,6 @@ async def get_db():
async with app.state.pool.acquire() as connection:
yield connection
class PostgreSQLUtil:
def __init__(self, pool: Pool):
self.pool = pool
async def execute_query(self, sql, params=None):
async with self.pool.acquire() as connection:
result = await connection.fetch(sql, params)
return result
async def query_to_json(self, sql, params=None):
data = await self.execute_query(sql, params)
return json.dumps(data, default=self.json_serializer)
@staticmethod
def json_serializer(obj):
"""处理JSON无法序列化的类型"""
if isinstance(obj, (date, datetime)):
return obj.isoformat()
raise TypeError(f"Type {type(obj)} not serializable")
async def create_pool():
return await asyncpg.create_pool(
host=PG_HOST,
port=PG_PORT,
database=PG_DATABASE,
user=PG_USER,
password=PG_PASSWORD,
min_size=1,
max_size=10
)
@app.post("/questions/get_excel")
async def get_excel(question_id: str = Form(...), question_str: str = Form(...), db: asyncpg.Connection = Depends(get_db)):
# 只接受guid号
@ -218,6 +185,58 @@ async def get_docx_stream(
}
)
# 返回生成的Word文件下载地址
# http://10.10.21.20:8000/questions/get_docx_file?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
@app.api_route("/questions/get_docx_file", methods=["POST", "GET"])
async def get_docx_file(
question_id: str = Form(None, description="问题IDPOST请求"), # POST 请求参数
question_id_get: str = Query(None, description="问题IDGET请求"), # GET 请求参数
):
# 根据请求方式获取 question_id
if question_id is not None: # POST 请求
question_id = question_id
elif question_id_get is not None: # GET 请求
question_id = question_id_get
else:
return {"success": False, "message": "缺少问题ID参数"}
# 根据问题ID获取查询docx_file_name
docx_file_name = get_question_by_id(question_id)[0]['docx_file_name']
# 返回成功和静态文件的URL
return {"success": True, "message": "Word文件生成成功", "download_url": f"{docx_file_name}"}
# 设置问题为系统推荐问题 ,0:取消1设置
@app.post("/questions/set_system_recommend")
def set_system_recommend(question_id: str = Form(...), flag: str = Form(...)):
set_system_recommend_questions(question_id, flag)
# 提示保存成功
return {"success": True, "message": "保存成功"}
# 设置问题为用户收藏问题 ,0:取消1设置
@app.post("/questions/set_user_collect")
def set_user_collect(question_id: str = Form(...), flag: str = Form(...)):
set_user_collect_questions(question_id, flag)
# 提示保存成功
return {"success": True, "message": "保存成功"}
# 查询有哪些系统推荐问题
@app.get("/questions/get_system_recommend")
def get_system_recommend():
# 查询所有系统推荐问题
system_recommend_questions = get_system_recommend_questions()
# 返回查询结果
return {"success": True, "data": system_recommend_questions}
# 查询有哪些用户收藏问题
@app.get("/questions/get_user_collect")
def get_user_collect():
# 查询所有用户收藏问题
user_collect_questions = get_user_collect_questions()
# 返回查询结果
return {"success": True, "data": user_collect_questions}
# 启动 FastAPI
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
Loading…
Cancel
Save