diff --git a/AI/WxMini/Start.py b/AI/WxMini/Start.py index ba99ec02..c8c7fcc8 100644 --- a/AI/WxMini/Start.py +++ b/AI/WxMini/Start.py @@ -17,7 +17,7 @@ from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager from WxMini.Milvus.Utils.MilvusConnectionPool import * from WxMini.Utils.EmbeddingUtil import text_to_embedding from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, get_user_by_login_name, \ - get_chat_logs_by_risk_flag + get_chat_logs_by_risk_flag, get_chat_logs_summary from WxMini.Utils.MySQLUtil import update_risk, get_last_chat_log_id from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory, get_sts_token from WxMini.Utils.TtsUtil import TTS @@ -402,7 +402,7 @@ async def get_risk_chat_logs( offset = (page - 1) * page_size # 调用 get_chat_logs_by_risk_flag 方法 - logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag, offset, page_size) + logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag,current_user["person_id"], offset, page_size) if not logs: raise HTTPException(status_code=404, detail="未找到符合条件的记录") @@ -414,13 +414,70 @@ async def get_risk_chat_logs( "total": total, "page": page, "page_size": page_size, + "logs": logs + } + } + + + +# 获取风险统计接口 +@app.get("/aichat/chat_logs_summary") +async def chat_logs_summary( + risk_flag: int = Query(..., description="风险标志(1 表示有风险,0 表示无风险 ,2:处理完毕)"), + page: int = Query(default=1, ge=1, description="当前页码(默认值为 1)"), + page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10,最大值为 100)"), + current_user: dict = Depends(get_current_user) +): + """ + 获取风险统计接口,支持分页和风险标志过滤 + :param risk_flag: 风险标志 + :param page: 当前页码 + :param page_size: 每页记录数 + :param current_user: 当前用户信息 + :return: 分页数据 + """ + # 验证 risk_flag 的值 + if risk_flag not in {0, 1, 2}: + raise HTTPException(status_code=400, detail="risk_flag 的值必须是 0、1 或 2") + + # 计算分页偏移量 + offset = (page - 1) * page_size + + # 调用 get_chat_logs_summary 方法 + logs, total = await get_chat_logs_summary(app.state.mysql_pool, risk_flag, offset, page_size) + + # 如果未找到记录,返回友好提示 + if not logs: + return { + "success": True, + "message": "未找到符合条件的记录", + "data": { + "total": 0, + "page": page, + "page_size": page_size, + "total_pages": 0, + "logs": [] + } + } + + # 计算总页数 + total_pages = (total + page_size - 1) // page_size + + # 返回分页数据 + return { + "success": True, + "message": "查询成功", + "data": { + "total": total, + "page": page, + "page_size": page_size, + "total_pages": total_pages, "logs": logs, "login_name": current_user["login_name"], "person_name": current_user["person_name"] } } - # 获取上传OSS的授权Token @app.get("/aichat/get_oss_upload_token") async def get_oss_upload_token(current_user: dict = Depends(get_current_user)): diff --git a/AI/WxMini/Utils/MySQLUtil.py b/AI/WxMini/Utils/MySQLUtil.py index 25b9b368..6215b79b 100644 --- a/AI/WxMini/Utils/MySQLUtil.py +++ b/AI/WxMini/Utils/MySQLUtil.py @@ -177,7 +177,53 @@ async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]: return dict(zip(columns, row)) -async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int): +# 显示统计分析页面 +async def get_chat_logs_summary(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int): + """ + 获取聊天记录的统计分析结果 + :param mysql_pool: MySQL 连接池 + :param risk_flag: 风险标志 + :param offset: 偏移量 + :param page_size: 每页记录数 + :return: 日志列表和总记录数 + """ + async with mysql_pool.acquire() as conn: + await conn.ping() # 重置连接 + async with conn.cursor() as cursor: + # 查询符合条件的记录 + sql = """ + SELECT tbp.*, COUNT(*) AS cnt + FROM t_chat_log AS tcl + INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id + WHERE tcl.risk_flag = %s + GROUP BY tcl.person_id + ORDER BY COUNT(*) DESC + LIMIT %s OFFSET %s + """ + await cursor.execute(sql, (risk_flag, page_size, offset)) + rows = await cursor.fetchall() + + # 获取列名 + columns = [column[0] for column in cursor.description] + + # 查询总记录数 + count_sql = """ + SELECT COUNT(DISTINCT tcl.person_id) + FROM t_chat_log AS tcl + INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id + WHERE tcl.risk_flag = %s + """ + await cursor.execute(count_sql, (risk_flag,)) + total = (await cursor.fetchone())[0] + + # 将元组转换为字典 + logs = [dict(zip(columns, row)) for row in rows] if rows else [] + + return logs, total + + +async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, person_id: str, offset: int, page_size: int) -> ( + List[Dict], int): """ 根据风险标志查询聊天记录 :param mysql_pool: MySQL 连接池 @@ -196,10 +242,10 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, pa tbp.person_id, tbp.login_name, tbp.person_name FROM t_chat_log AS tcl INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id - WHERE tcl.risk_flag = %s ORDER BY TCL.ID DESC + WHERE tcl.risk_flag = %s and tcl.person_id=%s ORDER BY TCL.ID DESC LIMIT %s OFFSET %s """ - await cursor.execute(sql, (risk_flag, page_size, offset)) + await cursor.execute(sql, (risk_flag, person_id, page_size, offset)) rows = await cursor.fetchall() # 在 count_sql 执行前获取列名 @@ -210,9 +256,9 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, pa SELECT COUNT(*) FROM t_chat_log AS tcl INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id - WHERE tcl.risk_flag = %s + WHERE tcl.risk_flag = %s and tcl.person_id=%s """ - await cursor.execute(count_sql, (risk_flag,)) + await cursor.execute(count_sql, (risk_flag,person_id)) total = (await cursor.fetchone())[0] # 将元组转换为字典,并格式化 create_time @@ -225,4 +271,4 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, pa log["create_time"] = log["create_time"].strftime("%Y-%m-%d %H:%M:%S") logs.append(log) return logs, total - return [], 0 \ No newline at end of file + return [], 0