main
HuangHai 4 months ago
parent 9690580a76
commit 657ec149a8

@ -13,7 +13,7 @@ from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory
from WxMini.Utils.TtsUtil import TTS
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, update_risk, \
get_risk_chat_log_page
get_risk_chat_log_page, get_last_chat_log_id
from WxMini.Utils.EmbeddingUtil import text_to_embedding
# 配置日志
@ -47,13 +47,26 @@ async def lifespan(app: FastAPI):
# 会话结束后,调用检查方法,判断是不是有需要介入的问题出现
async def on_session_end(session_id):
# 获取聊天记录
result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page=1, page_size=1)
# 获取最后一条聊天记录
last_id = await get_last_chat_log_id(app.state.mysql_pool, session_id)
if last_id:
# 查询最后一条记录的详细信息
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"SELECT user_input, model_response FROM t_chat_log WHERE id = %s",
(last_id,)
)
last_record = await cur.fetchone()
if last_record:
# 拼接历史聊天记录
history = f"问题:{last_record['user_input']}\n回答:{last_record['model_response']}"
else:
history = "无聊天记录"
else:
history = "无聊天记录"
# 拼接历史聊天记录
history = ""
for row in result['data']:
history = f"{history}\n问题:{row['user_input']}\n回答:{row['model_response']}"
# 将历史聊天记录发给大模型,让它帮我分析一下
prompt = (
@ -225,27 +238,25 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
milvus_pool.release_connection(connection)
# 获取聊天记录
# 获取聊天记录
@app.get("/aichat/get_chat_log")
async def get_chat_log(
session_id: str,
page: int = Query(default=1, ge=1, description="当前页码"),
page: int = Query(default=None, ge=1, description="当前页码(如果为 None则默认跳转到最后一页"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数")
):
"""
根据 session_id 查询聊天记录并按 id 降序分页
获取指定会话的聊天记录默认返回最新的记录最后一页
:param session_id: 用户会话 ID
:param page: 当前页码
:param page: 当前页码如果为 None则默认跳转到最后一页
:param page_size: 每页记录数
:return: 分页数据
"""
try:
result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page, page_size)
return result
except Exception as e:
logger.error(f"查询聊天记录失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"查询聊天记录失败: {str(e)}")
# 调用 get_chat_log_by_session 方法
result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page, page_size)
return result
@app.get("/aichat/get_risk_page")
async def get_risk_page(

@ -1,67 +0,0 @@
import asyncio
from openai import AsyncOpenAI
from WxMini.Milvus.Config.MulvusConfig import *
from WxMini.Utils.MySQLUtil import init_mysql_pool, get_chat_log_by_session
# 初始化异步 OpenAI 客户端
client = AsyncOpenAI(
api_key=MODEL_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
async def main():
# 哪个人员
session_id = 1
# 哪一页
page = 1
# 一页多少个
page_size = 100
# 初始化 MySQL 连接池
mysql_pool = await init_mysql_pool()
# 调用
result = await get_chat_log_by_session(mysql_pool, session_id, page, page_size)
# 我只关心 user_input 与 model_response
# 把这些拼接出问题与回答
history = ""
for row in result['data']: # 注意result 是一个字典,包含 'data' 字段
user_input = row['user_input']
model_response = row['model_response']
history = f"{history}\n问题:{user_input}\n回答:{model_response}"
# 将历史聊天记录发给大模型,让它帮我分析一下
prompt = (
"我将把用户与AI大模型交流的记录发给你帮我分析一下这个用户是否存在心理健康方面的问题"
"参考1、PHQ-9抑郁症筛查量表和2、Beck自杀意念评量表BSI-CV"
"如果没有健康问题请回复: OK否则回复NO换行后再输出是什么问题。"
f"\n\n历史聊天记录:{history}"
)
# 调用大模型进行分析
try:
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "你是一个心理健康分析助手,负责分析用户的心理健康状况。"},
{"role": "user", "content": prompt}
],
max_tokens=1000
)
# 提取大模型的回复
if response.choices and response.choices[0].message.content:
analysis_result = response.choices[0].message.content.strip()
print("大模型分析结果:")
print(analysis_result)
else:
print("大模型未返回有效结果。")
except Exception as e:
print(f"调用大模型失败: {str(e)}")
# 关闭连接池
mysql_pool.close()
await mysql_pool.wait_closed()
if __name__ == '__main__':
asyncio.run(main())

@ -44,18 +44,18 @@ async def truncate_chat_log(mysql_pool):
# 分页查询聊天记录
async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
async def get_chat_log_by_session(mysql_pool, session_id, page=None, page_size=10):
"""
根据 session_id 查询聊天记录并按 id 降序分页
根据 session_id 查询聊天记录并按 id 升序分页
:param mysql_pool: MySQL 连接池
:param session_id: 用户会话 ID
:param page: 当前页码
:param page: 当前页码如果为 None则默认跳转到最后一页
:param page_size: 每页记录数
:return: 分页数据
"""
if not mysql_pool:
raise ValueError("MySQL 连接池未初始化")
offset = (page - 1) * page_size
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
# 查询总记录数
@ -65,10 +65,20 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
)
total = (await cur.fetchone())[0]
# 查询分页数据,按 id 降序排列
# 计算总页数
total_pages = (total + page_size - 1) // page_size
# 如果未指定页码,则默认跳转到最后一页
if page is None:
page = total_pages
# 计算偏移量
offset = (page - 1) * page_size
# 查询分页数据,按 id 升序排列
await cur.execute(
"SELECT id, session_id, user_input, model_response, audio_url, duration, create_time "
"FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
"FROM t_chat_log WHERE session_id = %s ORDER BY id ASC LIMIT %s OFFSET %s",
(session_id, page_size, offset)
)
records = await cur.fetchall()
@ -88,26 +98,40 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
]
return {
"data": result,
"data": result, # 按 id 升序排列的数据
"total": total,
"page": page,
"page_size": page_size
"page_size": page_size,
"total_pages": total_pages
}
# 更新为危险的记录
async def update_risk(mysql_pool, session_id, risk_memo):
# 获取指定会话的最后一条记录的 id
async def get_last_chat_log_id(mysql_pool, session_id):
"""
获取指定会话的最后一条记录的 id
:param mysql_pool: MySQL 连接池
:param session_id: 用户会话 ID
:return: 最后一条记录的 id如果未找到则返回 None
"""
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
# 1. 获取此人员的最后一条记录 id
await cur.execute(
"SELECT id FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT 1",
(session_id,)
)
result = await cur.fetchone()
return result[0] if result else None
# 更新为危险的记录
async def update_risk(mysql_pool, session_id, risk_memo):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
# 1. 获取此人员的最后一条记录 id
last_id = await get_last_chat_log_id(mysql_pool, session_id)
if result:
last_id = result[0]
if last_id:
# 2. 更新 risk_flag 和 risk_memo
await cur.execute(
"UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s",

Loading…
Cancel
Save