import logging from typing import Optional, Dict, List from aiomysql import create_pool from WxMini.Milvus.Config.MulvusConfig import * # 配置日志 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # MySQL 配置 MYSQL_CONFIG = { "host": MYSQL_HOST, "port": MYSQL_PORT, "user": MYSQL_USER, "password": MYSQL_PASSWORD, "db": MYSQL_DB_NAME, "minsize": 1, "maxsize": 20, } # 初始化 MySQL 连接池 async def init_mysql_pool(): return await create_pool(**MYSQL_CONFIG) # 保存聊天记录到 MySQL async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration, input_type=1, output_type=1, input_image_type=0, image_width=0, image_height=0): async with mysql_pool.acquire() as conn: await conn.ping() # 重置连接 async with conn.cursor() as cur: await cur.execute( "INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,input_type,output_type,input_image_type,image_width,image_height,create_time) VALUES (%s, %s, %s, %s, %s, %s, %s,%s,%s,%s,NOW())", (person_id, prompt, result, audio_url, duration, input_type, output_type, input_image_type, image_width, image_height) ) await conn.commit() # 清空表 async def truncate_chat_log(mysql_pool): async with mysql_pool.acquire() as conn: await conn.ping() # 重置连接 async with conn.cursor() as cur: await cur.execute("TRUNCATE TABLE t_chat_log") await conn.commit() logger.info("表 t_chat_log 已清空。") from aiomysql import DictCursor # 分页查询聊天记录 async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10): """ 根据 person_id 查询聊天记录,并按 id 降序分页 :param mysql_pool: MySQL 连接池 :param person_id: 用户会话 ID :param page: 当前页码(默认值为 1,但会动态计算为最后一页) :param page_size: 每页记录数 :return: 分页数据 """ if not mysql_pool: raise ValueError("MySQL 连接池未初始化") async with mysql_pool.acquire() as conn: await conn.ping() # 重置连接 async with conn.cursor(DictCursor) as cur: # 使用 DictCursor # 查询总记录数 await cur.execute( "SELECT COUNT(*) FROM t_chat_log WHERE person_id = %s", (person_id,) ) total = (await cur.fetchone())['COUNT(*)'] # 计算总页数 total_pages = (total + page_size - 1) // page_size # 计算偏移量 offset = (page - 1) * page_size # 查询分页数据,按 id 降序排列 await cur.execute( "SELECT id, person_id, user_input, model_response, audio_url, duration,input_type,output_type,input_image_type,image_width,image_height, create_time " "FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s", (person_id, page_size, offset) ) records = await cur.fetchall() # 将查询结果反转,确保最新消息显示在最后 if records: records.reverse() # 将查询结果转换为字典列表 result = [ { "id": record['id'], "person_id": record['person_id'], "user_input": record['user_input'], "model_response": record['model_response'], "audio_url": record['audio_url'], "duration": record['duration'], "input_type": record['input_type'], "output_type": record['output_type'], "image_width": record['image_width'], "image_height": record['image_height'], "input_image_type": record['input_image_type'], "create_time": record['create_time'].strftime("%Y-%m-%d %H:%M:%S") } for record in records ] return { "data": result, # 按 id 升序排列的数据 "total": total, "page": page, "page_size": page_size, "total_pages": total_pages } # 获取指定会话的最后一条记录的 id async def get_last_chat_log_id(mysql_pool, person_id): """ 获取指定会话的最后一条记录的 id :param mysql_pool: MySQL 连接池 :param session_id: 用户会话 ID :return: 最后一条记录的 id,如果未找到则返回 None """ async with mysql_pool.acquire() as conn: await conn.ping() # 重置连接 async with conn.cursor() as cur: await cur.execute( "SELECT id FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT 1", (person_id,) ) result = await cur.fetchone() return result[0] if result else None # 更新为危险的记录 async def update_risk(mysql_pool, person_id, risk_memo): async with mysql_pool.acquire() as conn: await conn.ping() # 重置连接 async with conn.cursor() as cur: # 1. 获取此人员的最后一条记录 id last_id = await get_last_chat_log_id(mysql_pool, person_id) if last_id: # 2. 更新 risk_flag 和 risk_memo await cur.execute( "UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s", (risk_memo.replace('\n', '').replace("NO", ""), last_id) ) await conn.commit() logger.info(f"已更新 person_id={person_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。") else: logger.warning(f"未找到 person_id={person_id} 的记录。") # 查询用户信息 async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]: """ 根据用户名查询用户信息 :param pool: MySQL 连接池 :param login_name: 用户名 :return: 用户信息(字典形式) """ async with mysql_pool.acquire() as conn: await conn.ping() # 重置连接 async with conn.cursor() as cursor: sql = "SELECT * FROM t_base_person WHERE login_name = %s" await cursor.execute(sql, (login_name,)) row = await cursor.fetchone() if not row: return None # 将元组转换为字典 columns = [column[0] for column in cursor.description] return dict(zip(columns, row)) # 显示统计分析页面 async def get_chat_logs_summary(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int): """ 获取聊天记录的统计分析结果 :param mysql_pool: MySQL 连接池 :param risk_flag: 风险标志 :param offset: 偏移量 :param page_size: 每页记录数 :return: 日志列表和总记录数 """ async with mysql_pool.acquire() as conn: await conn.ping() # 重置连接 async with conn.cursor() as cursor: # 查询符合条件的记录 sql = """ SELECT tbp.*, COUNT(*) AS cnt FROM t_chat_log AS tcl INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id WHERE tcl.risk_flag = %s GROUP BY tcl.person_id ORDER BY COUNT(*) DESC LIMIT %s OFFSET %s """ await cursor.execute(sql, (risk_flag, page_size, offset)) rows = await cursor.fetchall() # 获取列名 columns = [column[0] for column in cursor.description] # 查询总记录数 count_sql = """ SELECT COUNT(DISTINCT tcl.person_id) FROM t_chat_log AS tcl INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id WHERE tcl.risk_flag = %s """ await cursor.execute(count_sql, (risk_flag,)) total = (await cursor.fetchone())[0] # 将元组转换为字典 logs = [dict(zip(columns, row)) for row in rows] if rows else [] return logs, total async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, person_id: str, offset: int, page_size: int) -> ( List[Dict], int): """ 根据风险标志查询聊天记录 :param mysql_pool: MySQL 连接池 :param risk_flag: 风险标志 :param offset: 分页偏移量 :param page_size: 每页记录数 :return: 聊天记录列表和总记录数 """ async with mysql_pool.acquire() as conn: await conn.ping() # 重置连接 async with conn.cursor() as cursor: # 查询符合条件的记录 sql = """ SELECT tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration, tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result, tbp.person_id, tbp.login_name, tbp.person_name FROM t_chat_log AS tcl INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id WHERE tcl.risk_flag = %s and tcl.person_id=%s ORDER BY TCL.ID DESC LIMIT %s OFFSET %s """ await cursor.execute(sql, (risk_flag, person_id, page_size, offset)) rows = await cursor.fetchall() # 在 count_sql 执行前获取列名 columns = [column[0] for column in cursor.description] # 查询总记录数 count_sql = """ SELECT COUNT(*) FROM t_chat_log AS tcl INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id WHERE tcl.risk_flag = %s and tcl.person_id=%s """ await cursor.execute(count_sql, (risk_flag, person_id)) total = (await cursor.fetchone())[0] # 将元组转换为字典,并格式化 create_time if rows: logs = [] for row in rows: log = dict(zip(columns, row)) # 格式化 create_time if log["create_time"]: log["create_time"] = log["create_time"].strftime("%Y-%m-%d %H:%M:%S") logs.append(log) return logs, total return [], 0