You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

282 lines
11 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
pip install aiomysql
"""
import logging
from typing import Optional, Dict, List
from aiomysql import create_pool
from Config.Config import *
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# MySQL 配置
MYSQL_CONFIG = {
"host": MYSQL_HOST,
"port": MYSQL_PORT,
"user": MYSQL_USER,
"password": MYSQL_PASSWORD,
"db": MYSQL_DB_NAME,
"minsize": 1,
"maxsize": 20,
}
# 初始化 MySQL 连接池
async def init_mysql_pool():
return await create_pool(**MYSQL_CONFIG)
# 保存聊天记录到 MySQL
async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration, input_type=1, output_type=1,
input_image_type=0, image_width=0, image_height=0):
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,input_type,output_type,input_image_type,image_width,image_height,create_time) VALUES (%s, %s, %s, %s, %s, %s, %s,%s,%s,%s,NOW())",
(person_id, prompt, result, audio_url, duration, input_type, output_type, input_image_type, image_width,
image_height)
)
await conn.commit()
# 清空表
async def truncate_chat_log(mysql_pool):
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
await cur.execute("TRUNCATE TABLE t_chat_log")
await conn.commit()
logger.info("表 t_chat_log 已清空。")
from aiomysql import DictCursor
# 分页查询聊天记录
async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
"""
根据 person_id 查询聊天记录,并按 id 降序分页
:param mysql_pool: MySQL 连接池
:param person_id: 用户会话 ID
:param page: 当前页码(默认值为 1但会动态计算为最后一页
:param page_size: 每页记录数
:return: 分页数据
"""
if not mysql_pool:
raise ValueError("MySQL 连接池未初始化")
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor(DictCursor) as cur: # 使用 DictCursor
# 查询总记录数
await cur.execute(
"SELECT COUNT(*) FROM t_chat_log WHERE person_id = %s",
(person_id,)
)
total = (await cur.fetchone())['COUNT(*)']
# 计算总页数
total_pages = (total + page_size - 1) // page_size
# 计算偏移量
offset = (page - 1) * page_size
# 查询分页数据,按 id 降序排列
await cur.execute(
"SELECT id, person_id, user_input, model_response, audio_url, duration,input_type,output_type,input_image_type,image_width,image_height, create_time "
"FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(person_id, page_size, offset)
)
records = await cur.fetchall()
# 将查询结果反转,确保最新消息显示在最后
if page==1 and records:
records.reverse()
# 将查询结果转换为字典列表
result = [
{
"id": record['id'],
"person_id": record['person_id'],
"user_input": record['user_input'],
"model_response": record['model_response'],
"audio_url": record['audio_url'],
"duration": record['duration'],
"input_type": record['input_type'],
"output_type": record['output_type'],
"image_width": record['image_width'],
"image_height": record['image_height'],
"input_image_type": record['input_image_type'],
"create_time": record['create_time'].strftime("%Y-%m-%d %H:%M:%S")
}
for record in records
]
return {
"data": result, # 按 id 升序排列的数据
"total": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages
}
# 获取指定会话的最后一条记录的 id
async def get_last_chat_log_id(mysql_pool, person_id):
"""
获取指定会话的最后一条记录的 id
:param mysql_pool: MySQL 连接池
:param session_id: 用户会话 ID
:return: 最后一条记录的 id如果未找到则返回 None
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
await cur.execute(
"SELECT id FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT 1",
(person_id,)
)
result = await cur.fetchone()
return result[0] if result else None
# 更新为危险的记录
async def update_risk(mysql_pool, person_id, risk_memo):
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
# 1. 获取此人员的最后一条记录 id
last_id = await get_last_chat_log_id(mysql_pool, person_id)
if last_id:
# 2. 更新 risk_flag 和 risk_memo
await cur.execute(
"UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s",
(risk_memo.replace('\n', '').replace("NO", ""), last_id)
)
await conn.commit()
logger.info(f"已更新 person_id={person_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。")
else:
logger.warning(f"未找到 person_id={person_id} 的记录。")
# 查询用户信息
async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
"""
根据用户名查询用户信息
:param pool: MySQL 连接池
:param login_name: 用户名
:return: 用户信息(字典形式)
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cursor:
sql = "SELECT * FROM t_base_person WHERE login_name = %s"
await cursor.execute(sql, (login_name,))
row = await cursor.fetchone()
if not row:
return None
# 将元组转换为字典
columns = [column[0] for column in cursor.description]
return dict(zip(columns, row))
# 显示统计分析页面
async def get_chat_logs_summary(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int):
"""
获取聊天记录的统计分析结果
:param mysql_pool: MySQL 连接池
:param risk_flag: 风险标志
:param offset: 偏移量
:param page_size: 每页记录数
:return: 日志列表和总记录数
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cursor:
# 查询符合条件的记录
sql = """
SELECT tbp.*, COUNT(*) AS cnt
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
GROUP BY tcl.person_id
ORDER BY COUNT(*) DESC
LIMIT %s OFFSET %s
"""
await cursor.execute(sql, (risk_flag, page_size, offset))
rows = await cursor.fetchall()
# 获取列名
columns = [column[0] for column in cursor.description]
# 查询总记录数
count_sql = """
SELECT COUNT(DISTINCT tcl.person_id)
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
"""
await cursor.execute(count_sql, (risk_flag,))
total = (await cursor.fetchone())[0]
# 将元组转换为字典
logs = [dict(zip(columns, row)) for row in rows] if rows else []
return logs, total
async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, person_id: str, offset: int, page_size: int) -> (
List[Dict], int):
"""
根据风险标志查询聊天记录
:param mysql_pool: MySQL 连接池
:param risk_flag: 风险标志
:param offset: 分页偏移量
:param page_size: 每页记录数
:return: 聊天记录列表和总记录数
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cursor:
# 查询符合条件的记录
sql = """
SELECT tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration,
tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result,
tbp.person_id, tbp.login_name, tbp.person_name
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s and tcl.person_id=%s ORDER BY TCL.ID DESC
LIMIT %s OFFSET %s
"""
await cursor.execute(sql, (risk_flag, person_id, page_size, offset))
rows = await cursor.fetchall()
# 在 count_sql 执行前获取列名
columns = [column[0] for column in cursor.description]
# 查询总记录数
count_sql = """
SELECT COUNT(*)
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s and tcl.person_id=%s
"""
await cursor.execute(count_sql, (risk_flag, person_id))
total = (await cursor.fetchone())[0]
# 将元组转换为字典,并格式化 create_time
if rows:
logs = []
for row in rows:
log = dict(zip(columns, row))
# 格式化 create_time
if log["create_time"]:
log["create_time"] = log["create_time"].strftime("%Y-%m-%d %H:%M:%S")
logs.append(log)
return logs, total
return [], 0