main
HuangHai 4 months ago
parent e3a4d6ef55
commit bbfc00fbe4

@ -71,7 +71,7 @@ class MilvusCollectionManager:
# 使用 Milvus 的 query 方法查询指定 ID 的记录
results = self.collection.query(
expr=f"id == {id}", # 查询条件
output_fields=["id", "session_id", "user_input", "model_response", "timestamp"] # 返回的字段
output_fields=["id", "person_id", "user_input", "model_response", "timestamp"] # 返回的字段
)
if results:
return results[0] # 返回第一条记录

@ -1,3 +1,5 @@
import asyncio
from pymilvus import FieldSchema, DataType, utility
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
@ -23,13 +25,13 @@ if utility.has_collection(collection_name):
# 5. 定义集合的字段和模式
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), # 主键字段,自动生成 ID
FieldSchema(name="session_id", dtype=DataType.VARCHAR, max_length=64), # 会话 ID
FieldSchema(name="person_id", dtype=DataType.VARCHAR, max_length=64), # 会话 ID
FieldSchema(name="user_input", dtype=DataType.VARCHAR, max_length=2048), # 用户问题
FieldSchema(name="model_response", dtype=DataType.VARCHAR, max_length=2048), # 大模型反馈结果
FieldSchema(name="timestamp", dtype=DataType.VARCHAR, max_length=32), # 时间
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=MS_DIMENSION) # 向量字段,维度为 200
]
schema_description = "Chat records collection with session ID, user input, model response, and timestamp"
schema_description = "Chat records collection with person_id , user input, model response, and timestamp"
# 6. 创建集合
print(f"正在创建集合 '{collection_name}'...")

@ -66,14 +66,14 @@ print(f"大模型回复: {model_response}")
# 7. 获取当前时间和会话 ID
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) # 当前时间
session_id = "F9D1C319-215D-3B4E-FAA0-0E024AC12D3C" # 会话 ID可以根据需要动态生成
person_id = "F9D1C319-215D-3B4E-FAA0-0E024AC12D3C" # 会话 ID可以根据需要动态生成
# 8. 将用户问题转换为嵌入向量
user_embedding = text_to_embedding(user_input)
# 9. 插入数据,确保字段顺序与集合定义一致
entities = [
[session_id], # session_id
[person_id], # person_id
[user_input], # user_input
[model_response], # model_response
[timestamp], # timestamp

@ -21,7 +21,7 @@ try:
# 使用 Milvus 的 query 方法查询所有数据
results = collection_manager.collection.query(
expr="", # 空表达式表示查询所有数据
output_fields=["id", "session_id", "user_input", "model_response", "timestamp", "embedding"], # 指定返回的字段
output_fields=["id", "person_id", "user_input", "model_response", "timestamp", "embedding"], # 指定返回的字段
limit=1000 # 设置最大返回记录数
)
print("查询结果:")
@ -29,14 +29,14 @@ try:
for result in results:
try:
# 获取字段值
session_id = result["session_id"]
person_id = result["person_id"]
user_input = result["user_input"]
model_response = result["model_response"]
timestamp = result["timestamp"]
embedding = result["embedding"]
# 打印结果
print(f"ID: {result['id']}")
print(f"会话 ID: {session_id}")
print(f"会话 ID: {person_id}")
print(f"用户问题: {user_input}")
print(f"大模型回复: {model_response}")
print(f"时间: {timestamp}")

@ -62,7 +62,7 @@ if results:
# 查询非向量字段
record = collection_manager.query_by_id(hit.id)
print(f"ID: {hit.id}")
print(f"会话 ID: {record['session_id']}")
print(f"会话 ID: {record['person_id']}")
print(f"用户问题: {record['user_input']}")
print(f"大模型回复: {record['model_response']}")
print(f"时间: {record['timestamp']}")

@ -23,7 +23,7 @@ SET FOREIGN_KEY_CHECKS = 0;
DROP TABLE IF EXISTS `t_chat_log`;
CREATE TABLE `t_chat_log` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT '主键',
`session_id` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '用户人员编号',
`person_id` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '用户人员编号',
`user_input` varchar(2000) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '用户提出的问题',
`model_response` varchar(2000) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '大模型的反馈',
`audio_url` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT NULL COMMENT '生成的语音文件路径',

@ -13,7 +13,7 @@ from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory
from WxMini.Utils.TtsUtil import TTS
from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, update_risk, \
get_risk_chat_log_page, get_last_chat_log_id
get_risk_chat_log_page, get_last_chat_log_id, get_user_by_login_name
from WxMini.Utils.EmbeddingUtil import text_to_embedding
# 配置日志
@ -46,9 +46,9 @@ async def lifespan(app: FastAPI):
# 会话结束后,调用检查方法,判断是不是有需要介入的问题出现
async def on_session_end(session_id):
async def on_session_end(person_id):
# 获取最后一条聊天记录
last_id = await get_last_chat_log_id(app.state.mysql_pool, session_id)
last_id = await get_last_chat_log_id(app.state.mysql_pool, person_id)
if last_id:
# 查询最后一条记录的详细信息
async with app.state.mysql_pool.acquire() as conn:
@ -90,10 +90,10 @@ async def on_session_end(session_id):
analysis_result = response.choices[0].message.content.strip()
if analysis_result.startswith("NO"):
# 异步执行 update_risk
asyncio.create_task(update_risk(app.state.mysql_pool, session_id, analysis_result))
logger.info(f"已异步更新 session_id={session_id} 的风险状态。")
asyncio.create_task(update_risk(app.state.mysql_pool, person_id, analysis_result))
logger.info(f"已异步更新 person_id={person_id} 的风险状态。")
else:
logger.info(f"AI大模型没有发现任何心理健康问题用户会话 {session_id} 没有风险。")
logger.info(f"AI大模型没有发现任何心理健康问题用户会话 {person_id} 没有风险。")
# 初始化 FastAPI 应用
@ -106,11 +106,12 @@ client = AsyncOpenAI(
)
@app.post("/aichat/reply")
async def reply(session_id: str = Form(...), prompt: str = Form(...)):
async def reply(person_id: str = Form(...), prompt: str = Form(...)):
"""
接收用户输入的 prompt调用大模型并返回结果
:param session_id: 用户会话 ID
:param person_id: 用户会话 ID
:param prompt: 用户输入的 prompt
:return: 大模型的回复
"""
@ -132,7 +133,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
collection_manager.search,
data=current_embedding, # 输入向量
search_params=search_params, # 搜索参数
expr=f"session_id == '{session_id}'", # 按 session_id 过滤
expr=f"person_id == '{person_id}'", # 按 person_id 过滤
limit=5 # 返回 5 条结果
)
end_time = time.time()
@ -182,7 +183,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
# 记录用户输入和大模型反馈到向量数据库
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
entities = [
[session_id], # session_id
[person_id], # person_id
[prompt[:500]], # user_input截断到 500 字符
[result[:500]], # model_response截断到 500 字符
[timestamp], # timestamp
@ -214,11 +215,11 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
url = OSS_PREFIX + tts_file
# 记录聊天数据到 MySQL
await save_chat_to_mysql(app.state.mysql_pool, session_id, prompt, result, url, duration)
await save_chat_to_mysql(app.state.mysql_pool, person_id, prompt, result, url, duration)
logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。")
# 调用会话检查机制
await on_session_end(session_id)
await on_session_end(person_id)
# 返回数据
return {
@ -241,22 +242,62 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
# 获取聊天记录
from fastapi import Query
# 登录接口
@app.post("/aichat/login")
async def login(
login_name: str = Form(..., description="用户名"),
password: str = Form(..., description="密码")
):
"""
用户登录接口
:param login_name: 用户名
:param password: 密码
:return: 登录结果
"""
if not login_name or not password:
raise HTTPException(status_code=400, detail="用户名和密码不能为空")
# 调用 get_user_by_login_name 方法
user = await get_user_by_login_name(app.state.mysql_pool, login_name)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
if user['password'] != password:
raise HTTPException(status_code=401, detail="密码错误")
# 返回带字段名称的数据
return {
"code": 200,
"message": "登录成功",
"data": {
"person_id": user["person_id"],
"login_name": user["login_name"],
"identity_id": user["identity_id"],
"person_name": user["person_name"],
"xb_name": user["xb_name"],
"city_name": user["city_name"],
"area_name": user["area_name"],
"school_name": user["school_name"],
"grade_name": user["grade_name"],
"class_name": user["class_name"]
}
}
# 获取聊天记录
@app.get("/aichat/get_chat_log")
async def get_chat_log(
session_id: str,
person_id: str,
page: int = Query(default=1, ge=1, description="当前页码(默认值为 1但会动态计算为最后一页"),
page_size: int = Query(default=10, ge=1, le=100, description="每页记录数")
):
"""
获取指定会话的聊天记录默认返回最新的记录最后一页
:param session_id: 用户会话 ID
:param person_id: 用户会话 ID
:param page: 当前页码默认值为 1但会动态计算为最后一页
:param page_size: 每页记录数
:return: 分页数据
"""
# 调用 get_chat_log_by_session 方法
result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page, page_size)
result = await get_chat_log_by_session(app.state.mysql_pool, person_id, page, page_size)
return result
@app.get("/aichat/get_risk_page")

@ -14,10 +14,10 @@ prompt = "You are a helpful and harmless assistant. You are Qwen developed by Al
#img_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/QVQ/demo.png'
# 答案=46
img_url='https://hzkc.oss-cn-beijing.aliyuncs.com/Temp/simple_math.jpeg'
#img_url='https://hzkc.oss-cn-beijing.aliyuncs.com/Temp/simple_math.jpeg'
# 答案应该是80度
# img_url='https://hzkc.oss-cn-beijing.aliyuncs.com/Temp/hard_math.jpeg'
img_url='https://hzkc.oss-cn-beijing.aliyuncs.com/Temp/hard_math.jpeg'
response = client.chat.completions.create(
#model="Qwen/QVQ-72B-Preview",

@ -1,4 +1,6 @@
import logging
from typing import Optional, Dict
from aiomysql import create_pool
from WxMini.Milvus.Config.MulvusConfig import *
@ -24,12 +26,12 @@ async def init_mysql_pool():
# 保存聊天记录到 MySQL
async def save_chat_to_mysql(mysql_pool, session_id, prompt, result, audio_url, duration):
async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO t_chat_log (session_id, user_input, model_response,audio_url,duration,create_time) VALUES (%s, %s, %s, %s, %s,NOW())",
(session_id, prompt, result, audio_url, duration)
"INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,create_time) VALUES (%s, %s, %s, %s, %s,NOW())",
(person_id, prompt, result, audio_url, duration)
)
await conn.commit()
@ -46,11 +48,11 @@ async def truncate_chat_log(mysql_pool):
from aiomysql import DictCursor
# 分页查询聊天记录
async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
"""
根据 session_id 查询聊天记录并按 id 降序分页
根据 person_id 查询聊天记录并按 id 降序分页
:param mysql_pool: MySQL 连接池
:param session_id: 用户会话 ID
:param person_id: 用户会话 ID
:param page: 当前页码默认值为 1但会动态计算为最后一页
:param page_size: 每页记录数
:return: 分页数据
@ -62,8 +64,8 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
async with conn.cursor(DictCursor) as cur: # 使用 DictCursor
# 查询总记录数
await cur.execute(
"SELECT COUNT(*) FROM t_chat_log WHERE session_id = %s",
(session_id,)
"SELECT COUNT(*) FROM t_chat_log WHERE person_id = %s",
(person_id,)
)
total = (await cur.fetchone())['COUNT(*)']
@ -75,9 +77,9 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
# 查询分页数据,按 id 降序排列
await cur.execute(
"SELECT id, session_id, user_input, model_response, audio_url, duration, create_time "
"FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(session_id, page_size, offset)
"SELECT id, person_id, user_input, model_response, audio_url, duration, create_time "
"FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(person_id, page_size, offset)
)
records = await cur.fetchall()
@ -88,7 +90,7 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
result = [
{
"id": record['id'],
"session_id": record['session_id'],
"person_id": record['person_id'],
"user_input": record['user_input'],
"model_response": record['model_response'],
"audio_url": record['audio_url'],
@ -108,7 +110,7 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10):
# 获取指定会话的最后一条记录的 id
async def get_last_chat_log_id(mysql_pool, session_id):
async def get_last_chat_log_id(mysql_pool, person_id):
"""
获取指定会话的最后一条记录的 id
:param mysql_pool: MySQL 连接池
@ -118,19 +120,19 @@ async def get_last_chat_log_id(mysql_pool, session_id):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"SELECT id FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT 1",
(session_id,)
"SELECT id FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT 1",
(person_id,)
)
result = await cur.fetchone()
return result[0] if result else None
# 更新为危险的记录
async def update_risk(mysql_pool, session_id, risk_memo):
async def update_risk(mysql_pool, person_id, risk_memo):
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
# 1. 获取此人员的最后一条记录 id
last_id = await get_last_chat_log_id(mysql_pool, session_id)
last_id = await get_last_chat_log_id(mysql_pool, person_id)
if last_id:
# 2. 更新 risk_flag 和 risk_memo
@ -139,9 +141,9 @@ async def update_risk(mysql_pool, session_id, risk_memo):
(risk_memo.replace('\n', '').replace("NO", ""), last_id)
)
await conn.commit()
logger.info(f"已更新 session_id={session_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。")
logger.info(f"已更新 person_id={person_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。")
else:
logger.warning(f"未找到 session_id={session_id} 的记录。")
logger.warning(f"未找到 person_id={person_id} 的记录。")
# 查询有风险的聊天记录
@ -166,7 +168,7 @@ async def get_risk_chat_log_page(mysql_pool, risk_flag, page=1, page_size=10):
# 查询分页数据
query = (
"SELECT id, session_id, user_input, model_response, audio_url, duration, create_time, risk_memo "
"SELECT id, person_id, user_input, model_response, audio_url, duration, create_time, risk_memo "
"FROM t_chat_log WHERE risk_flag = %s ORDER BY id DESC LIMIT %s OFFSET %s"
)
params = (risk_flag, page_size, offset)
@ -180,7 +182,7 @@ async def get_risk_chat_log_page(mysql_pool, risk_flag, page=1, page_size=10):
result = [
{
"id": record[0],
"session_id": record[1],
"person_id": record[1],
"user_input": record[2],
"model_response": record[3],
"audio_url": record[4],
@ -196,4 +198,23 @@ async def get_risk_chat_log_page(mysql_pool, risk_flag, page=1, page_size=10):
"total": total,
"page": page,
"page_size": page_size
}
}
# 查询用户信息
async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
"""
根据用户名查询用户信息
:param pool: MySQL 连接池
:param login_name: 用户名
:return: 用户信息字典形式
"""
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cursor:
sql = "SELECT * FROM t_base_person WHERE login_name = %s"
await cursor.execute(sql, (login_name,))
row = await cursor.fetchone()
if not row:
return None
# 将元组转换为字典
columns = [column[0] for column in cursor.description]
return dict(zip(columns, row))

@ -32,10 +32,10 @@ __all__ = ["NlsStreamInputTtsSynthesizer"]
class NlsStreamInputTtsRequest:
def __init__(self, task_id, session_id, appkey):
def __init__(self, task_id, person_id, appkey):
self.task_id = task_id
self.appkey = appkey
self.session_id = session_id
self.person_id = person_id
def getStartCMD(self, voice, format, sample_rate, volumn, speech_rate, pitch_rate, ex):
self.voice = voice
@ -53,7 +53,7 @@ class NlsStreamInputTtsRequest:
"appkey": self.appkey,
},
"payload": {
"session_id": self.session_id,
"session_id": self.person_id,
"voice": self.voice,
"format": self.format,
"sample_rate": self.sample_rate,

Loading…
Cancel
Save