main
HuangHai 4 months ago
parent 950576e1e9
commit 1fc442a241

@ -276,13 +276,17 @@ async def reply(person_id: str = Form(...),
logger.error(f"查询失败: {e}")
# 在最后增加此人最近几条的交互记录数据
recent_logs = get_chat_log_by_session(app.state.mysql_pool, person_id)
for record in recent_logs:
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
try:
recent_logs = await get_chat_log_by_session(app.state.mysql_pool, person_id)
data = recent_logs["data"]
for record in data:
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
logger.error(f"获取交互记录时出错:{e}")
# 限制历史交互提示词长度
history_prompt = history_prompt[:3000]
# logger.info(f"历史交互提示词: {history_prompt}")
logger.info(f"历史交互提示词: {history_prompt}")
# 调用大模型,将历史交互作为提示词
try:
@ -321,7 +325,7 @@ async def reply(person_id: str = Form(...),
if len(result) > 500:
logger.warning(f"大模型回复被截断,原始长度: {len(result)}")
await asyncio.to_thread(collection_manager.insert_data, entities)
# logger.info("用户输入和大模型反馈已记录到向量数据库。")
logger.info("用户输入和大模型反馈已记录到向量数据库。")
# 调用 TTS 生成 MP3
uuid_str = str(uuid.uuid4())

@ -76,12 +76,11 @@ async def recognize_text(client, pool, person_id, image_url):
# 记录到数据库
try:
await save_chat_to_mysql(pool, person_id, f'![]({image_url})', full_text, "", 0)
await save_chat_to_mysql(pool, person_id, f'![]({image_url})', full_text, "", 0, 2, 2, 1)
except Exception as e:
print(f"记录到数据库时出错:{e}")
async def recognize_content(client, pool, person_id, image_url):
"""
识别图片中的内容流式输出
@ -105,6 +104,6 @@ async def recognize_content(client, pool, person_id, image_url):
# 记录到数据库
try:
await save_chat_to_mysql(pool, person_id, f'![]({image_url})', full_text, "", 0)
await save_chat_to_mysql(pool, person_id, f'![]({image_url})', full_text, "", 0, 2, 2, 2)
except Exception as e:
print(f"记录到数据库时出错:{e}")

@ -26,13 +26,14 @@ async def init_mysql_pool():
# 保存聊天记录到 MySQL
async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration):
async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration, input_type=1, output_type=1,
input_image_type=0):
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,create_time) VALUES (%s, %s, %s, %s, %s,NOW())",
(person_id, prompt, result, audio_url, duration)
"INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,input_type,output_type,input_image_type,create_time) VALUES (%s, %s, %s, %s, %s, %s, %s,%s,NOW())",
(person_id, prompt, result, audio_url, duration, input_type, output_type, input_image_type)
)
await conn.commit()
@ -81,7 +82,7 @@ async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
# 查询分页数据,按 id 降序排列
await cur.execute(
"SELECT id, person_id, user_input, model_response, audio_url, duration, create_time "
"SELECT id, person_id, user_input, model_response, audio_url, duration,input_type,output_type,input_image_type, create_time "
"FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(person_id, page_size, offset)
)
@ -256,7 +257,7 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, person_id: str,
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s and tcl.person_id=%s
"""
await cursor.execute(count_sql, (risk_flag,person_id))
await cursor.execute(count_sql, (risk_flag, person_id))
total = (await cursor.fetchone())[0]
# 将元组转换为字典,并格式化 create_time

Loading…
Cancel
Save