main
HuangHai 4 months ago
parent 8365bb3527
commit 3a6d2735cf

@ -17,7 +17,7 @@ from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Utils.EmbeddingUtil import text_to_embedding
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, get_user_by_login_name, \
get_chat_logs_by_risk_flag
get_chat_logs_by_risk_flag, get_chat_logs_summary
from WxMini.Utils.MySQLUtil import update_risk, get_last_chat_log_id
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory, get_sts_token
from WxMini.Utils.TtsUtil import TTS
@ -402,7 +402,7 @@ async def get_risk_chat_logs(
offset = (page - 1) * page_size
# 调用 get_chat_logs_by_risk_flag 方法
logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag, offset, page_size)
logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag,current_user["person_id"], offset, page_size)
if not logs:
raise HTTPException(status_code=404, detail="未找到符合条件的记录")
@ -414,13 +414,70 @@ async def get_risk_chat_logs(
"total": total,
"page": page,
"page_size": page_size,
"logs": logs
}
}
# 获取风险统计接口
@app.get("/aichat/chat_logs_summary")
async def chat_logs_summary(
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险 ,2:处理完毕)"),
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100"),
current_user: dict = Depends(get_current_user)
):
"""
获取风险统计接口支持分页和风险标志过滤
:param risk_flag: 风险标志
:param page: 当前页码
:param page_size: 每页记录数
:param current_user: 当前用户信息
:return: 分页数据
"""
# 验证 risk_flag 的值
if risk_flag not in {0, 1, 2}:
raise HTTPException(status_code=400, detail="risk_flag 的值必须是 0、1 或 2")
# 计算分页偏移量
offset = (page - 1) * page_size
# 调用 get_chat_logs_summary 方法
logs, total = await get_chat_logs_summary(app.state.mysql_pool, risk_flag, offset, page_size)
# 如果未找到记录,返回友好提示
if not logs:
return {
"success": True,
"message": "未找到符合条件的记录",
"data": {
"total": 0,
"page": page,
"page_size": page_size,
"total_pages": 0,
"logs": []
}
}
# 计算总页数
total_pages = (total + page_size - 1) // page_size
# 返回分页数据
return {
"success": True,
"message": "查询成功",
"data": {
"total": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
"logs": logs,
"login_name": current_user["login_name"],
"person_name": current_user["person_name"]
}
}
# 获取上传OSS的授权Token
@app.get("/aichat/get_oss_upload_token")
async def get_oss_upload_token(current_user: dict = Depends(get_current_user)):

@ -177,7 +177,53 @@ async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
return dict(zip(columns, row))
async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int):
# 显示统计分析页面
async def get_chat_logs_summary(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int):
"""
获取聊天记录的统计分析结果
:param mysql_pool: MySQL 连接池
:param risk_flag: 风险标志
:param offset: 偏移量
:param page_size: 每页记录数
:return: 日志列表和总记录数
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cursor:
# 查询符合条件的记录
sql = """
SELECT tbp.*, COUNT(*) AS cnt
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
GROUP BY tcl.person_id
ORDER BY COUNT(*) DESC
LIMIT %s OFFSET %s
"""
await cursor.execute(sql, (risk_flag, page_size, offset))
rows = await cursor.fetchall()
# 获取列名
columns = [column[0] for column in cursor.description]
# 查询总记录数
count_sql = """
SELECT COUNT(DISTINCT tcl.person_id)
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
"""
await cursor.execute(count_sql, (risk_flag,))
total = (await cursor.fetchone())[0]
# 将元组转换为字典
logs = [dict(zip(columns, row)) for row in rows] if rows else []
return logs, total
async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, person_id: str, offset: int, page_size: int) -> (
List[Dict], int):
"""
根据风险标志查询聊天记录
:param mysql_pool: MySQL 连接池
@ -196,10 +242,10 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, pa
tbp.person_id, tbp.login_name, tbp.person_name
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s ORDER BY TCL.ID DESC
WHERE tcl.risk_flag = %s and tcl.person_id=%s ORDER BY TCL.ID DESC
LIMIT %s OFFSET %s
"""
await cursor.execute(sql, (risk_flag, page_size, offset))
await cursor.execute(sql, (risk_flag, person_id, page_size, offset))
rows = await cursor.fetchall()
# 在 count_sql 执行前获取列名
@ -210,9 +256,9 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, pa
SELECT COUNT(*)
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
WHERE tcl.risk_flag = %s and tcl.person_id=%s
"""
await cursor.execute(count_sql, (risk_flag,))
await cursor.execute(count_sql, (risk_flag,person_id))
total = (await cursor.fetchone())[0]
# 将元组转换为字典,并格式化 create_time
@ -225,4 +271,4 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, pa
log["create_time"] = log["create_time"].strftime("%Y-%m-%d %H:%M:%S")
logs.append(log)
return logs, total
return [], 0
return [], 0

Loading…
Cancel
Save