You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
4.4 KiB

import logging
from aiomysql import create_pool
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# 全局 MySQL 连接池
mysql_pool = None
# MySQL 配置
MYSQL_CONFIG = {
"host": "10.10.14.203",
"port": 3306,
"user": "root",
"password": "Password123@mysql",
"db": "xiaozhi_esp32_server",
"minsize": 1,
"maxsize": 20,
"autocommit": True,
"charset": "utf8mb4",
}
async def init_mysql_pool():
"""初始化 MySQL 连接池"""
global mysql_pool
try:
mysql_pool = await create_pool(**MYSQL_CONFIG)
logger.info("MySQL连接池创建成功")
except Exception as e:
logger.error(f"创建MySQL连接池失败: {str(e)}")
raise
async def close_mysql_pool():
"""关闭 MySQL 连接池"""
global mysql_pool
if mysql_pool:
mysql_pool.close()
await mysql_pool.wait_closed()
logger.info("MySQL连接池已关闭")
async def save_chat_history(chat_text, chat_wav, ai_text, person_id, person_name):
"""保存聊天记录"""
try:
async with mysql_pool.acquire() as conn:
await conn.ping()
async with conn.cursor() as cur:
sql = """
INSERT INTO t_chat_history
(chat_text, chat_wav, ai_text, person_id, person_name)
VALUES (%s, %s, %s, %s, %s)
"""
await cur.execute(sql, (chat_text, chat_wav, ai_text, person_id, person_name))
logger.info(f"成功保存聊天记录: {person_name}")
return cur.lastrowid
except Exception as e:
logger.error(f"保存聊天记录失败: {str(e)}")
raise
async def update_chat_history(chat_id, chat_text=None, chat_wav=None, ai_text=None, person_name=None):
"""更新聊天记录"""
try:
async with mysql_pool.acquire() as conn:
await conn.ping()
async with conn.cursor() as cur:
updates = []
params = []
if chat_text is not None:
updates.append("chat_text = %s")
params.append(chat_text)
if chat_wav is not None:
updates.append("chat_wav = %s")
params.append(chat_wav)
if ai_text is not None:
updates.append("ai_text = %s")
params.append(ai_text)
if person_name is not None:
updates.append("person_name = %s")
params.append(person_name)
if not updates:
logger.warning("没有提供要更新的字段")
return
params.append(chat_id)
sql = f"UPDATE t_chat_history SET {', '.join(updates)} WHERE id = %s"
await cur.execute(sql, params)
logger.info(f"成功更新聊天记录 ID: {chat_id}")
except Exception as e:
logger.error(f"更新聊天记录失败: {str(e)}")
raise
async def delete_chat_history(chat_id):
"""删除聊天记录"""
try:
async with mysql_pool.acquire() as conn:
await conn.ping()
async with conn.cursor() as cur:
sql = "DELETE FROM t_chat_history WHERE id = %s"
await cur.execute(sql, (chat_id,))
logger.info(f"成功删除聊天记录 ID: {chat_id}")
except Exception as e:
logger.error(f"删除聊天记录失败: {str(e)}")
raise
async def get_chat_history(person_id=None, limit=10):
"""获取聊天记录"""
try:
async with mysql_pool.acquire() as conn:
await conn.ping()
async with conn.cursor() as cur:
if person_id:
sql = "SELECT * FROM t_chat_history WHERE person_id = %s ORDER BY id DESC LIMIT %s"
await cur.execute(sql, (person_id, limit))
else:
sql = "SELECT * FROM t_chat_history ORDER BY id DESC LIMIT %s"
await cur.execute(sql, (limit,))
result = await cur.fetchall()
return result
except Exception as e:
logger.error(f"获取聊天记录失败: {str(e)}")
raise