|
|
|
@ -83,7 +83,7 @@ async def on_session_end(person_id):
|
|
|
|
|
input_word = file.read()
|
|
|
|
|
prompt = (
|
|
|
|
|
"分析用户是否存在心理健康方面的问题:"
|
|
|
|
|
f"参考分类文档内容如下:{input_word},"
|
|
|
|
|
f"参考分类文档内容如下:{input_word},注意:只有情节比较严重的才认为有健康问题,轻微的不算。"
|
|
|
|
|
"如果没有健康问题请回复: OK;否则回复:NO,换行后输出问题类型的名称"
|
|
|
|
|
f"\n\n聊天记录:{history}"
|
|
|
|
|
)
|
|
|
|
@ -268,7 +268,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
# 查询非向量字段
|
|
|
|
|
record = await asyncio.to_thread(collection_manager.query_by_id, hit.id)
|
|
|
|
|
if record:
|
|
|
|
|
logger.info(f"查询到的记录: {record}")
|
|
|
|
|
#logger.info(f"查询到的记录: {record}")
|
|
|
|
|
# 添加历史交互
|
|
|
|
|
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
|
|
|
|
|
except Exception as e:
|
|
|
|
@ -276,7 +276,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
|
|
|
|
|
# 限制历史交互提示词长度
|
|
|
|
|
history_prompt = history_prompt[:2000]
|
|
|
|
|
logger.info(f"历史交互提示词: {history_prompt}")
|
|
|
|
|
#logger.info(f"历史交互提示词: {history_prompt}")
|
|
|
|
|
|
|
|
|
|
# 调用大模型,将历史交互作为提示词
|
|
|
|
|
try:
|
|
|
|
@ -315,7 +315,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
if len(result) > 500:
|
|
|
|
|
logger.warning(f"大模型回复被截断,原始长度: {len(result)}")
|
|
|
|
|
await asyncio.to_thread(collection_manager.insert_data, entities)
|
|
|
|
|
logger.info("用户输入和大模型反馈已记录到向量数据库。")
|
|
|
|
|
#logger.info("用户输入和大模型反馈已记录到向量数据库。")
|
|
|
|
|
|
|
|
|
|
# 调用 TTS 生成 MP3
|
|
|
|
|
uuid_str = str(uuid.uuid4())
|
|
|
|
@ -328,7 +328,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
t = TTS(None) # 传入 None 表示不保存到本地文件
|
|
|
|
|
audio_data, duration = await asyncio.to_thread(t.generate_audio,
|
|
|
|
|
result) # 假设 TTS 类有一个 generate_audio 方法返回音频数据
|
|
|
|
|
print(f"音频时长: {duration} 秒")
|
|
|
|
|
#print(f"音频时长: {duration} 秒")
|
|
|
|
|
|
|
|
|
|
# 将音频数据直接上传到 OSS
|
|
|
|
|
await asyncio.to_thread(upload_mp3_to_oss_from_memory, tts_file, audio_data)
|
|
|
|
@ -339,7 +339,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
|
|
|
|
|
# 记录聊天数据到 MySQL
|
|
|
|
|
await save_chat_to_mysql(app.state.mysql_pool, person_id, prompt, result, url, duration)
|
|
|
|
|
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
|
|
|
|
|
#logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
|
|
|
|
|
|
|
|
|
|
# 调用会话检查机制
|
|
|
|
|
await on_session_end(person_id)
|
|
|
|
@ -368,7 +368,7 @@ async def reply(person_id: str = Form(...),
|
|
|
|
|
@app.get("/aichat/get_chat_log")
|
|
|
|
|
async def get_chat_log(
|
|
|
|
|
person_id: str,
|
|
|
|
|
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1,但会动态计算为最后一页)"),
|
|
|
|
|
page: int = Query(default=1, ge=1, description="当前页码"),
|
|
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数"),
|
|
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
|
|
):
|
|
|
|
@ -387,7 +387,7 @@ async def get_chat_log(
|
|
|
|
|
# 获取风险聊天记录接口
|
|
|
|
|
@app.get("/aichat/get_risk_chat_logs")
|
|
|
|
|
async def get_risk_chat_logs(
|
|
|
|
|
risk_flag: int = Query(..., description="风险标志(1 表示有风险,0 表示无风险)"),
|
|
|
|
|
risk_flag: int = Query(..., description="风险标志(1 表示有风险,0 表示无风险 ,2:处理完毕)"),
|
|
|
|
|
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1)"),
|
|
|
|
|
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数(默认值为 10,最大值为 100)"),
|
|
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
|
|