main
HuangHai 4 months ago
parent a5a141bd25
commit 9fe67cfa52

@ -83,7 +83,7 @@ async def on_session_end(person_id):
input_word = file.read() input_word = file.read()
prompt = ( prompt = (
"分析用户是否存在心理健康方面的问题:" "分析用户是否存在心理健康方面的问题:"
f"参考分类文档内容如下:{input_word}," f"参考分类文档内容如下:{input_word},注意:只有情节比较严重的才认为有健康问题,轻微的不算。"
"如果没有健康问题请回复: OK否则回复NO换行后输出问题类型的名称" "如果没有健康问题请回复: OK否则回复NO换行后输出问题类型的名称"
f"\n\n聊天记录:{history}" f"\n\n聊天记录:{history}"
) )
@ -268,7 +268,7 @@ async def reply(person_id: str = Form(...),
# 查询非向量字段 # 查询非向量字段
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id) record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
if record: if record:
logger.info(f"查询到的记录: {record}") #logger.info(f"查询到的记录: {record}")
# 添加历史交互 # 添加历史交互
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n" history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e: except Exception as e:
@ -276,7 +276,7 @@ async def reply(person_id: str = Form(...),
# 限制历史交互提示词长度 # 限制历史交互提示词长度
history_prompt = history_prompt[:2000] history_prompt = history_prompt[:2000]
logger.info(f"历史交互提示词: {history_prompt}") #logger.info(f"历史交互提示词: {history_prompt}")
# 调用大模型,将历史交互作为提示词 # 调用大模型,将历史交互作为提示词
try: try:
@ -315,7 +315,7 @@ async def reply(person_id: str = Form(...),
if len(result) > 500: if len(result) > 500:
logger.warning(f"大模型回复被截断,原始长度: {len(result)}") logger.warning(f"大模型回复被截断,原始长度: {len(result)}")
await asyncio.to_thread(collection_manager.insert_data, entities) await asyncio.to_thread(collection_manager.insert_data, entities)
logger.info("用户输入和大模型反馈已记录到向量数据库。") #logger.info("用户输入和大模型反馈已记录到向量数据库。")
# 调用 TTS 生成 MP3 # 调用 TTS 生成 MP3
uuid_str = str(uuid.uuid4()) uuid_str = str(uuid.uuid4())
@ -328,7 +328,7 @@ async def reply(person_id: str = Form(...),
t = TTS(None) # 传入 None 表示不保存到本地文件 t = TTS(None) # 传入 None 表示不保存到本地文件
audio_data, duration = await asyncio.to_thread(t.generate_audio, audio_data, duration = await asyncio.to_thread(t.generate_audio,
result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据 result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
print(f"音频时长: {duration}") #print(f"音频时长: {duration} 秒")
# 将音频数据直接上传到 OSS # 将音频数据直接上传到 OSS
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data) await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data)
@ -339,7 +339,7 @@ async def reply(person_id: str = Form(...),
# 记录聊天数据到 MySQL # 记录聊天数据到 MySQL
await save_chat_to_mysql(app.state.mysql_pool, person_id, prompt, result, url, duration) await save_chat_to_mysql(app.state.mysql_pool, person_id, prompt, result, url, duration)
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。") #logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
# 调用会话检查机制 # 调用会话检查机制
await on_session_end(person_id) await on_session_end(person_id)
@ -368,7 +368,7 @@ async def reply(person_id: str = Form(...),
@app.get("/aichat/get_chat_log") @app.get("/aichat/get_chat_log")
async def get_chat_log( async def get_chat_log(
person_id: str, person_id: str,
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1但会动态计算为最后一页"), page: int = Query(default=1, ge=1, description="当前页码"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数"), page_size: int = Query(default=10, ge=1, le=100, description="每页记录数"),
current_user: dict = Depends(get_current_user) current_user: dict = Depends(get_current_user)
): ):
@ -387,7 +387,7 @@ async def get_chat_log(
# 获取风险聊天记录接口 # 获取风险聊天记录接口
@app.get("/aichat/get_risk_chat_logs") @app.get("/aichat/get_risk_chat_logs")
async def get_risk_chat_logs( async def get_risk_chat_logs(
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险"), risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险 ,2:处理完毕"),
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"), page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100"), page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100"),
current_user: dict = Depends(get_current_user) current_user: dict = Depends(get_current_user)

@ -78,7 +78,7 @@ async def get_chat_log_by_session(mysql_pool, current_user, person_id, page=1, p
# 查询分页数据,按 id 降序排列 # 查询分页数据,按 id 降序排列
await cur.execute( await cur.execute(
"SELECT id, person_id, user_input, model_response, audio_url, duration, create_time " "SELECT SQL_NO_CACHE id, person_id, user_input, model_response, audio_url, duration, create_time "
"FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s", "FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(person_id, page_size, offset) (person_id, page_size, offset)
) )
@ -184,7 +184,7 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, pa
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
# 查询符合条件的记录 # 查询符合条件的记录
sql = """ sql = """
SELECT SELECT SQL_NO_CACHE
tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration, tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration,
tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result, tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result,
tbp.person_id, tbp.login_name, tbp.person_name tbp.person_id, tbp.login_name, tbp.person_name

Loading…
Cancel
Save