You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
3.2 KiB

import asyncio
import logging
from aiomysql import create_pool
from WxMini.Milvus.Config.MulvusConfig import *
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# MySQL 配置
MYSQL_CONFIG = {
"host": MYSQL_HOST,
"port": MYSQL_PORT,
"user": MYSQL_USER,
"password": MYSQL_PASSWORD,
"db": MYSQL_DB_NAME,
"minsize": 1,
"maxsize": 20,
}
# 初始化 MySQL 连接池
async def init_mysql_pool():
return await create_pool(**MYSQL_CONFIG)
# 保存聊天记录到 MySQL
async def save_chat_to_mysql(mysql_pool, session_id, prompt, result,audio_url,duration):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO t_chat_log (session_id, user_input, model_response,audio_url,duration,create_time) VALUES (%s, %s, %s, %s, %s,NOW())",
(session_id, prompt, result,audio_url,duration)
)
await conn.commit()
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
# 清空表
async def truncate_chat_log(mysql_pool):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("TRUNCATE TABLE t_chat_log")
await conn.commit()
logger.info("表 t_chat_log 已清空。")
# 分页查询聊天记录
async def get_chat_log_by_session(mysql_pool,session_id, page=1, page_size=10):
"""
根据 session_id 查询聊天记录,并按 id 降序分页
:param session_id: 用户会话 ID
:param page: 当前页码
:param page_size: 每页记录数
:return: 分页数据
"""
if not mysql_pool:
raise ValueError("MySQL 连接池未初始化")
offset = (page - 1) * page_size
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
# 查询总记录数
await cur.execute(
"SELECT COUNT(*) FROM t_chat_log WHERE session_id = %s",
(session_id,)
)
total = (await cur.fetchone())[0]
# 查询分页数据
await cur.execute(
"SELECT id, session_id, user_input, model_response, audio_url, duration, create_time "
"FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(session_id, page_size, offset)
)
records = await cur.fetchall()
# 将查询结果转换为字典列表
result = [
{
"id": record[0],
"session_id": record[1],
"user_input": record[2],
"model_response": record[3],
"audio_url": record[4],
"duration": record[5],
"create_time": record[6].strftime("%Y-%m-%d %H:%M:%S")
}
for record in records
]
return {
"data": result,
"total": total,
"page": page,
"page_size": page_size
}