|
|
|
@ -44,18 +44,18 @@ async def truncate_chat_log(mysql_pool):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 分页查询聊天记录
|
|
|
|
|
async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
|
|
|
|
|
async def get_chat_log_by_session(mysql_pool, session_id, page=None, page_size=10):
|
|
|
|
|
"""
|
|
|
|
|
根据 session_id 查询聊天记录,并按 id 降序分页
|
|
|
|
|
根据 session_id 查询聊天记录,并按 id 升序分页
|
|
|
|
|
:param mysql_pool: MySQL 连接池
|
|
|
|
|
:param session_id: 用户会话 ID
|
|
|
|
|
:param page: 当前页码
|
|
|
|
|
:param page: 当前页码(如果为 None,则默认跳转到最后一页)
|
|
|
|
|
:param page_size: 每页记录数
|
|
|
|
|
:return: 分页数据
|
|
|
|
|
"""
|
|
|
|
|
if not mysql_pool:
|
|
|
|
|
raise ValueError("MySQL 连接池未初始化")
|
|
|
|
|
|
|
|
|
|
offset = (page - 1) * page_size
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
|
# 查询总记录数
|
|
|
|
@ -65,10 +65,20 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
|
|
|
|
|
)
|
|
|
|
|
total = (await cur.fetchone())[0]
|
|
|
|
|
|
|
|
|
|
# 查询分页数据,按 id 降序排列
|
|
|
|
|
# 计算总页数
|
|
|
|
|
total_pages = (total + page_size - 1) // page_size
|
|
|
|
|
|
|
|
|
|
# 如果未指定页码,则默认跳转到最后一页
|
|
|
|
|
if page is None:
|
|
|
|
|
page = total_pages
|
|
|
|
|
|
|
|
|
|
# 计算偏移量
|
|
|
|
|
offset = (page - 1) * page_size
|
|
|
|
|
|
|
|
|
|
# 查询分页数据,按 id 升序排列
|
|
|
|
|
await cur.execute(
|
|
|
|
|
"SELECT id, session_id, user_input, model_response, audio_url, duration, create_time "
|
|
|
|
|
"FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
|
|
|
|
|
"FROM t_chat_log WHERE session_id = %s ORDER BY id ASC LIMIT %s OFFSET %s",
|
|
|
|
|
(session_id, page_size, offset)
|
|
|
|
|
)
|
|
|
|
|
records = await cur.fetchall()
|
|
|
|
@ -88,26 +98,40 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"data": result,
|
|
|
|
|
"data": result, # 按 id 升序排列的数据
|
|
|
|
|
"total": total,
|
|
|
|
|
"page": page,
|
|
|
|
|
"page_size": page_size
|
|
|
|
|
"page_size": page_size,
|
|
|
|
|
"total_pages": total_pages
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 更新为危险的记录
|
|
|
|
|
async def update_risk(mysql_pool, session_id, risk_memo):
|
|
|
|
|
# 获取指定会话的最后一条记录的 id
|
|
|
|
|
async def get_last_chat_log_id(mysql_pool, session_id):
|
|
|
|
|
"""
|
|
|
|
|
获取指定会话的最后一条记录的 id
|
|
|
|
|
:param mysql_pool: MySQL 连接池
|
|
|
|
|
:param session_id: 用户会话 ID
|
|
|
|
|
:return: 最后一条记录的 id,如果未找到则返回 None
|
|
|
|
|
"""
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
|
# 1. 获取此人员的最后一条记录 id
|
|
|
|
|
await cur.execute(
|
|
|
|
|
"SELECT id FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT 1",
|
|
|
|
|
(session_id,)
|
|
|
|
|
)
|
|
|
|
|
result = await cur.fetchone()
|
|
|
|
|
return result[0] if result else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 更新为危险的记录
|
|
|
|
|
async def update_risk(mysql_pool, session_id, risk_memo):
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
|
# 1. 获取此人员的最后一条记录 id
|
|
|
|
|
last_id = await get_last_chat_log_id(mysql_pool, session_id)
|
|
|
|
|
|
|
|
|
|
if result:
|
|
|
|
|
last_id = result[0]
|
|
|
|
|
if last_id:
|
|
|
|
|
# 2. 更新 risk_flag 和 risk_memo
|
|
|
|
|
await cur.execute(
|
|
|
|
|
"UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s",
|
|
|
|
|