You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

220 lines
8.0 KiB

4 months ago
import logging
4 months ago
from typing import Optional, Dict
4 months ago
from aiomysql import create_pool
from WxMini.Milvus.Config.MulvusConfig import *
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# MySQL 配置
MYSQL_CONFIG = {
"host": MYSQL_HOST,
"port": MYSQL_PORT,
"user": MYSQL_USER,
"password": MYSQL_PASSWORD,
"db": MYSQL_DB_NAME,
"minsize": 1,
"maxsize": 20,
}
# 初始化 MySQL 连接池
async def init_mysql_pool():
return await create_pool(**MYSQL_CONFIG)
# 保存聊天记录到 MySQL
4 months ago
async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration):
4 months ago
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
4 months ago
"INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,create_time) VALUES (%s, %s, %s, %s, %s,NOW())",
(person_id, prompt, result, audio_url, duration)
4 months ago
)
await conn.commit()
4 months ago
4 months ago
4 months ago
# 清空表
async def truncate_chat_log(mysql_pool):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("TRUNCATE TABLE t_chat_log")
await conn.commit()
4 months ago
logger.info("表 t_chat_log 已清空。")
4 months ago
4 months ago
from aiomysql import DictCursor
4 months ago
# 分页查询聊天记录
4 months ago
async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
4 months ago
"""
4 months ago
根据 person_id 查询聊天记录并按 id 降序分页
4 months ago
:param mysql_pool: MySQL 连接池
4 months ago
:param person_id: 用户会话 ID
4 months ago
:param page: 当前页码默认值为 1但会动态计算为最后一页
4 months ago
:param page_size: 每页记录数
:return: 分页数据
"""
if not mysql_pool:
raise ValueError("MySQL 连接池未初始化")
async with mysql_pool.acquire() as conn:
4 months ago
async with conn.cursor(DictCursor) as cur: # 使用 DictCursor
4 months ago
# 查询总记录数
await cur.execute(
4 months ago
"SELECT COUNT(*) FROM t_chat_log WHERE person_id = %s",
(person_id,)
4 months ago
)
4 months ago
total = (await cur.fetchone())['COUNT(*)']
4 months ago
4 months ago
# 计算总页数
total_pages = (total + page_size - 1) // page_size
# 计算偏移量
offset = (page - 1) * page_size
4 months ago
# 查询分页数据,按 id 降序排列
4 months ago
await cur.execute(
4 months ago
"SELECT id, person_id, user_input, model_response, audio_url, duration, create_time "
"FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(person_id, page_size, offset)
4 months ago
)
records = await cur.fetchall()
4 months ago
# 将查询结果反转,确保最新消息显示在最后
records.reverse()
4 months ago
# 将查询结果转换为字典列表
result = [
{
4 months ago
"id": record['id'],
4 months ago
"person_id": record['person_id'],
4 months ago
"user_input": record['user_input'],
"model_response": record['model_response'],
"audio_url": record['audio_url'],
"duration": record['duration'],
"create_time": record['create_time'].strftime("%Y-%m-%d %H:%M:%S")
4 months ago
}
for record in records
]
return {
4 months ago
"data": result, # 按 id 升序排列的数据
4 months ago
"total": total,
"page": page,
4 months ago
"page_size": page_size,
"total_pages": total_pages
4 months ago
}
4 months ago
# 获取指定会话的最后一条记录的 id
4 months ago
async def get_last_chat_log_id(mysql_pool, person_id):
4 months ago
"""
获取指定会话的最后一条记录的 id
:param mysql_pool: MySQL 连接池
:param session_id: 用户会话 ID
:return: 最后一条记录的 id如果未找到则返回 None
"""
4 months ago
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
4 months ago
"SELECT id FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT 1",
(person_id,)
4 months ago
)
result = await cur.fetchone()
4 months ago
return result[0] if result else None
# 更新为危险的记录
4 months ago
async def update_risk(mysql_pool, person_id, risk_memo):
4 months ago
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
# 1. 获取此人员的最后一条记录 id
4 months ago
last_id = await get_last_chat_log_id(mysql_pool, person_id)
4 months ago
4 months ago
if last_id:
4 months ago
# 2. 更新 risk_flag 和 risk_memo
await cur.execute(
"UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s",
(risk_memo.replace('\n', '').replace("NO", ""), last_id)
)
await conn.commit()
4 months ago
logger.info(f"已更新 person_id={person_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。")
4 months ago
else:
4 months ago
logger.warning(f"未找到 person_id={person_id} 的记录。")
4 months ago
4 months ago
# 查询有风险的聊天记录
4 months ago
async def get_risk_chat_log_page(mysql_pool, risk_flag, page=1, page_size=10):
"""
查询有风险的聊天记录并按 id 降序分页
:param mysql_pool: MySQL 连接池
4 months ago
:param risk_flag: 风险标志
4 months ago
:param page: 当前页码
:param page_size: 每页记录数
:return: 分页数据
"""
offset = (page - 1) * page_size
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
# 查询总记录数
await cur.execute(
4 months ago
"SELECT COUNT(*) FROM t_chat_log WHERE risk_flag = %s", (risk_flag,)
4 months ago
)
total = (await cur.fetchone())[0]
4 months ago
logger.info(f"总记录数: {total}")
4 months ago
# 查询分页数据
4 months ago
query = (
4 months ago
"SELECT id, person_id, user_input, model_response, audio_url, duration, create_time, risk_memo "
4 months ago
"FROM t_chat_log WHERE risk_flag = %s ORDER BY id DESC LIMIT %s OFFSET %s"
4 months ago
)
4 months ago
params = (risk_flag, page_size, offset)
4 months ago
logger.debug(f"执行查询: {query % params}") # 打印 SQL 查询
4 months ago
await cur.execute(query, params)
4 months ago
records = await cur.fetchall()
4 months ago
logger.debug(f"查询结果: {records}") # 打印查询结果
4 months ago
# 将查询结果转换为字典列表
result = [
{
"id": record[0],
4 months ago
"person_id": record[1],
4 months ago
"user_input": record[2],
"model_response": record[3],
"audio_url": record[4],
"duration": record[5],
"create_time": record[6].strftime("%Y-%m-%d %H:%M:%S"),
"risk_memo": record[7]
}
for record in records
]
return {
"data": result,
"total": total,
"page": page,
"page_size": page_size
4 months ago
}
# 查询用户信息
async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
"""
根据用户名查询用户信息
:param pool: MySQL 连接池
:param login_name: 用户名
:return: 用户信息字典形式
"""
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cursor:
sql = "SELECT * FROM t_base_person WHERE login_name = %s"
await cursor.execute(sql, (login_name,))
row = await cursor.fetchone()
if not row:
return None
# 将元组转换为字典
columns = [column[0] for column in cursor.description]
return dict(zip(columns, row))