|
|
|
@ -0,0 +1,281 @@
|
|
|
|
|
"""
|
|
|
|
|
pip install aiomysql
|
|
|
|
|
"""
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Optional, Dict, List
|
|
|
|
|
|
|
|
|
|
from aiomysql import create_pool
|
|
|
|
|
from Config.Config import *
|
|
|
|
|
# 配置日志
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
# MySQL 配置
|
|
|
|
|
MYSQL_CONFIG = {
|
|
|
|
|
"host": MYSQL_HOST,
|
|
|
|
|
"port": MYSQL_PORT,
|
|
|
|
|
"user": MYSQL_USER,
|
|
|
|
|
"password": MYSQL_PASSWORD,
|
|
|
|
|
"db": MYSQL_DB_NAME,
|
|
|
|
|
"minsize": 1,
|
|
|
|
|
"maxsize": 20,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化 MySQL 连接池
|
|
|
|
|
async def init_mysql_pool():
|
|
|
|
|
return await create_pool(**MYSQL_CONFIG)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 保存聊天记录到 MySQL
|
|
|
|
|
async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration, input_type=1, output_type=1,
|
|
|
|
|
input_image_type=0, image_width=0, image_height=0):
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
await conn.ping() # 重置连接
|
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
|
await cur.execute(
|
|
|
|
|
"INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,input_type,output_type,input_image_type,image_width,image_height,create_time) VALUES (%s, %s, %s, %s, %s, %s, %s,%s,%s,%s,NOW())",
|
|
|
|
|
(person_id, prompt, result, audio_url, duration, input_type, output_type, input_image_type, image_width,
|
|
|
|
|
image_height)
|
|
|
|
|
)
|
|
|
|
|
await conn.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 清空表
|
|
|
|
|
async def truncate_chat_log(mysql_pool):
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
await conn.ping() # 重置连接
|
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
|
await cur.execute("TRUNCATE TABLE t_chat_log")
|
|
|
|
|
await conn.commit()
|
|
|
|
|
logger.info("表 t_chat_log 已清空。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from aiomysql import DictCursor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 分页查询聊天记录
|
|
|
|
|
async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
|
|
|
|
|
"""
|
|
|
|
|
根据 person_id 查询聊天记录,并按 id 降序分页
|
|
|
|
|
:param mysql_pool: MySQL 连接池
|
|
|
|
|
:param person_id: 用户会话 ID
|
|
|
|
|
:param page: 当前页码(默认值为 1,但会动态计算为最后一页)
|
|
|
|
|
:param page_size: 每页记录数
|
|
|
|
|
:return: 分页数据
|
|
|
|
|
"""
|
|
|
|
|
if not mysql_pool:
|
|
|
|
|
raise ValueError("MySQL 连接池未初始化")
|
|
|
|
|
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
await conn.ping() # 重置连接
|
|
|
|
|
async with conn.cursor(DictCursor) as cur: # 使用 DictCursor
|
|
|
|
|
# 查询总记录数
|
|
|
|
|
await cur.execute(
|
|
|
|
|
"SELECT COUNT(*) FROM t_chat_log WHERE person_id = %s",
|
|
|
|
|
(person_id,)
|
|
|
|
|
)
|
|
|
|
|
total = (await cur.fetchone())['COUNT(*)']
|
|
|
|
|
|
|
|
|
|
# 计算总页数
|
|
|
|
|
total_pages = (total + page_size - 1) // page_size
|
|
|
|
|
|
|
|
|
|
# 计算偏移量
|
|
|
|
|
offset = (page - 1) * page_size
|
|
|
|
|
|
|
|
|
|
# 查询分页数据,按 id 降序排列
|
|
|
|
|
await cur.execute(
|
|
|
|
|
"SELECT id, person_id, user_input, model_response, audio_url, duration,input_type,output_type,input_image_type,image_width,image_height, create_time "
|
|
|
|
|
"FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
|
|
|
|
|
(person_id, page_size, offset)
|
|
|
|
|
)
|
|
|
|
|
records = await cur.fetchall()
|
|
|
|
|
|
|
|
|
|
# 将查询结果反转,确保最新消息显示在最后
|
|
|
|
|
if page==1 and records:
|
|
|
|
|
records.reverse()
|
|
|
|
|
|
|
|
|
|
# 将查询结果转换为字典列表
|
|
|
|
|
result = [
|
|
|
|
|
{
|
|
|
|
|
"id": record['id'],
|
|
|
|
|
"person_id": record['person_id'],
|
|
|
|
|
"user_input": record['user_input'],
|
|
|
|
|
"model_response": record['model_response'],
|
|
|
|
|
"audio_url": record['audio_url'],
|
|
|
|
|
"duration": record['duration'],
|
|
|
|
|
"input_type": record['input_type'],
|
|
|
|
|
"output_type": record['output_type'],
|
|
|
|
|
"image_width": record['image_width'],
|
|
|
|
|
"image_height": record['image_height'],
|
|
|
|
|
"input_image_type": record['input_image_type'],
|
|
|
|
|
"create_time": record['create_time'].strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
}
|
|
|
|
|
for record in records
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"data": result, # 按 id 升序排列的数据
|
|
|
|
|
"total": total,
|
|
|
|
|
"page": page,
|
|
|
|
|
"page_size": page_size,
|
|
|
|
|
"total_pages": total_pages
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取指定会话的最后一条记录的 id
|
|
|
|
|
async def get_last_chat_log_id(mysql_pool, person_id):
|
|
|
|
|
"""
|
|
|
|
|
获取指定会话的最后一条记录的 id
|
|
|
|
|
:param mysql_pool: MySQL 连接池
|
|
|
|
|
:param session_id: 用户会话 ID
|
|
|
|
|
:return: 最后一条记录的 id,如果未找到则返回 None
|
|
|
|
|
"""
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
await conn.ping() # 重置连接
|
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
|
await cur.execute(
|
|
|
|
|
"SELECT id FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT 1",
|
|
|
|
|
(person_id,)
|
|
|
|
|
)
|
|
|
|
|
result = await cur.fetchone()
|
|
|
|
|
return result[0] if result else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 更新为危险的记录
|
|
|
|
|
async def update_risk(mysql_pool, person_id, risk_memo):
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
await conn.ping() # 重置连接
|
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
|
# 1. 获取此人员的最后一条记录 id
|
|
|
|
|
last_id = await get_last_chat_log_id(mysql_pool, person_id)
|
|
|
|
|
|
|
|
|
|
if last_id:
|
|
|
|
|
# 2. 更新 risk_flag 和 risk_memo
|
|
|
|
|
await cur.execute(
|
|
|
|
|
"UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s",
|
|
|
|
|
(risk_memo.replace('\n', '').replace("NO", ""), last_id)
|
|
|
|
|
)
|
|
|
|
|
await conn.commit()
|
|
|
|
|
logger.info(f"已更新 person_id={person_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。")
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(f"未找到 person_id={person_id} 的记录。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 查询用户信息
|
|
|
|
|
async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
|
|
|
|
|
"""
|
|
|
|
|
根据用户名查询用户信息
|
|
|
|
|
:param pool: MySQL 连接池
|
|
|
|
|
:param login_name: 用户名
|
|
|
|
|
:return: 用户信息(字典形式)
|
|
|
|
|
"""
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
await conn.ping() # 重置连接
|
|
|
|
|
async with conn.cursor() as cursor:
|
|
|
|
|
sql = "SELECT * FROM t_base_person WHERE login_name = %s"
|
|
|
|
|
await cursor.execute(sql, (login_name,))
|
|
|
|
|
row = await cursor.fetchone()
|
|
|
|
|
if not row:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# 将元组转换为字典
|
|
|
|
|
columns = [column[0] for column in cursor.description]
|
|
|
|
|
return dict(zip(columns, row))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 显示统计分析页面
|
|
|
|
|
async def get_chat_logs_summary(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int):
|
|
|
|
|
"""
|
|
|
|
|
获取聊天记录的统计分析结果
|
|
|
|
|
:param mysql_pool: MySQL 连接池
|
|
|
|
|
:param risk_flag: 风险标志
|
|
|
|
|
:param offset: 偏移量
|
|
|
|
|
:param page_size: 每页记录数
|
|
|
|
|
:return: 日志列表和总记录数
|
|
|
|
|
"""
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
await conn.ping() # 重置连接
|
|
|
|
|
async with conn.cursor() as cursor:
|
|
|
|
|
# 查询符合条件的记录
|
|
|
|
|
sql = """
|
|
|
|
|
SELECT tbp.*, COUNT(*) AS cnt
|
|
|
|
|
FROM t_chat_log AS tcl
|
|
|
|
|
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
|
|
|
|
|
WHERE tcl.risk_flag = %s
|
|
|
|
|
GROUP BY tcl.person_id
|
|
|
|
|
ORDER BY COUNT(*) DESC
|
|
|
|
|
LIMIT %s OFFSET %s
|
|
|
|
|
"""
|
|
|
|
|
await cursor.execute(sql, (risk_flag, page_size, offset))
|
|
|
|
|
rows = await cursor.fetchall()
|
|
|
|
|
|
|
|
|
|
# 获取列名
|
|
|
|
|
columns = [column[0] for column in cursor.description]
|
|
|
|
|
|
|
|
|
|
# 查询总记录数
|
|
|
|
|
count_sql = """
|
|
|
|
|
SELECT COUNT(DISTINCT tcl.person_id)
|
|
|
|
|
FROM t_chat_log AS tcl
|
|
|
|
|
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
|
|
|
|
|
WHERE tcl.risk_flag = %s
|
|
|
|
|
"""
|
|
|
|
|
await cursor.execute(count_sql, (risk_flag,))
|
|
|
|
|
total = (await cursor.fetchone())[0]
|
|
|
|
|
|
|
|
|
|
# 将元组转换为字典
|
|
|
|
|
logs = [dict(zip(columns, row)) for row in rows] if rows else []
|
|
|
|
|
|
|
|
|
|
return logs, total
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, person_id: str, offset: int, page_size: int) -> (
|
|
|
|
|
List[Dict], int):
|
|
|
|
|
"""
|
|
|
|
|
根据风险标志查询聊天记录
|
|
|
|
|
:param mysql_pool: MySQL 连接池
|
|
|
|
|
:param risk_flag: 风险标志
|
|
|
|
|
:param offset: 分页偏移量
|
|
|
|
|
:param page_size: 每页记录数
|
|
|
|
|
:return: 聊天记录列表和总记录数
|
|
|
|
|
"""
|
|
|
|
|
async with mysql_pool.acquire() as conn:
|
|
|
|
|
await conn.ping() # 重置连接
|
|
|
|
|
async with conn.cursor() as cursor:
|
|
|
|
|
# 查询符合条件的记录
|
|
|
|
|
sql = """
|
|
|
|
|
SELECT tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration,
|
|
|
|
|
tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result,
|
|
|
|
|
tbp.person_id, tbp.login_name, tbp.person_name
|
|
|
|
|
FROM t_chat_log AS tcl
|
|
|
|
|
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
|
|
|
|
|
WHERE tcl.risk_flag = %s and tcl.person_id=%s ORDER BY TCL.ID DESC
|
|
|
|
|
LIMIT %s OFFSET %s
|
|
|
|
|
"""
|
|
|
|
|
await cursor.execute(sql, (risk_flag, person_id, page_size, offset))
|
|
|
|
|
rows = await cursor.fetchall()
|
|
|
|
|
|
|
|
|
|
# 在 count_sql 执行前获取列名
|
|
|
|
|
columns = [column[0] for column in cursor.description]
|
|
|
|
|
|
|
|
|
|
# 查询总记录数
|
|
|
|
|
count_sql = """
|
|
|
|
|
SELECT COUNT(*)
|
|
|
|
|
FROM t_chat_log AS tcl
|
|
|
|
|
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
|
|
|
|
|
WHERE tcl.risk_flag = %s and tcl.person_id=%s
|
|
|
|
|
"""
|
|
|
|
|
await cursor.execute(count_sql, (risk_flag, person_id))
|
|
|
|
|
total = (await cursor.fetchone())[0]
|
|
|
|
|
|
|
|
|
|
# 将元组转换为字典,并格式化 create_time
|
|
|
|
|
if rows:
|
|
|
|
|
logs = []
|
|
|
|
|
for row in rows:
|
|
|
|
|
log = dict(zip(columns, row))
|
|
|
|
|
# 格式化 create_time
|
|
|
|
|
if log["create_time"]:
|
|
|
|
|
log["create_time"] = log["create_time"].strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
logs.append(log)
|
|
|
|
|
return logs, total
|
|
|
|
|
return [], 0
|