You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

275 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import asyncio
import logging
import time
import uuid
from contextlib import asynccontextmanager
from fastapi import FastAPI, Form, HTTPException, Query
from openai import AsyncOpenAI
from WxMini.Milvus.Config.MulvusConfig import *
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory
from WxMini.Utils.TtsUtil import TTS
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, update_risk, \
get_risk_chat_log_page
from WxMini.Utils.EmbeddingUtil import text_to_embedding
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 初始化 Milvus 连接池
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
# 初始化集合管理器
collection_name = MS_COLLECTION_NAME
collection_manager = MilvusCollectionManager(collection_name)
# 使用 Lifespan Events 处理应用启动和关闭逻辑
@asynccontextmanager
async def lifespan(app: FastAPI):
# 应用启动时加载集合到内存
collection_manager.load_collection()
logger.info(f"集合 '{collection_name}' 已加载到内存。")
# 初始化 MySQL 连接池
app.state.mysql_pool = await init_mysql_pool()
logger.info("MySQL 连接池已初始化。")
yield
# 应用关闭时释放连接池
milvus_pool.close()
app.state.mysql_pool.close()
await app.state.mysql_pool.wait_closed()
logger.info("Milvus 和 MySQL 连接池已关闭。")
# 会话结束后,调用检查方法,判断是不是有需要介入的问题出现
async def on_session_end(session_id):
# 获取聊天记录
result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page=1, page_size=1)
# 拼接历史聊天记录
history = ""
for row in result['data']:
history = f"{history}\n问题:{row['user_input']}\n回答:{row['model_response']}"
# 将历史聊天记录发给大模型,让它帮我分析一下
prompt = (
"我将把用户与AI大模型交流的记录发给你帮我分析一下这个用户是否存在心理健康方面的问题"
"参考1、PHQ-9抑郁症筛查量表和2、Beck自杀意念评量表BSI-CV"
"如果没有健康问题请回复: OK否则回复NO换行后再输出是什么问题。"
f"\n\n历史聊天记录:{history}"
)
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个心理健康分析助手,负责分析用户的心理健康状况。"},
{"role": "user", "content": prompt}
],
max_tokens=1000
)
# 处理分析结果
if response.choices and response.choices[0].message.content:
analysis_result = response.choices[0].message.content.strip()
if analysis_result.startswith("NO"):
# 异步执行 update_risk
asyncio.create_task(update_risk(app.state.mysql_pool, session_id, analysis_result))
logger.info(f"已异步更新 session_id={session_id} 的风险状态。")
else:
logger.info(f"AI大模型没有发现任何心理健康问题用户会话 {session_id} 没有风险。")
# 初始化 FastAPI 应用
app = FastAPI(lifespan=lifespan)
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
@app.post("/aichat/reply")
async def reply(session_id: str = Form(...), prompt: str = Form(...)):
"""
接收用户输入的 prompt调用大模型并返回结果
:param session_id: 用户会话 ID
:param prompt: 用户输入的 prompt
:return: 大模型的回复
"""
try:
logger.info(f"收到用户输入: {prompt}")
# 从连接池中获取一个连接
connection = milvus_pool.get_connection()
# 将用户输入转换为嵌入向量
current_embedding = text_to_embedding(prompt)
# 查询与当前对话最相关的历史交互
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
}
start_time = time.time()
results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行
collection_manager.search,
data=current_embedding, # 输入向量
search_params=search_params, # 搜索参数
expr=f"session_id == '{session_id}'", # 按 session_id 过滤
limit=5 # 返回 5 条结果
)
end_time = time.time()
# 构建历史交互提示词
history_prompt = ""
if results:
for hits in results:
for hit in hits:
try:
# 查询非向量字段
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
if record:
logger.info(f"查询到的记录: {record}")
# 添加历史交互
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
logger.error(f"查询失败: {e}")
# 限制历史交互提示词长度
history_prompt = history_prompt[:2000]
logger.info(f"历史交互提示词: {history_prompt}")
# 调用大模型,将历史交互作为提示词
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system",
"content": "你是一个私人助理,负责回答用户的问题。请根据用户的历史对话和当前问题,提供准确且简洁的回答。不要提及你是通义千问或其他无关信息,也不可以回复与本次用户问题不相关的历史对话记录内容,回复内容不要超过90字。"},
{"role": "user", "content": f"历史对话记录:{history_prompt},本次用户问题: {prompt}"}
],
max_tokens=100
),
timeout=60 # 设置超时时间为 60 秒
)
except asyncio.TimeoutError:
logger.error("大模型调用超时")
raise HTTPException(status_code=500, detail="大模型调用超时")
# 提取生成的回复
if response.choices and response.choices[0].message.content:
result = response.choices[0].message.content.strip()
logger.info(f"大模型回复: {result}")
# 记录用户输入和大模型反馈到向量数据库
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
entities = [
[session_id], # session_id
[prompt[:500]], # user_input截断到 500 字符
[result[:500]], # model_response截断到 500 字符
[timestamp], # timestamp
[current_embedding] # embedding
]
if len(prompt) > 500:
logger.warning(f"用户输入被截断,原始长度: {len(prompt)}")
if len(result) > 500:
logger.warning(f"大模型回复被截断,原始长度: {len(result)}")
await asyncio.to_thread(collection_manager.insert_data, entities)
logger.info("用户输入和大模型反馈已记录到向量数据库。")
# 调用 TTS 生成 MP3
uuid_str = str(uuid.uuid4())
timestamp = int(time.time())
tts_file = f"audio/{uuid_str}_{timestamp}.mp3"
# 生成 TTS 音频数据(不落盘)
t = TTS(None) # 传入 None 表示不保存到本地文件
audio_data, duration = await asyncio.to_thread(t.generate_audio,
result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
print(f"音频时长: {duration}")
# 将音频数据直接上传到 OSS
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data)
logger.info(f"TTS 文件已直接上传到 OSS: {tts_file}")
# 完整的 URL
url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file
# 记录聊天数据到 MySQL
await save_chat_to_mysql(app.state.mysql_pool, session_id, prompt, result, url, duration)
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
# 调用会话检查机制
await on_session_end(session_id)
# 返回数据
return {
"success": True,
"url": url,
"search_time": end_time - start_time, # 返回查询耗时
"duration": duration, # 返回大模型的回复时长
"response": result # 返回大模型的回复
}
else:
raise HTTPException(status_code=500, detail="大模型未返回有效结果")
except Exception as e:
logger.error(f"调用大模型失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}")
finally:
# 释放连接
milvus_pool.release_connection(connection)
# 获取聊天记录
@app.get("/aichat/get_chat_log")
async def get_chat_log(
session_id: str,
page: int = Query(default=1, ge=1, description="当前页码"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数")
):
"""
根据 session_id 查询聊天记录,并按 id 降序分页
:param session_id: 用户会话 ID
:param page: 当前页码
:param page_size: 每页记录数
:return: 分页数据
"""
try:
result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page, page_size)
return result
except Exception as e:
logger.error(f"查询聊天记录失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"查询聊天记录失败: {str(e)}")
@app.get("/aichat/get_risk_page")
async def get_risk_page(
risk_flag: int = Query(default=1, ge=1, description="1有风险0无风险2:有风险但已处理"),
page: int = Query(default=1, ge=1, description="当前页码"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数")
):
"""
查询有风险的聊天记录,并按 id 降序分页
:param page: 当前页码
:param page_size: 每页记录数
:return: 分页数据
"""
try:
result = await get_risk_chat_log_page(app.state.mysql_pool, risk_flag, page, page_size)
return result
except Exception as e:
logger.error(f"查询有风险的聊天记录失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"查询有风险的聊天记录失败: {str(e)}")
# 运行 FastAPI 应用
if __name__ == "__main__":
import uvicorn
uvicorn.run("Start:app", host="0.0.0.0", port=5600, workers=1)