You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

190 lines
7.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import uuid
import time
import jieba
from fastapi import FastAPI, Form, HTTPException
from openai import AsyncOpenAI # 使用异步客户端
from gensim.models import KeyedVectors
from contextlib import asynccontextmanager
from TtsConfig import *
from WxMini.OssUtil import upload_mp3_to_oss
from WxMini.TtsUtil import TTS
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Milvus.Config.MulvusConfig import *
import asyncio # 引入异步支持
import logging # 增加日志记录
import jieba.analyse
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 提取用户输入的关键词
def extract_keywords(text, topK=3):
"""
提取用户输入的关键词
:param text: 用户输入的文本
:param topK: 返回的关键词数量
:return: 关键词列表
"""
keywords = jieba.analyse.extract_tags(text, topK=topK)
return keywords
# 初始化 Word2Vec 模型
model_path = MS_MODEL_PATH
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT)
logger.info(f"模型加载成功,词向量维度: {model.vector_size}")
# 初始化 Milvus 连接池
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
# 初始化集合管理器
collection_name = MS_COLLECTION_NAME
collection_manager = MilvusCollectionManager(collection_name)
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = jieba.lcut(text) # 使用 jieba 分词
logger.info(f"文本: {text}, 分词结果: {words}")
embeddings = [model[word] for word in words if word in model]
logger.info(f"有效词向量数量: {len(embeddings)}")
if embeddings:
avg_embedding = sum(embeddings) / len(embeddings)
logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维
return avg_embedding
else:
logger.warning("未找到有效词,返回零向量")
return [0.0] * model.vector_size
# 使用 Lifespan Events 处理应用启动和关闭逻辑
@asynccontextmanager
async def lifespan(app: FastAPI):
# 应用启动时加载集合到内存
collection_manager.load_collection()
logger.info(f"集合 '{collection_name}' 已加载到内存。")
yield
# 应用关闭时释放连接池
milvus_pool.close()
logger.info("Milvus 连接池已关闭。")
# 初始化 FastAPI 应用
app = FastAPI(lifespan=lifespan)
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
@app.post("/reply")
async def reply(session_id: str = Form(...), prompt: str = Form(...)):
"""
接收用户输入的 prompt调用大模型并返回结果
:param session_id: 用户会话 ID
:param prompt: 用户输入的 prompt
:return: 大模型的回复
"""
try:
logger.info(f"收到用户输入: {prompt}")
# 从连接池中获取一个连接
connection = milvus_pool.get_connection()
# 将用户输入转换为嵌入向量
current_embedding = text_to_embedding(prompt)
# 查询与当前对话最相关的历史交互
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
start_time = time.time()
results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行
collection_manager.search,
data=current_embedding, # 输入向量
search_params=search_params, # 搜索参数
expr=f"session_id == '{session_id}'", # 按 session_id 过滤
limit=5 # 返回 5 条结果
)
end_time = time.time()
# 构建历史交互提示词
history_prompt = ""
if results:
for hits in results:
for hit in hits:
try:
# 查询非向量字段
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
if record:
logger.info(f"查询到的记录: {record}")
# 添加历史交互
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
logger.error(f"查询失败: {e}")
logger.info(f"历史交互提示词: {history_prompt}")
# 调用大模型,将历史交互作为提示词
response = await client.chat.completions.create( # 使用异步调用
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个私人助理,负责回答用户的问题。请根据用户的历史对话和当前问题,提供准确且简洁的回答。不要提及你是通义千问或其他无关信息,也不可以回复与本次用户问题不相关的历史对话记录内容。"},
{"role": "user", "content": f"历史对话记录:{history_prompt},本次用户问题: {prompt}"} # 将历史交互和当前输入一起发送
],
max_tokens=500
)
# 提取生成的回复
if response.choices and response.choices[0].message.content:
result = response.choices[0].message.content.strip()
logger.info(f"大模型回复: {result}")
# 记录用户输入和大模型反馈到向量数据库
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
entities = [
[session_id], # session_id
[prompt[:500]], # user_input截断到 500 字符
[result[:500]], # model_response截断到 500 字符
[timestamp], # timestamp
[current_embedding] # embedding
]
await asyncio.to_thread(collection_manager.insert_data, entities)
logger.info("用户输入和大模型反馈已记录到向量数据库。")
# 调用 TTS 生成 MP3
uuid_str = str(uuid.uuid4())
tts_file = "audio/" + uuid_str + ".mp3"
t = TTS(tts_file)
await asyncio.to_thread(t.start, result) # 将 TTS 生成放到线程池中执行
# 文件上传到 OSS
await asyncio.to_thread(upload_mp3_to_oss, tts_file, tts_file)
# 删除临时文件
try:
os.remove(tts_file)
logger.info(f"临时文件 {tts_file} 已删除")
except Exception as e:
logger.error(f"删除临时文件失败: {e}")
# 完整的 URL
url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file
return {
"success": True,
"url": url,
"search_time": end_time - start_time, # 返回查询耗时
"response": result # 返回大模型的回复
}
else:
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
except Exception as e:
logger.error(f"调用大模型失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}")
finally:
# 释放连接
milvus_pool.release_connection(connection)
# 运行 FastAPI 应用
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5600)