diff --git a/AI/WxMini/Start.py b/AI/WxMini/Start.py index cac1a037..84fb3a74 100644 --- a/AI/WxMini/Start.py +++ b/AI/WxMini/Start.py @@ -13,7 +13,7 @@ from WxMini.Milvus.Utils.MilvusConnectionPool import * from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory from WxMini.Utils.TtsUtil import TTS from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, update_risk, \ - get_risk_chat_log_page, get_last_chat_log_id, get_user_by_login_name + get_last_chat_log_id, get_user_by_login_name, get_chat_logs_by_risk_flag from WxMini.Utils.EmbeddingUtil import text_to_embedding # 配置日志 @@ -309,25 +309,39 @@ async def get_chat_log( result = await get_chat_log_by_session(app.state.mysql_pool, person_id, page, page_size) return result - -@app.get("/aichat/get_risk_page") -async def get_risk_page( - risk_flag: int = Query(default=1, ge=1, description="1:有风险,0:无风险,2:有风险但已处理"), - page: int = Query(default=1, ge=1, description="当前页码"), - page_size: int = Query(default=10, ge=1, le=100, description="每页记录数") +# 获取风险聊天记录接口 +@app.get("/aichat/get_risk_chat_logs") +async def get_risk_chat_logs( + risk_flag: int = Query(..., description="风险标志(1 表示有风险,0 表示无风险)"), + page: int = Query(default=1, ge=1, description="当前页码(默认值为 1)"), + page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10,最大值为 100)") ): """ - 查询有风险的聊天记录,并按 id 降序分页 + 获取聊天记录,支持分页和风险标志过滤 + :param risk_flag: 风险标志 :param page: 当前页码 :param page_size: 每页记录数 :return: 分页数据 """ - try: - result = await get_risk_chat_log_page(app.state.mysql_pool, risk_flag, page, page_size) - return result - except Exception as e: - logger.error(f"查询有风险的聊天记录失败: {str(e)}") - raise HTTPException(status_code=500, detail=f"查询有风险的聊天记录失败: {str(e)}") + # 计算分页偏移量 + offset = (page - 1) * page_size + + # 调用 get_chat_logs_by_risk_flag 方法 + logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag, offset, page_size) + if not logs: + raise HTTPException(status_code=404, detail="未找到符合条件的记录") + + # 返回分页数据 + return { + "code": 200, + "message": "查询成功", + "data": { + "total": total, + "page": page, + "page_size": page_size, + "logs": logs + } + } # 运行 FastAPI 应用 diff --git a/AI/WxMini/Utils/MySQLUtil.py b/AI/WxMini/Utils/MySQLUtil.py index 91cd6b48..46704c25 100644 --- a/AI/WxMini/Utils/MySQLUtil.py +++ b/AI/WxMini/Utils/MySQLUtil.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Dict +from typing import Optional, Dict, List from aiomysql import create_pool from WxMini.Milvus.Config.MulvusConfig import * @@ -147,59 +147,6 @@ async def update_risk(mysql_pool, person_id, risk_memo): logger.warning(f"未找到 person_id={person_id} 的记录。") -# 查询有风险的聊天记录 -async def get_risk_chat_log_page(mysql_pool, risk_flag, page=1, page_size=10): - """ - 查询有风险的聊天记录,并按 id 降序分页 - :param mysql_pool: MySQL 连接池 - :param risk_flag: 风险标志 - :param page: 当前页码 - :param page_size: 每页记录数 - :return: 分页数据 - """ - offset = (page - 1) * page_size - async with mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - # 查询总记录数 - await cur.execute( - "SELECT COUNT(*) FROM t_chat_log WHERE risk_flag = %s", (risk_flag,) - ) - total = (await cur.fetchone())[0] - logger.info(f"总记录数: {total}") - - # 查询分页数据 - query = ( - "SELECT id, person_id, user_input, model_response, audio_url, duration, create_time, risk_memo " - "FROM t_chat_log WHERE risk_flag = %s ORDER BY id DESC LIMIT %s OFFSET %s" - ) - params = (risk_flag, page_size, offset) - logger.debug(f"执行查询: {query % params}") # 打印 SQL 查询 - - await cur.execute(query, params) - records = await cur.fetchall() - logger.debug(f"查询结果: {records}") # 打印查询结果 - - # 将查询结果转换为字典列表 - result = [ - { - "id": record[0], - "person_id": record[1], - "user_input": record[2], - "model_response": record[3], - "audio_url": record[4], - "duration": record[5], - "create_time": record[6].strftime("%Y-%m-%d %H:%M:%S"), - "risk_memo": record[7] - } - for record in records - ] - - return { - "data": result, - "total": total, - "page": page, - "page_size": page_size - } # 查询用户信息 async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]: """ @@ -218,4 +165,47 @@ async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]: # 将元组转换为字典 columns = [column[0] for column in cursor.description] - return dict(zip(columns, row)) \ No newline at end of file + return dict(zip(columns, row)) + +# 查询聊天记录 +async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int): + """ + 根据风险标志查询聊天记录 + :param pool: MySQL 连接池 + :param risk_flag: 风险标志 + :param offset: 分页偏移量 + :param page_size: 每页记录数 + :return: 聊天记录列表和总记录数 + """ + async with mysql_pool.acquire() as conn: + async with conn.cursor() as cursor: + # 查询符合条件的记录 + sql = """ + SELECT + tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration, + tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result, + tbp.* + FROM t_chat_log AS tcl + INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id + WHERE tcl.risk_flag = %s + LIMIT %s OFFSET %s + """ + await cursor.execute(sql, (risk_flag, page_size, offset)) + rows = await cursor.fetchall() + + # 查询总记录数 + count_sql = """ + SELECT COUNT(*) + FROM t_chat_log AS tcl + INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id + WHERE tcl.risk_flag = %s + """ + await cursor.execute(count_sql, (risk_flag,)) + total = (await cursor.fetchone())[0] + + # 将元组转换为字典 + if rows: + columns = [column[0] for column in cursor.description] + logs = [dict(zip(columns, row)) for row in rows] + return logs, total + return [], 0 \ No newline at end of file