import asyncio import logging import time import uuid from contextlib import asynccontextmanager from fastapi import FastAPI, Form, HTTPException from openai import AsyncOpenAI from WxMini.Milvus.Config.MulvusConfig import * from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager from WxMini.Milvus.Utils.MilvusConnectionPool import * from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory from WxMini.Utils.TtsUtil import TTS from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql from WxMini.Utils.EmbeddingUtil import text_to_embedding # 配置日志 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # 初始化 Milvus 连接池 milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS) # 初始化集合管理器 collection_name = MS_COLLECTION_NAME collection_manager = MilvusCollectionManager(collection_name) # 使用 Lifespan Events 处理应用启动和关闭逻辑 @asynccontextmanager async def lifespan(app: FastAPI): # 应用启动时加载集合到内存 collection_manager.load_collection() logger.info(f"集合 '{collection_name}' 已加载到内存。") # 初始化 MySQL 连接池 app.state.mysql_pool = await init_mysql_pool() logger.info("MySQL 连接池已初始化。") yield # 应用关闭时释放连接池 milvus_pool.close() app.state.mysql_pool.close() await app.state.mysql_pool.wait_closed() logger.info("Milvus 和 MySQL 连接池已关闭。") # 初始化 FastAPI 应用 app = FastAPI(lifespan=lifespan) # 初始化异步 OpenAI 客户端 client = AsyncOpenAI( api_key=MODEL_API_KEY, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) @app.post("/reply") async def reply(session_id: str = Form(...), prompt: str = Form(...)): """ 接收用户输入的 prompt,调用大模型并返回结果 :param session_id: 用户会话 ID :param prompt: 用户输入的 prompt :return: 大模型的回复 """ try: logger.info(f"收到用户输入: {prompt}") # 从连接池中获取一个连接 connection = milvus_pool.get_connection() # 将用户输入转换为嵌入向量 current_embedding = text_to_embedding(prompt) # 查询与当前对话最相关的历史交互 search_params = { "metric_type": "L2", # 使用 L2 距离度量方式 "params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数 } start_time = time.time() results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行 collection_manager.search, data=current_embedding, # 输入向量 search_params=search_params, # 搜索参数 expr=f"session_id == '{session_id}'", # 按 session_id 过滤 limit=5 # 返回 5 条结果 ) end_time = time.time() # 构建历史交互提示词 history_prompt = "" if results: for hits in results: for hit in hits: try: # 查询非向量字段 record = await asyncio.to_thread(collection_manager.query_by_id, hit.id) if record: logger.info(f"查询到的记录: {record}") # 添加历史交互 history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n" except Exception as e: logger.error(f"查询失败: {e}") # 限制历史交互提示词长度 history_prompt = history_prompt[:2000] logger.info(f"历史交互提示词: {history_prompt}") # 调用大模型,将历史交互作为提示词 try: response = await asyncio.wait_for( client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": "你是一个私人助理,负责回答用户的问题。请根据用户的历史对话和当前问题,提供准确且简洁的回答。不要提及你是通义千问或其他无关信息,也不可以回复与本次用户问题不相关的历史对话记录内容。"}, {"role": "user", "content": f"历史对话记录:{history_prompt},本次用户问题: {prompt}"} ], max_tokens=500 ), timeout=60 # 设置超时时间为 60 秒 ) except asyncio.TimeoutError: logger.error("大模型调用超时") raise HTTPException(status_code=500, detail="大模型调用超时") # 提取生成的回复 if response.choices and response.choices[0].message.content: result = response.choices[0].message.content.strip() logger.info(f"大模型回复: {result}") # 记录用户输入和大模型反馈到向量数据库 timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) entities = [ [session_id], # session_id [prompt[:500]], # user_input,截断到 500 字符 [result[:500]], # model_response,截断到 500 字符 [timestamp], # timestamp [current_embedding] # embedding ] if len(prompt) > 500: logger.warning(f"用户输入被截断,原始长度: {len(prompt)}") if len(result) > 500: logger.warning(f"大模型回复被截断,原始长度: {len(result)}") await asyncio.to_thread(collection_manager.insert_data, entities) logger.info("用户输入和大模型反馈已记录到向量数据库。") # 记录聊天数据到 MySQL await save_chat_to_mysql(app.state.mysql_pool, session_id, prompt, result) logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。") # 调用 TTS 生成 MP3 uuid_str = str(uuid.uuid4()) timestamp = int(time.time()) tts_file = f"audio/{uuid_str}_{timestamp}.mp3" # 生成 TTS 音频数据(不落盘) t = TTS(None) # 传入 None 表示不保存到本地文件 audio_data = await asyncio.to_thread(t.generate_audio, result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据 # 将音频数据直接上传到 OSS await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data) logger.info(f"TTS 文件已直接上传到 OSS: {tts_file}") # 完整的 URL url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file return { "success": True, "url": url, "search_time": end_time - start_time, # 返回查询耗时 "response": result # 返回大模型的回复 } else: raise HTTPException(status_code=500, detail="大模型未返回有效结果") except Exception as e: logger.error(f"调用大模型失败: {str(e)}") raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}") finally: # 释放连接 milvus_pool.release_connection(connection) # 运行 FastAPI 应用 if __name__ == "__main__": import uvicorn uvicorn.run("Start:app", host="0.0.0.0", port=5600, workers=1)