main
HuangHai 4 months ago
parent cf6dd8b28e
commit f09d0809a1

@ -13,7 +13,7 @@ from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory
from WxMini.Utils.TtsUtil import TTS
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, update_risk, \
get_risk_chat_log_page, get_last_chat_log_id, get_user_by_login_name
get_last_chat_log_id, get_user_by_login_name, get_chat_logs_by_risk_flag
from WxMini.Utils.EmbeddingUtil import text_to_embedding
# 配置日志
@ -309,25 +309,39 @@ async def get_chat_log(
result = await get_chat_log_by_session(app.state.mysql_pool, person_id, page, page_size)
return result
@app.get("/aichat/get_risk_page")
async def get_risk_page(
risk_flag: int = Query(default=1, ge=1, description="1有风险0无风险2:有风险但已处理"),
page: int = Query(default=1, ge=1, description="当前页码"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数")
# 获取风险聊天记录接口
@app.get("/aichat/get_risk_chat_logs")
async def get_risk_chat_logs(
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险)"),
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100")
):
"""
查询有风险的聊天记录并按 id 降序分页
获取聊天记录支持分页和风险标志过滤
:param risk_flag: 风险标志
:param page: 当前页码
:param page_size: 每页记录数
:return: 分页数据
"""
try:
result = await get_risk_chat_log_page(app.state.mysql_pool, risk_flag, page, page_size)
return result
except Exception as e:
logger.error(f"查询有风险的聊天记录失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"查询有风险的聊天记录失败: {str(e)}")
# 计算分页偏移量
offset = (page - 1) * page_size
# 调用 get_chat_logs_by_risk_flag 方法
logs, total = await get_chat_logs_by_risk_flag(app.state.mysql_pool, risk_flag, offset, page_size)
if not logs:
raise HTTPException(status_code=404, detail="未找到符合条件的记录")
# 返回分页数据
return {
"code": 200,
"message": "查询成功",
"data": {
"total": total,
"page": page,
"page_size": page_size,
"logs": logs
}
}
# 运行 FastAPI 应用

@ -1,5 +1,5 @@
import logging
from typing import Optional, Dict
from typing import Optional, Dict, List
from aiomysql import create_pool
from WxMini.Milvus.Config.MulvusConfig import *
@ -147,59 +147,6 @@ async def update_risk(mysql_pool, person_id, risk_memo):
logger.warning(f"未找到 person_id={person_id} 的记录。")
# 查询有风险的聊天记录
async def get_risk_chat_log_page(mysql_pool, risk_flag, page=1, page_size=10):
"""
查询有风险的聊天记录并按 id 降序分页
:param mysql_pool: MySQL 连接池
:param risk_flag: 风险标志
:param page: 当前页码
:param page_size: 每页记录数
:return: 分页数据
"""
offset = (page - 1) * page_size
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
# 查询总记录数
await cur.execute(
"SELECT COUNT(*) FROM t_chat_log WHERE risk_flag = %s", (risk_flag,)
)
total = (await cur.fetchone())[0]
logger.info(f"总记录数: {total}")
# 查询分页数据
query = (
"SELECT id, person_id, user_input, model_response, audio_url, duration, create_time, risk_memo "
"FROM t_chat_log WHERE risk_flag = %s ORDER BY id DESC LIMIT %s OFFSET %s"
)
params = (risk_flag, page_size, offset)
logger.debug(f"执行查询: {query % params}") # 打印 SQL 查询
await cur.execute(query, params)
records = await cur.fetchall()
logger.debug(f"查询结果: {records}") # 打印查询结果
# 将查询结果转换为字典列表
result = [
{
"id": record[0],
"person_id": record[1],
"user_input": record[2],
"model_response": record[3],
"audio_url": record[4],
"duration": record[5],
"create_time": record[6].strftime("%Y-%m-%d %H:%M:%S"),
"risk_memo": record[7]
}
for record in records
]
return {
"data": result,
"total": total,
"page": page,
"page_size": page_size
}
# 查询用户信息
async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
"""
@ -218,4 +165,47 @@ async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
# 将元组转换为字典
columns = [column[0] for column in cursor.description]
return dict(zip(columns, row))
return dict(zip(columns, row))
# 查询聊天记录
async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int):
"""
根据风险标志查询聊天记录
:param pool: MySQL 连接池
:param risk_flag: 风险标志
:param offset: 分页偏移量
:param page_size: 每页记录数
:return: 聊天记录列表和总记录数
"""
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cursor:
# 查询符合条件的记录
sql = """
SELECT
tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration,
tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result,
tbp.*
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
LIMIT %s OFFSET %s
"""
await cursor.execute(sql, (risk_flag, page_size, offset))
rows = await cursor.fetchall()
# 查询总记录数
count_sql = """
SELECT COUNT(*)
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
"""
await cursor.execute(count_sql, (risk_flag,))
total = (await cursor.fetchone())[0]
# 将元组转换为字典
if rows:
columns = [column[0] for column in cursor.description]
logs = [dict(zip(columns, row)) for row in rows]
return logs, total
return [], 0
Loading…
Cancel
Save