diff --git a/AI/WxMini/Start.py b/AI/WxMini/Start.py index 58f9165f..f3ef07ff 100644 --- a/AI/WxMini/Start.py +++ b/AI/WxMini/Start.py @@ -13,7 +13,7 @@ from WxMini.Milvus.Utils.MilvusConnectionPool import * from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory from WxMini.Utils.TtsUtil import TTS from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, update_risk, \ - get_risk_chat_log_page + get_risk_chat_log_page, get_last_chat_log_id from WxMini.Utils.EmbeddingUtil import text_to_embedding # 配置日志 @@ -47,13 +47,26 @@ async def lifespan(app: FastAPI): # 会话结束后,调用检查方法,判断是不是有需要介入的问题出现 async def on_session_end(session_id): - # 获取聊天记录 - result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page=1, page_size=1) + # 获取最后一条聊天记录 + last_id = await get_last_chat_log_id(app.state.mysql_pool, session_id) + if last_id: + # 查询最后一条记录的详细信息 + async with app.state.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + "SELECT user_input, model_response FROM t_chat_log WHERE id = %s", + (last_id,) + ) + last_record = await cur.fetchone() + + if last_record: + # 拼接历史聊天记录 + history = f"问题:{last_record['user_input']}\n回答:{last_record['model_response']}" + else: + history = "无聊天记录" + else: + history = "无聊天记录" - # 拼接历史聊天记录 - history = "" - for row in result['data']: - history = f"{history}\n问题:{row['user_input']}\n回答:{row['model_response']}" # 将历史聊天记录发给大模型,让它帮我分析一下 prompt = ( @@ -225,27 +238,25 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)): milvus_pool.release_connection(connection) +# 获取聊天记录 + # 获取聊天记录 @app.get("/aichat/get_chat_log") async def get_chat_log( session_id: str, - page: int = Query(default=1, ge=1, description="当前页码"), + page: int = Query(default=None, ge=1, description="当前页码(如果为 None,则默认跳转到最后一页)"), page_size: int = Query(default=10, ge=1, le=100, description="每页记录数") ): """ - 根据 session_id 查询聊天记录,并按 id 降序分页 + 获取指定会话的聊天记录,默认返回最新的记录(最后一页) :param session_id: 用户会话 ID - :param page: 当前页码 + :param page: 当前页码(如果为 None,则默认跳转到最后一页) :param page_size: 每页记录数 :return: 分页数据 """ - try: - result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page, page_size) - return result - except Exception as e: - logger.error(f"查询聊天记录失败: {str(e)}") - raise HTTPException(status_code=500, detail=f"查询聊天记录失败: {str(e)}") - + # 调用 get_chat_log_by_session 方法 + result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page, page_size) + return result @app.get("/aichat/get_risk_page") async def get_risk_page( diff --git a/AI/WxMini/Test/TongJi.py b/AI/WxMini/Test/TongJi.py deleted file mode 100644 index 628714f3..00000000 --- a/AI/WxMini/Test/TongJi.py +++ /dev/null @@ -1,67 +0,0 @@ -import asyncio -from openai import AsyncOpenAI -from WxMini.Milvus.Config.MulvusConfig import * -from WxMini.Utils.MySQLUtil import init_mysql_pool, get_chat_log_by_session - -# 初始化异步 OpenAI 客户端 -client = AsyncOpenAI( - api_key=MODEL_API_KEY, - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", -) - -async def main(): - # 哪个人员 - session_id = 1 - # 哪一页 - page = 1 - # 一页多少个 - page_size = 100 - - # 初始化 MySQL 连接池 - mysql_pool = await init_mysql_pool() - - # 调用 - result = await get_chat_log_by_session(mysql_pool, session_id, page, page_size) - - # 我只关心 user_input 与 model_response - # 把这些拼接出问题与回答 - history = "" - for row in result['data']: # 注意:result 是一个字典,包含 'data' 字段 - user_input = row['user_input'] - model_response = row['model_response'] - history = f"{history}\n问题:{user_input}\n回答:{model_response}" - - # 将历史聊天记录发给大模型,让它帮我分析一下 - prompt = ( - "我将把用户与AI大模型交流的记录发给你,帮我分析一下这个用户是否存在心理健康方面的问题," - "参考:1、PHQ-9抑郁症筛查量表和2、Beck自杀意念评量表(BSI-CV)。" - "如果没有健康问题请回复: OK;否则回复:NO,换行后再输出是什么问题。" - f"\n\n历史聊天记录:{history}" - ) - - # 调用大模型进行分析 - try: - response = await client.chat.completions.create( - model=MODEL_NAME, - messages=[ - {"role": "system", "content": "你是一个心理健康分析助手,负责分析用户的心理健康状况。"}, - {"role": "user", "content": prompt} - ], - max_tokens=1000 - ) - # 提取大模型的回复 - if response.choices and response.choices[0].message.content: - analysis_result = response.choices[0].message.content.strip() - print("大模型分析结果:") - print(analysis_result) - else: - print("大模型未返回有效结果。") - except Exception as e: - print(f"调用大模型失败: {str(e)}") - - # 关闭连接池 - mysql_pool.close() - await mysql_pool.wait_closed() - -if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file diff --git a/AI/WxMini/Utils/MySQLUtil.py b/AI/WxMini/Utils/MySQLUtil.py index 3559fd95..b24bd5d6 100644 --- a/AI/WxMini/Utils/MySQLUtil.py +++ b/AI/WxMini/Utils/MySQLUtil.py @@ -44,18 +44,18 @@ async def truncate_chat_log(mysql_pool): # 分页查询聊天记录 -async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10): +async def get_chat_log_by_session(mysql_pool, session_id, page=None, page_size=10): """ - 根据 session_id 查询聊天记录,并按 id 降序分页 + 根据 session_id 查询聊天记录,并按 id 升序分页 + :param mysql_pool: MySQL 连接池 :param session_id: 用户会话 ID - :param page: 当前页码 + :param page: 当前页码(如果为 None,则默认跳转到最后一页) :param page_size: 每页记录数 :return: 分页数据 """ if not mysql_pool: raise ValueError("MySQL 连接池未初始化") - offset = (page - 1) * page_size async with mysql_pool.acquire() as conn: async with conn.cursor() as cur: # 查询总记录数 @@ -65,10 +65,20 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10): ) total = (await cur.fetchone())[0] - # 查询分页数据,按 id 降序排列 + # 计算总页数 + total_pages = (total + page_size - 1) // page_size + + # 如果未指定页码,则默认跳转到最后一页 + if page is None: + page = total_pages + + # 计算偏移量 + offset = (page - 1) * page_size + + # 查询分页数据,按 id 升序排列 await cur.execute( "SELECT id, session_id, user_input, model_response, audio_url, duration, create_time " - "FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT %s OFFSET %s", + "FROM t_chat_log WHERE session_id = %s ORDER BY id ASC LIMIT %s OFFSET %s", (session_id, page_size, offset) ) records = await cur.fetchall() @@ -88,26 +98,40 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10): ] return { - "data": result, + "data": result, # 按 id 升序排列的数据 "total": total, "page": page, - "page_size": page_size + "page_size": page_size, + "total_pages": total_pages } -# 更新为危险的记录 -async def update_risk(mysql_pool, session_id, risk_memo): +# 获取指定会话的最后一条记录的 id +async def get_last_chat_log_id(mysql_pool, session_id): + """ + 获取指定会话的最后一条记录的 id + :param mysql_pool: MySQL 连接池 + :param session_id: 用户会话 ID + :return: 最后一条记录的 id,如果未找到则返回 None + """ async with mysql_pool.acquire() as conn: async with conn.cursor() as cur: - # 1. 获取此人员的最后一条记录 id await cur.execute( "SELECT id FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT 1", (session_id,) ) result = await cur.fetchone() + return result[0] if result else None + + +# 更新为危险的记录 +async def update_risk(mysql_pool, session_id, risk_memo): + async with mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + # 1. 获取此人员的最后一条记录 id + last_id = await get_last_chat_log_id(mysql_pool, session_id) - if result: - last_id = result[0] + if last_id: # 2. 更新 risk_flag 和 risk_memo await cur.execute( "UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s", diff --git a/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc b/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc index 90ecbe7d..bef1a193 100644 Binary files a/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc and b/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc differ diff --git a/AI/WxMini/__pycache__/Start.cpython-310.pyc b/AI/WxMini/__pycache__/Start.cpython-310.pyc index 56a81c81..58368920 100644 Binary files a/AI/WxMini/__pycache__/Start.cpython-310.pyc and b/AI/WxMini/__pycache__/Start.cpython-310.pyc differ