import asyncio import logging import time import uuid from contextlib import asynccontextmanager from fastapi import FastAPI, Form, HTTPException, Query from openai import AsyncOpenAI from WxMini.Milvus.Config.MulvusConfig import * from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager from WxMini.Milvus.Utils.MilvusConnectionPool import * from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory from WxMini.Utils.TtsUtil import TTS from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, update_risk, \ get_risk_chat_log_page, get_last_chat_log_id from WxMini.Utils.EmbeddingUtil import text_to_embedding # 配置日志 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # 初始化 Milvus 连接池 milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS) # 初始化集合管理器 collection_name = MS_COLLECTION_NAME collection_manager = MilvusCollectionManager(collection_name) # 使用 Lifespan Events 处理应用启动和关闭逻辑 @asynccontextmanager async def lifespan(app: FastAPI): # 应用启动时加载集合到内存 collection_manager.load_collection() logger.info(f"集合 '{collection_name}' 已加载到内存。") # 初始化 MySQL 连接池 app.state.mysql_pool = await init_mysql_pool() logger.info("MySQL 连接池已初始化。") yield # 应用关闭时释放连接池 milvus_pool.close() app.state.mysql_pool.close() await app.state.mysql_pool.wait_closed() logger.info("Milvus 和 MySQL 连接池已关闭。") # 会话结束后,调用检查方法,判断是不是有需要介入的问题出现 async def on_session_end(session_id): # 获取最后一条聊天记录 last_id = await get_last_chat_log_id(app.state.mysql_pool, session_id) if last_id: # 查询最后一条记录的详细信息 async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( "SELECT user_input, model_response FROM t_chat_log WHERE id = %s", (last_id,) ) last_record = await cur.fetchone() if last_record: # 拼接历史聊天记录 history = f"问题:{last_record['user_input']}\n回答:{last_record['model_response']}" else: history = "无聊天记录" else: history = "无聊天记录" # 将历史聊天记录发给大模型,让它帮我分析一下 prompt = ( "我将把用户与AI大模型交流的记录发给你,帮我分析一下这个用户是否存在心理健康方面的问题," "参考:1、PHQ-9抑郁症筛查量表和2、Beck自杀意念评量表(BSI-CV)。" "如果没有健康问题请回复: OK;否则回复:NO,换行后再输出是什么问题。" f"\n\n历史聊天记录:{history}" ) response = await client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": "你是一个心理健康分析助手,负责分析用户的心理健康状况。"}, {"role": "user", "content": prompt} ], max_tokens=1000 ) # 处理分析结果 if response.choices and response.choices[0].message.content: analysis_result = response.choices[0].message.content.strip() if analysis_result.startswith("NO"): # 异步执行 update_risk asyncio.create_task(update_risk(app.state.mysql_pool, session_id, analysis_result)) logger.info(f"已异步更新 session_id={session_id} 的风险状态。") else: logger.info(f"AI大模型没有发现任何心理健康问题,用户会话 {session_id} 没有风险。") # 初始化 FastAPI 应用 app = FastAPI(lifespan=lifespan) # 初始化异步 OpenAI 客户端 client = AsyncOpenAI( api_key=MODEL_API_KEY, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) @app.post("/aichat/reply") async def reply(session_id: str = Form(...), prompt: str = Form(...)): """ 接收用户输入的 prompt,调用大模型并返回结果 :param session_id: 用户会话 ID :param prompt: 用户输入的 prompt :return: 大模型的回复 """ try: logger.info(f"收到用户输入: {prompt}") # 从连接池中获取一个连接 connection = milvus_pool.get_connection() # 将用户输入转换为嵌入向量 current_embedding = text_to_embedding(prompt) # 查询与当前对话最相关的历史交互 search_params = { "metric_type": "L2", # 使用 L2 距离度量方式 "params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数 } start_time = time.time() results = await asyncio.to_thread( # 将阻塞操作放到线程池中执行 collection_manager.search, data=current_embedding, # 输入向量 search_params=search_params, # 搜索参数 expr=f"session_id == '{session_id}'", # 按 session_id 过滤 limit=5 # 返回 5 条结果 ) end_time = time.time() # 构建历史交互提示词 history_prompt = "" if results: for hits in results: for hit in hits: try: # 查询非向量字段 record = await asyncio.to_thread(collection_manager.query_by_id, hit.id) if record: logger.info(f"查询到的记录: {record}") # 添加历史交互 history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n" except Exception as e: logger.error(f"查询失败: {e}") # 限制历史交互提示词长度 history_prompt = history_prompt[:2000] logger.info(f"历史交互提示词: {history_prompt}") # 调用大模型,将历史交互作为提示词 try: response = await asyncio.wait_for( client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": "你是一个私人助理,负责回答用户的问题。请根据用户的历史对话和当前问题,提供准确且简洁的回答。不要提及你是通义千问或其他无关信息,也不可以回复与本次用户问题不相关的历史对话记录内容,回复内容不要超过90字。"}, {"role": "user", "content": f"历史对话记录:{history_prompt},本次用户问题: {prompt}"} ], max_tokens=100 ), timeout=60 # 设置超时时间为 60 秒 ) except asyncio.TimeoutError: logger.error("大模型调用超时") raise HTTPException(status_code=500, detail="大模型调用超时") # 提取生成的回复 if response.choices and response.choices[0].message.content: result = response.choices[0].message.content.strip() logger.info(f"大模型回复: {result}") # 记录用户输入和大模型反馈到向量数据库 timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) entities = [ [session_id], # session_id [prompt[:500]], # user_input,截断到 500 字符 [result[:500]], # model_response,截断到 500 字符 [timestamp], # timestamp [current_embedding] # embedding ] if len(prompt) > 500: logger.warning(f"用户输入被截断,原始长度: {len(prompt)}") if len(result) > 500: logger.warning(f"大模型回复被截断,原始长度: {len(result)}") await asyncio.to_thread(collection_manager.insert_data, entities) logger.info("用户输入和大模型反馈已记录到向量数据库。") # 调用 TTS 生成 MP3 uuid_str = str(uuid.uuid4()) timestamp = int(time.time()) tts_file = f"audio/{uuid_str}_{timestamp}.mp3" # 生成 TTS 音频数据(不落盘) t = TTS(None) # 传入 None 表示不保存到本地文件 audio_data, duration = await asyncio.to_thread(t.generate_audio, result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据 print(f"音频时长: {duration} 秒") # 将音频数据直接上传到 OSS await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data) logger.info(f"TTS 文件已直接上传到 OSS: {tts_file}") # 完整的 URL url = 'https://ylt.oss-cn-hangzhou.aliyuncs.com/' + tts_file # 记录聊天数据到 MySQL await save_chat_to_mysql(app.state.mysql_pool, session_id, prompt, result, url, duration) logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。") # 调用会话检查机制 await on_session_end(session_id) # 返回数据 return { "success": True, "url": url, "search_time": end_time - start_time, # 返回查询耗时 "duration": duration, # 返回大模型的回复时长 "response": result # 返回大模型的回复 } else: raise HTTPException(status_code=500, detail="大模型未返回有效结果") except Exception as e: logger.error(f"调用大模型失败: {str(e)}") raise HTTPException(status_code=500, detail=f"调用大模型失败: {str(e)}") finally: # 释放连接 milvus_pool.release_connection(connection) # 获取聊天记录 # 获取聊天记录 @app.get("/aichat/get_chat_log") async def get_chat_log( session_id: str, page: int = Query(default=None, ge=1, description="当前页码(如果为 None,则默认跳转到最后一页)"), page_size: int = Query(default=10, ge=1, le=100, description="每页记录数") ): """ 获取指定会话的聊天记录,默认返回最新的记录(最后一页) :param session_id: 用户会话 ID :param page: 当前页码(如果为 None,则默认跳转到最后一页) :param page_size: 每页记录数 :return: 分页数据 """ # 调用 get_chat_log_by_session 方法 result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page, page_size) return result @app.get("/aichat/get_risk_page") async def get_risk_page( risk_flag: int = Query(default=1, ge=1, description="1:有风险,0:无风险,2:有风险但已处理"), page: int = Query(default=1, ge=1, description="当前页码"), page_size: int = Query(default=10, ge=1, le=100, description="每页记录数") ): """ 查询有风险的聊天记录,并按 id 降序分页 :param page: 当前页码 :param page_size: 每页记录数 :return: 分页数据 """ try: result = await get_risk_chat_log_page(app.state.mysql_pool, risk_flag, page, page_size) return result except Exception as e: logger.error(f"查询有风险的聊天记录失败: {str(e)}") raise HTTPException(status_code=500, detail=f"查询有风险的聊天记录失败: {str(e)}") # 运行 FastAPI 应用 if __name__ == "__main__": import uvicorn uvicorn.run("Start:app", host="0.0.0.0", port=5600, workers=1)