You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

220 lines
8.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
from typing import Optional, Dict
from aiomysql import create_pool
from WxMini.Milvus.Config.MulvusConfig import *
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# MySQL 配置
MYSQL_CONFIG = {
"host": MYSQL_HOST,
"port": MYSQL_PORT,
"user": MYSQL_USER,
"password": MYSQL_PASSWORD,
"db": MYSQL_DB_NAME,
"minsize": 1,
"maxsize": 20,
}
# 初始化 MySQL 连接池
async def init_mysql_pool():
return await create_pool(**MYSQL_CONFIG)
# 保存聊天记录到 MySQL
async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,create_time) VALUES (%s, %s, %s, %s, %s,NOW())",
(person_id, prompt, result, audio_url, duration)
)
await conn.commit()
# 清空表
async def truncate_chat_log(mysql_pool):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("TRUNCATE TABLE t_chat_log")
await conn.commit()
logger.info("表 t_chat_log 已清空。")
from aiomysql import DictCursor
# 分页查询聊天记录
async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
"""
根据 person_id 查询聊天记录,并按 id 降序分页
:param mysql_pool: MySQL 连接池
:param person_id: 用户会话 ID
:param page: 当前页码(默认值为 1但会动态计算为最后一页
:param page_size: 每页记录数
:return: 分页数据
"""
if not mysql_pool:
raise ValueError("MySQL 连接池未初始化")
async with mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur: # 使用 DictCursor
# 查询总记录数
await cur.execute(
"SELECT COUNT(*) FROM t_chat_log WHERE person_id = %s",
(person_id,)
)
total = (await cur.fetchone())['COUNT(*)']
# 计算总页数
total_pages = (total + page_size - 1) // page_size
# 计算偏移量
offset = (page - 1) * page_size
# 查询分页数据,按 id 降序排列
await cur.execute(
"SELECT id, person_id, user_input, model_response, audio_url, duration, create_time "
"FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(person_id, page_size, offset)
)
records = await cur.fetchall()
# 将查询结果反转,确保最新消息显示在最后
records.reverse()
# 将查询结果转换为字典列表
result = [
{
"id": record['id'],
"person_id": record['person_id'],
"user_input": record['user_input'],
"model_response": record['model_response'],
"audio_url": record['audio_url'],
"duration": record['duration'],
"create_time": record['create_time'].strftime("%Y-%m-%d %H:%M:%S")
}
for record in records
]
return {
"data": result, # 按 id 升序排列的数据
"total": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages
}
# 获取指定会话的最后一条记录的 id
async def get_last_chat_log_id(mysql_pool, person_id):
"""
获取指定会话的最后一条记录的 id
:param mysql_pool: MySQL 连接池
:param session_id: 用户会话 ID
:return: 最后一条记录的 id如果未找到则返回 None
"""
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"SELECT id FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT 1",
(person_id,)
)
result = await cur.fetchone()
return result[0] if result else None
# 更新为危险的记录
async def update_risk(mysql_pool, person_id, risk_memo):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
# 1. 获取此人员的最后一条记录 id
last_id = await get_last_chat_log_id(mysql_pool, person_id)
if last_id:
# 2. 更新 risk_flag 和 risk_memo
await cur.execute(
"UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s",
(risk_memo.replace('\n', '').replace("NO", ""), last_id)
)
await conn.commit()
logger.info(f"已更新 person_id={person_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。")
else:
logger.warning(f"未找到 person_id={person_id} 的记录。")
# 查询有风险的聊天记录
async def get_risk_chat_log_page(mysql_pool, risk_flag, page=1, page_size=10):
"""
查询有风险的聊天记录,并按 id 降序分页
:param mysql_pool: MySQL 连接池
:param risk_flag: 风险标志
:param page: 当前页码
:param page_size: 每页记录数
:return: 分页数据
"""
offset = (page - 1) * page_size
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
# 查询总记录数
await cur.execute(
"SELECT COUNT(*) FROM t_chat_log WHERE risk_flag = %s", (risk_flag,)
)
total = (await cur.fetchone())[0]
logger.info(f"总记录数: {total}")
# 查询分页数据
query = (
"SELECT id, person_id, user_input, model_response, audio_url, duration, create_time, risk_memo "
"FROM t_chat_log WHERE risk_flag = %s ORDER BY id DESC LIMIT %s OFFSET %s"
)
params = (risk_flag, page_size, offset)
logger.debug(f"执行查询: {query % params}") # 打印 SQL 查询
await cur.execute(query, params)
records = await cur.fetchall()
logger.debug(f"查询结果: {records}") # 打印查询结果
# 将查询结果转换为字典列表
result = [
{
"id": record[0],
"person_id": record[1],
"user_input": record[2],
"model_response": record[3],
"audio_url": record[4],
"duration": record[5],
"create_time": record[6].strftime("%Y-%m-%d %H:%M:%S"),
"risk_memo": record[7]
}
for record in records
]
return {
"data": result,
"total": total,
"page": page,
"page_size": page_size
}
# 查询用户信息
async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
"""
根据用户名查询用户信息
:param pool: MySQL 连接池
:param login_name: 用户名
:return: 用户信息(字典形式)
"""
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cursor:
sql = "SELECT * FROM t_base_person WHERE login_name = %s"
await cursor.execute(sql, (login_name,))
row = await cursor.fetchone()
if not row:
return None
# 将元组转换为字典
columns = [column[0] for column in cursor.description]
return dict(zip(columns, row))