diff --git a/AI/WxMini/BigStart.py b/AI/WxMini/BigStart.py new file mode 100644 index 00000000..2246e935 --- /dev/null +++ b/AI/WxMini/BigStart.py @@ -0,0 +1,190 @@ +import os +import uuid +import time +import jieba +from fastapi import FastAPI, Form, HTTPException +from openai import AsyncOpenAI # 使用异步客户端 +from gensim.models import KeyedVectors +from contextlib import asynccontextmanager +from TtsConfig import * +from WxMini.OssUtil import upload_mp3_to_oss +from WxMini.TtsUtil import TTS +from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager +from WxMini.Milvus.Utils.MilvusConnectionPool import * +from WxMini.Milvus.Config.MulvusConfig import * +import asyncio # 引入异步支持 +import logging # 增加日志记录 + +import jieba.analyse + +# 配置日志 +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# 提取用户输入的关键词 +def extract_keywords(text, topK=3): + """ + 提取用户输入的关键词 + :param text: 用户输入的文本 + :param topK: 返回的关键词数量 + :return: 关键词列表 + """ + keywords = jieba.analyse.extract_tags(text, topK=topK) + return keywords + +# 初始化 Word2Vec 模型 +model_path = MS_MODEL_PATH +model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT) +logger.info(f"模型加载成功,词向量维度: {model.vector_size}") + +# 初始化 Milvus 连接池 +milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS) + +# 初始化集合管理器 +collection_name = MS_COLLECTION_NAME +collection_manager = MilvusCollectionManager(collection_name) + +# 将文本转换为嵌入向量 +def text_to_embedding(text): + words = jieba.lcut(text) # 使用 jieba 分词 + logger.info(f"文本: {text}, 分词结果: {words}") + embeddings = [model[word] for word in words if word in model] + logger.info(f"有效词向量数量: {len(embeddings)}") + if embeddings: + avg_embedding = sum(embeddings) / len(embeddings) + logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维 + return avg_embedding + else: + logger.warning("未找到有效词,返回零向量") + return [0.0] * model.vector_size + +# 使用 Lifespan Events 处理应用启动和关闭逻辑 +@asynccontextmanager +async def lifespan(app: FastAPI): + # 应用启动时加载集合到内存 + collection_manager.load_collection() + logger.info(f"集合 '{collection_name}' 已加载到内存。") + yield + # 应用关闭时释放连接池 + milvus_pool.close() + logger.info("Milvus 连接池已关闭。") + +# 初始化 FastAPI 应用 +app = FastAPI(lifespan=lifespan) + +# 初始化异步 OpenAI 客户端 +client = AsyncOpenAI( + api_key=MODEL_API_KEY, + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", +) + +@app.post("/reply") +async def reply(session_id: str = Form(...), prompt: str = Form(...)): + """ + 接收用户输入的 prompt,调用大模型并返回结果 + :param session_id: 用户会话 ID + :param prompt: 用户输入的 prompt + :return: 大模型的回复 + """ + try: + logger.info(f"收到用户输入: {prompt}") + # 从连接池中获取一个连接 + connection = milvus_pool.get_connection() + + # 将用户输入转换为嵌入向量 + current_embedding = text_to_embedding(prompt) + + # 查询与当前对话最相关的历史交互 + search_params = { + "metric_type": "L2", # 使用 L2 距离度量方式 + "params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数 + } + start_time = time.time() + results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行 + collection_manager.search, + data=current_embedding, # 输入向量 + search_params=search_params, # 搜索参数 + expr=f"session_id == '{session_id}'", # 按 session_id 过滤 + limit=5 # 返回 5 条结果 + ) + end_time = time.time() + + # 构建历史交互提示词 + history_prompt = "" + if results: + for hits in results: + for hit in hits: + try: + # 查询非向量字段 + record = await asyncio.to_thread(collection_manager.query_by_id, hit.id) + if record: + logger.info(f"查询到的记录: {record}") + # 添加历史交互 + history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n" + except Exception as e: + logger.error(f"查询失败: {e}") + + logger.info(f"历史交互提示词: {history_prompt}") + + # 调用大模型,将历史交互作为提示词 + response = await client.chat.completions.create( # 使用异步调用 + model=MODEL_NAME, + messages=[ + {"role": "system", "content": "你是一个私人助理,负责回答用户的问题。请根据用户的历史对话和当前问题,提供准确且简洁的回答。不要提及你是通义千问或其他无关信息,也不可以回复与本次用户问题不相关的历史对话记录内容。"}, + {"role": "user", "content": f"历史对话记录:{history_prompt},本次用户问题: {prompt}"} # 将历史交互和当前输入一起发送 + ], + max_tokens=500 + ) + + # 提取生成的回复 + if response.choices and response.choices[0].message.content: + result = response.choices[0].message.content.strip() + logger.info(f"大模型回复: {result}") + + # 记录用户输入和大模型反馈到向量数据库 + timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + entities = [ + [session_id], # session_id + [prompt[:500]], # user_input,截断到 500 字符 + [result[:500]], # model_response,截断到 500 字符 + [timestamp], # timestamp + [current_embedding] # embedding + ] + await asyncio.to_thread(collection_manager.insert_data, entities) + logger.info("用户输入和大模型反馈已记录到向量数据库。") + + # 调用 TTS 生成 MP3 + uuid_str = str(uuid.uuid4()) + tts_file = "audio/" + uuid_str + ".mp3" + t = TTS(tts_file) + await asyncio.to_thread(t.start, result) # 将 TTS 生成放到线程池中执行 + # 文件上传到 OSS + await asyncio.to_thread(upload_mp3_to_oss, tts_file, tts_file) + # 删除临时文件 + try: + os.remove(tts_file) + logger.info(f"临时文件 {tts_file} 已删除") + except Exception as e: + logger.error(f"删除临时文件失败: {e}") + # 完整的 URL + url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file + return { + "success": True, + "url": url, + "search_time": end_time - start_time, # 返回查询耗时 + "response": result # 返回大模型的回复 + } + else: + raise HTTPException(status_code=500, detail="大模型未返回有效结果") + except Exception as e: + logger.error(f"调用大模型失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}") + finally: + # 释放连接 + milvus_pool.release_connection(connection) + +# 运行 FastAPI 应用 +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=5600) \ No newline at end of file diff --git a/AI/WxMini/Milvus/T1_create_collection.py b/AI/WxMini/Milvus/Test/T1_create_collection.py similarity index 100% rename from AI/WxMini/Milvus/T1_create_collection.py rename to AI/WxMini/Milvus/Test/T1_create_collection.py diff --git a/AI/WxMini/Milvus/T2_create_index.py b/AI/WxMini/Milvus/Test/T2_create_index.py similarity index 100% rename from AI/WxMini/Milvus/T2_create_index.py rename to AI/WxMini/Milvus/Test/T2_create_index.py diff --git a/AI/WxMini/Milvus/T3_insert_data.py b/AI/WxMini/Milvus/Test/T3_insert_data.py similarity index 100% rename from AI/WxMini/Milvus/T3_insert_data.py rename to AI/WxMini/Milvus/Test/T3_insert_data.py diff --git a/AI/WxMini/Milvus/T4_select_all_data.py b/AI/WxMini/Milvus/Test/T4_select_all_data.py similarity index 100% rename from AI/WxMini/Milvus/T4_select_all_data.py rename to AI/WxMini/Milvus/Test/T4_select_all_data.py diff --git a/AI/WxMini/Milvus/T5_search_near_data.py b/AI/WxMini/Milvus/Test/T5_search_near_data.py similarity index 100% rename from AI/WxMini/Milvus/T5_search_near_data.py rename to AI/WxMini/Milvus/Test/T5_search_near_data.py diff --git a/AI/WxMini/Milvus/Test/__init__.py b/AI/WxMini/Milvus/Test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/AI/WxMini/Milvus/X1_create_collection.py b/AI/WxMini/Milvus/X1_create_collection.py index 48de3b46..5de6417d 100644 --- a/AI/WxMini/Milvus/X1_create_collection.py +++ b/AI/WxMini/Milvus/X1_create_collection.py @@ -23,8 +23,8 @@ if utility.has_collection(collection_name): fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), # 主键字段,自动生成 ID FieldSchema(name="session_id", dtype=DataType.VARCHAR, max_length=64), # 会话 ID - FieldSchema(name="user_input", dtype=DataType.VARCHAR, max_length=500), # 用户问题 - FieldSchema(name="model_response", dtype=DataType.VARCHAR, max_length=500), # 大模型反馈结果 + FieldSchema(name="user_input", dtype=DataType.VARCHAR, max_length=2048), # 用户问题 + FieldSchema(name="model_response", dtype=DataType.VARCHAR, max_length=2048), # 大模型反馈结果 FieldSchema(name="timestamp", dtype=DataType.VARCHAR, max_length=32), # 时间 FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=MS_DIMENSION) # 向量字段,维度为 200 ]