|
|
import asyncio
|
|
|
import logging
|
|
|
import time
|
|
|
import uuid
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
|
from fastapi import FastAPI, Form, HTTPException
|
|
|
from openai import AsyncOpenAI
|
|
|
|
|
|
from WxMini.Milvus.Config.MulvusConfig import *
|
|
|
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
|
|
|
from WxMini.Milvus.Utils.MilvusConnectionPool import *
|
|
|
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory
|
|
|
from WxMini.Utils.TtsUtil import TTS
|
|
|
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql
|
|
|
from WxMini.Utils.EmbeddingUtil import text_to_embedding
|
|
|
|
|
|
# 配置日志
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# 初始化 Milvus 连接池
|
|
|
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
|
|
|
|
|
|
# 初始化集合管理器
|
|
|
collection_name = MS_COLLECTION_NAME
|
|
|
collection_manager = MilvusCollectionManager(collection_name)
|
|
|
|
|
|
# 使用 Lifespan Events 处理应用启动和关闭逻辑
|
|
|
@asynccontextmanager
|
|
|
async def lifespan(app: FastAPI):
|
|
|
# 应用启动时加载集合到内存
|
|
|
collection_manager.load_collection()
|
|
|
logger.info(f"集合 '{collection_name}' 已加载到内存。")
|
|
|
# 初始化 MySQL 连接池
|
|
|
app.state.mysql_pool = await init_mysql_pool()
|
|
|
logger.info("MySQL 连接池已初始化。")
|
|
|
yield
|
|
|
# 应用关闭时释放连接池
|
|
|
milvus_pool.close()
|
|
|
app.state.mysql_pool.close()
|
|
|
await app.state.mysql_pool.wait_closed()
|
|
|
logger.info("Milvus 和 MySQL 连接池已关闭。")
|
|
|
|
|
|
# 初始化 FastAPI 应用
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
# 初始化异步 OpenAI 客户端
|
|
|
client = AsyncOpenAI(
|
|
|
api_key=MODEL_API_KEY,
|
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
|
)
|
|
|
|
|
|
@app.post("/reply")
|
|
|
async def reply(session_id: str = Form(...), prompt: str = Form(...)):
|
|
|
"""
|
|
|
接收用户输入的 prompt,调用大模型并返回结果
|
|
|
:param session_id: 用户会话 ID
|
|
|
:param prompt: 用户输入的 prompt
|
|
|
:return: 大模型的回复
|
|
|
"""
|
|
|
try:
|
|
|
logger.info(f"收到用户输入: {prompt}")
|
|
|
# 从连接池中获取一个连接
|
|
|
connection = milvus_pool.get_connection()
|
|
|
|
|
|
# 将用户输入转换为嵌入向量
|
|
|
current_embedding = text_to_embedding(prompt)
|
|
|
|
|
|
# 查询与当前对话最相关的历史交互
|
|
|
search_params = {
|
|
|
"metric_type": "L2", # 使用 L2 距离度量方式
|
|
|
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
|
|
|
}
|
|
|
start_time = time.time()
|
|
|
results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行
|
|
|
collection_manager.search,
|
|
|
data=current_embedding, # 输入向量
|
|
|
search_params=search_params, # 搜索参数
|
|
|
expr=f"session_id == '{session_id}'", # 按 session_id 过滤
|
|
|
limit=5 # 返回 5 条结果
|
|
|
)
|
|
|
end_time = time.time()
|
|
|
|
|
|
# 构建历史交互提示词
|
|
|
history_prompt = ""
|
|
|
if results:
|
|
|
for hits in results:
|
|
|
for hit in hits:
|
|
|
try:
|
|
|
# 查询非向量字段
|
|
|
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
|
|
|
if record:
|
|
|
logger.info(f"查询到的记录: {record}")
|
|
|
# 添加历史交互
|
|
|
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
|
|
|
except Exception as e:
|
|
|
logger.error(f"查询失败: {e}")
|
|
|
|
|
|
# 限制历史交互提示词长度
|
|
|
history_prompt = history_prompt[:2000]
|
|
|
logger.info(f"历史交互提示词: {history_prompt}")
|
|
|
|
|
|
# 调用大模型,将历史交互作为提示词
|
|
|
try:
|
|
|
response = await asyncio.wait_for(
|
|
|
client.chat.completions.create(
|
|
|
model=MODEL_NAME,
|
|
|
messages=[
|
|
|
{"role": "system",
|
|
|
"content": "你是一个私人助理,负责回答用户的问题。请根据用户的历史对话和当前问题,提供准确且简洁的回答。不要提及你是通义千问或其他无关信息,也不可以回复与本次用户问题不相关的历史对话记录内容。"},
|
|
|
{"role": "user", "content": f"历史对话记录:{history_prompt},本次用户问题: {prompt}"}
|
|
|
],
|
|
|
max_tokens=500
|
|
|
),
|
|
|
timeout=60 # 设置超时时间为 60 秒
|
|
|
)
|
|
|
except asyncio.TimeoutError:
|
|
|
logger.error("大模型调用超时")
|
|
|
raise HTTPException(status_code=500, detail="大模型调用超时")
|
|
|
|
|
|
# 提取生成的回复
|
|
|
if response.choices and response.choices[0].message.content:
|
|
|
result = response.choices[0].message.content.strip()
|
|
|
logger.info(f"大模型回复: {result}")
|
|
|
|
|
|
# 记录用户输入和大模型反馈到向量数据库
|
|
|
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
|
|
entities = [
|
|
|
[session_id], # session_id
|
|
|
[prompt[:500]], # user_input,截断到 500 字符
|
|
|
[result[:500]], # model_response,截断到 500 字符
|
|
|
[timestamp], # timestamp
|
|
|
[current_embedding] # embedding
|
|
|
]
|
|
|
if len(prompt) > 500:
|
|
|
logger.warning(f"用户输入被截断,原始长度: {len(prompt)}")
|
|
|
if len(result) > 500:
|
|
|
logger.warning(f"大模型回复被截断,原始长度: {len(result)}")
|
|
|
await asyncio.to_thread(collection_manager.insert_data, entities)
|
|
|
logger.info("用户输入和大模型反馈已记录到向量数据库。")
|
|
|
|
|
|
# 记录聊天数据到 MySQL
|
|
|
await save_chat_to_mysql(app.state.mysql_pool, session_id, prompt, result)
|
|
|
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
|
|
|
|
|
|
# 调用 TTS 生成 MP3
|
|
|
uuid_str = str(uuid.uuid4())
|
|
|
timestamp = int(time.time())
|
|
|
tts_file = f"audio/{uuid_str}_{timestamp}.mp3"
|
|
|
|
|
|
# 生成 TTS 音频数据(不落盘)
|
|
|
t = TTS(None) # 传入 None 表示不保存到本地文件
|
|
|
audio_data = await asyncio.to_thread(t.generate_audio, result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
|
|
|
|
|
|
# 将音频数据直接上传到 OSS
|
|
|
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data)
|
|
|
logger.info(f"TTS 文件已直接上传到 OSS: {tts_file}")
|
|
|
|
|
|
# 完整的 URL
|
|
|
url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file
|
|
|
return {
|
|
|
"success": True,
|
|
|
"url": url,
|
|
|
"search_time": end_time - start_time, # 返回查询耗时
|
|
|
"response": result # 返回大模型的回复
|
|
|
}
|
|
|
else:
|
|
|
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
|
|
|
except Exception as e:
|
|
|
logger.error(f"调用大模型失败: {str(e)}")
|
|
|
raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}")
|
|
|
finally:
|
|
|
# 释放连接
|
|
|
milvus_pool.release_connection(connection)
|
|
|
|
|
|
# 运行 FastAPI 应用
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
|
|
|
uvicorn.run("Start:app", host="0.0.0.0", port=5600, workers=1) |