import os import uuid import time import jieba from fastapi import FastAPI, Form, HTTPException from openai import AsyncOpenAI # 使用异步客户端 from gensim.models import KeyedVectors from contextlib import asynccontextmanager from TtsConfig import * from WxMini.OssUtil import upload_mp3_to_oss from WxMini.TtsUtil import TTS from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager from WxMini.Milvus.Utils.MilvusConnectionPool import * from WxMini.Milvus.Config.MulvusConfig import * import asyncio # 引入异步支持 import logging # 增加日志记录 import jieba.analyse # 配置日志 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # 提取用户输入的关键词 def extract_keywords(text, topK=3): """ 提取用户输入的关键词 :param text: 用户输入的文本 :param topK: 返回的关键词数量 :return: 关键词列表 """ keywords = jieba.analyse.extract_tags(text, topK=topK) return keywords # 初始化 Word2Vec 模型 model_path = MS_MODEL_PATH model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT) logger.info(f"模型加载成功,词向量维度: {model.vector_size}") # 初始化 Milvus 连接池 milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS) # 初始化集合管理器 collection_name = MS_COLLECTION_NAME collection_manager = MilvusCollectionManager(collection_name) # 将文本转换为嵌入向量 def text_to_embedding(text): words = jieba.lcut(text) # 使用 jieba 分词 logger.info(f"文本: {text}, 分词结果: {words}") embeddings = [model[word] for word in words if word in model] logger.info(f"有效词向量数量: {len(embeddings)}") if embeddings: avg_embedding = sum(embeddings) / len(embeddings) logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维 return avg_embedding else: logger.warning("未找到有效词,返回零向量") return [0.0] * model.vector_size # 使用 Lifespan Events 处理应用启动和关闭逻辑 @asynccontextmanager async def lifespan(app: FastAPI): # 应用启动时加载集合到内存 collection_manager.load_collection() logger.info(f"集合 '{collection_name}' 已加载到内存。") yield # 应用关闭时释放连接池 milvus_pool.close() logger.info("Milvus 连接池已关闭。") # 初始化 FastAPI 应用 app = FastAPI(lifespan=lifespan) # 初始化异步 OpenAI 客户端 client = AsyncOpenAI( api_key=MODEL_API_KEY, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) @app.post("/reply") async def reply(session_id: str = Form(...), prompt: str = Form(...)): """ 接收用户输入的 prompt,调用大模型并返回结果 :param session_id: 用户会话 ID :param prompt: 用户输入的 prompt :return: 大模型的回复 """ try: logger.info(f"收到用户输入: {prompt}") # 从连接池中获取一个连接 connection = milvus_pool.get_connection() # 将用户输入转换为嵌入向量 current_embedding = text_to_embedding(prompt) # 查询与当前对话最相关的历史交互 search_params = { "metric_type": "L2", # 使用 L2 距离度量方式 "params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数 } start_time = time.time() results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行 collection_manager.search, data=current_embedding, # 输入向量 search_params=search_params, # 搜索参数 expr=f"session_id == '{session_id}'", # 按 session_id 过滤 limit=5 # 返回 5 条结果 ) end_time = time.time() # 构建历史交互提示词 history_prompt = "" if results: for hits in results: for hit in hits: try: # 查询非向量字段 record = await asyncio.to_thread(collection_manager.query_by_id, hit.id) if record: logger.info(f"查询到的记录: {record}") # 添加历史交互 history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n" except Exception as e: logger.error(f"查询失败: {e}") logger.info(f"历史交互提示词: {history_prompt}") # 调用大模型,将历史交互作为提示词 response = await client.chat.completions.create( # 使用异步调用 model=MODEL_NAME, messages=[ {"role": "system", "content": "你是一个私人助理,负责回答用户的问题。请根据用户的历史对话和当前问题,提供准确且简洁的回答。不要提及你是通义千问或其他无关信息,也不可以回复与本次用户问题不相关的历史对话记录内容。"}, {"role": "user", "content": f"历史对话记录:{history_prompt},本次用户问题: {prompt}"} # 将历史交互和当前输入一起发送 ], max_tokens=500 ) # 提取生成的回复 if response.choices and response.choices[0].message.content: result = response.choices[0].message.content.strip() logger.info(f"大模型回复: {result}") # 记录用户输入和大模型反馈到向量数据库 timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) entities = [ [session_id], # session_id [prompt[:500]], # user_input,截断到 500 字符 [result[:500]], # model_response,截断到 500 字符 [timestamp], # timestamp [current_embedding] # embedding ] await asyncio.to_thread(collection_manager.insert_data, entities) logger.info("用户输入和大模型反馈已记录到向量数据库。") # 调用 TTS 生成 MP3 uuid_str = str(uuid.uuid4()) tts_file = "audio/" + uuid_str + ".mp3" t = TTS(tts_file) await asyncio.to_thread(t.start, result) # 将 TTS 生成放到线程池中执行 # 文件上传到 OSS await asyncio.to_thread(upload_mp3_to_oss, tts_file, tts_file) # 删除临时文件 try: os.remove(tts_file) logger.info(f"临时文件 {tts_file} 已删除") except Exception as e: logger.error(f"删除临时文件失败: {e}") # 完整的 URL url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file return { "success": True, "url": url, "search_time": end_time - start_time, # 返回查询耗时 "response": result # 返回大模型的回复 } else: raise HTTPException(status_code=500, detail="大模型未返回有效结果") except Exception as e: logger.error(f"调用大模型失败: {str(e)}") raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}") finally: # 释放连接 milvus_pool.release_connection(connection) # 运行 FastAPI 应用 if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=5600)