main
HuangHai 4 months ago
parent a5a141bd25
commit 9fe67cfa52

@ -83,7 +83,7 @@ async def on_session_end(person_id):
input_word = file.read()
prompt = (
"分析用户是否存在心理健康方面的问题:"
f"参考分类文档内容如下:{input_word},"
f"参考分类文档内容如下:{input_word},注意:只有情节比较严重的才认为有健康问题,轻微的不算。"
"如果没有健康问题请回复: OK否则回复NO换行后输出问题类型的名称"
f"\n\n聊天记录:{history}"
)
@ -268,7 +268,7 @@ async def reply(person_id: str = Form(...),
# 查询非向量字段
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
if record:
logger.info(f"查询到的记录: {record}")
#logger.info(f"查询到的记录: {record}")
# 添加历史交互
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
@ -276,7 +276,7 @@ async def reply(person_id: str = Form(...),
# 限制历史交互提示词长度
history_prompt = history_prompt[:2000]
logger.info(f"历史交互提示词: {history_prompt}")
#logger.info(f"历史交互提示词: {history_prompt}")
# 调用大模型,将历史交互作为提示词
try:
@ -315,7 +315,7 @@ async def reply(person_id: str = Form(...),
if len(result) > 500:
logger.warning(f"大模型回复被截断,原始长度: {len(result)}")
await asyncio.to_thread(collection_manager.insert_data, entities)
logger.info("用户输入和大模型反馈已记录到向量数据库。")
#logger.info("用户输入和大模型反馈已记录到向量数据库。")
# 调用 TTS 生成 MP3
uuid_str = str(uuid.uuid4())
@ -328,7 +328,7 @@ async def reply(person_id: str = Form(...),
t = TTS(None) # 传入 None 表示不保存到本地文件
audio_data, duration = await asyncio.to_thread(t.generate_audio,
result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
print(f"音频时长: {duration}")
#print(f"音频时长: {duration} 秒")
# 将音频数据直接上传到 OSS
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data)
@ -339,7 +339,7 @@ async def reply(person_id: str = Form(...),
# 记录聊天数据到 MySQL
await save_chat_to_mysql(app.state.mysql_pool, person_id, prompt, result, url, duration)
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
#logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
# 调用会话检查机制
await on_session_end(person_id)
@ -368,7 +368,7 @@ async def reply(person_id: str = Form(...),
@app.get("/aichat/get_chat_log")
async def get_chat_log(
person_id: str,
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1但会动态计算为最后一页"),
page: int = Query(default=1, ge=1, description="当前页码"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数"),
current_user: dict = Depends(get_current_user)
):
@ -387,7 +387,7 @@ async def get_chat_log(
# 获取风险聊天记录接口
@app.get("/aichat/get_risk_chat_logs")
async def get_risk_chat_logs(
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险"),
risk_flag: int = Query(..., description="风险标志1 表示有风险0 表示无风险 ,2:处理完毕"),
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10最大值为 100"),
current_user: dict = Depends(get_current_user)

@ -78,7 +78,7 @@ async def get_chat_log_by_session(mysql_pool, current_user, person_id, page=1, p
# 查询分页数据,按 id 降序排列
await cur.execute(
"SELECT id, person_id, user_input, model_response, audio_url, duration, create_time "
"SELECT SQL_NO_CACHE id, person_id, user_input, model_response, audio_url, duration, create_time "
"FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(person_id, page_size, offset)
)
@ -184,7 +184,7 @@ async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, offset: int, pa
async with conn.cursor() as cursor:
# 查询符合条件的记录
sql = """
SELECT
SELECT SQL_NO_CACHE
tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration,
tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result,
tbp.person_id, tbp.login_name, tbp.person_name

Loading…
Cancel
Save