|
|
|
@ -26,6 +26,7 @@ milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=M
|
|
|
|
|
collection_name = MS_COLLECTION_NAME
|
|
|
|
|
collection_manager = MilvusCollectionManager(collection_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 Lifespan Events 处理应用启动和关闭逻辑
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
@ -42,6 +43,7 @@ async def lifespan(app: FastAPI):
|
|
|
|
|
await app.state.mysql_pool.wait_closed()
|
|
|
|
|
logger.info("Milvus 和 MySQL 连接池已关闭。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化 FastAPI 应用
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
|
|
|
@ -51,6 +53,7 @@ client = AsyncOpenAI(
|
|
|
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/reply")
|
|
|
|
|
async def reply(session_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
"""
|
|
|
|
@ -140,8 +143,6 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
await asyncio.to_thread(collection_manager.insert_data, entities)
|
|
|
|
|
logger.info("用户输入和大模型反馈已记录到向量数据库。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 调用 TTS 生成 MP3
|
|
|
|
|
uuid_str = str(uuid.uuid4())
|
|
|
|
|
timestamp = int(time.time())
|
|
|
|
@ -149,7 +150,8 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
|
|
|
|
|
# 生成 TTS 音频数据(不落盘)
|
|
|
|
|
t = TTS(None) # 传入 None 表示不保存到本地文件
|
|
|
|
|
audio_data, duration = await asyncio.to_thread(t.generate_audio, result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
|
|
|
|
|
audio_data, duration = await asyncio.to_thread(t.generate_audio,
|
|
|
|
|
result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
|
|
|
|
|
print(f"音频时长: {duration} 秒")
|
|
|
|
|
|
|
|
|
|
# 将音频数据直接上传到 OSS
|
|
|
|
@ -160,13 +162,14 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file
|
|
|
|
|
|
|
|
|
|
# 记录聊天数据到 MySQL
|
|
|
|
|
await save_chat_to_mysql(app.state.mysql_pool, session_id, prompt, result,url,duration)
|
|
|
|
|
await save_chat_to_mysql(app.state.mysql_pool, session_id, prompt, result, url, duration)
|
|
|
|
|
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"success": True,
|
|
|
|
|
"url": url,
|
|
|
|
|
"search_time": end_time - start_time, # 返回查询耗时
|
|
|
|
|
"duration": duration, # 返回大模型的回复时长
|
|
|
|
|
"response": result # 返回大模型的回复
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
@ -182,9 +185,9 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
# 获取聊天记录
|
|
|
|
|
@app.get("/get_chat_log")
|
|
|
|
|
async def get_chat_log(
|
|
|
|
|
session_id: str,
|
|
|
|
|
page: int = Query(default=1, ge=1, description="当前页码"),
|
|
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数")
|
|
|
|
|
session_id: str,
|
|
|
|
|
page: int = Query(default=1, ge=1, description="当前页码"),
|
|
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数")
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
根据 session_id 查询聊天记录,并按 id 降序分页
|
|
|
|
@ -194,14 +197,15 @@ async def get_chat_log(
|
|
|
|
|
:return: 分页数据
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
result = await get_chat_log_by_session(app.state.mysql_pool,session_id, page, page_size)
|
|
|
|
|
result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page, page_size)
|
|
|
|
|
return result
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"查询聊天记录失败: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"查询聊天记录失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 运行 FastAPI 应用
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import uvicorn
|
|
|
|
|
|
|
|
|
|
uvicorn.run("Start:app", host="0.0.0.0", port=5600, workers=1)
|
|
|
|
|
uvicorn.run("Start:app", host="0.0.0.0", port=5600, workers=1)
|
|
|
|
|