diff --git a/AI/WxMini/Milvus/Utils/MilvusCollectionManager.py b/AI/WxMini/Milvus/Utils/MilvusCollectionManager.py index 1d10aa19..39234eaa 100644 --- a/AI/WxMini/Milvus/Utils/MilvusCollectionManager.py +++ b/AI/WxMini/Milvus/Utils/MilvusCollectionManager.py @@ -71,7 +71,7 @@ class MilvusCollectionManager: # 使用 Milvus 的 query 方法查询指定 ID 的记录 results = self.collection.query( expr=f"id == {id}", # 查询条件 - output_fields=["id", "session_id", "user_input", "model_response", "timestamp"] # 返回的字段 + output_fields=["id", "person_id", "user_input", "model_response", "timestamp"] # 返回的字段 ) if results: return results[0] # 返回第一条记录 diff --git a/AI/WxMini/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc b/AI/WxMini/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc index 22ad788f..b7382af3 100644 Binary files a/AI/WxMini/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc and b/AI/WxMini/Milvus/Utils/__pycache__/MilvusCollectionManager.cpython-310.pyc differ diff --git a/AI/WxMini/Milvus/X1_create_collection.py b/AI/WxMini/Milvus/X1_create_collection.py index 9c513aa5..0318dea3 100644 --- a/AI/WxMini/Milvus/X1_create_collection.py +++ b/AI/WxMini/Milvus/X1_create_collection.py @@ -1,3 +1,5 @@ +import asyncio + from pymilvus import FieldSchema, DataType, utility from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager @@ -23,13 +25,13 @@ if utility.has_collection(collection_name): # 5. 定义集合的字段和模式 fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), # 主键字段,自动生成 ID - FieldSchema(name="session_id", dtype=DataType.VARCHAR, max_length=64), # 会话 ID + FieldSchema(name="person_id", dtype=DataType.VARCHAR, max_length=64), # 会话 ID FieldSchema(name="user_input", dtype=DataType.VARCHAR, max_length=2048), # 用户问题 FieldSchema(name="model_response", dtype=DataType.VARCHAR, max_length=2048), # 大模型反馈结果 FieldSchema(name="timestamp", dtype=DataType.VARCHAR, max_length=32), # 时间 FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=MS_DIMENSION) # 向量字段,维度为 200 ] -schema_description = "Chat records collection with session ID, user input, model response, and timestamp" +schema_description = "Chat records collection with person_id , user input, model response, and timestamp" # 6. 创建集合 print(f"正在创建集合 '{collection_name}'...") diff --git a/AI/WxMini/Milvus/X3_insert_data.py b/AI/WxMini/Milvus/X3_insert_data.py index b5e614f6..c8b68618 100644 --- a/AI/WxMini/Milvus/X3_insert_data.py +++ b/AI/WxMini/Milvus/X3_insert_data.py @@ -66,14 +66,14 @@ print(f"大模型回复: {model_response}") # 7. 获取当前时间和会话 ID timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) # 当前时间 -session_id = "F9D1C319-215D-3B4E-FAA0-0E024AC12D3C" # 会话 ID(可以根据需要动态生成) +person_id = "F9D1C319-215D-3B4E-FAA0-0E024AC12D3C" # 会话 ID(可以根据需要动态生成) # 8. 将用户问题转换为嵌入向量 user_embedding = text_to_embedding(user_input) # 9. 插入数据,确保字段顺序与集合定义一致 entities = [ - [session_id], # session_id + [person_id], # person_id [user_input], # user_input [model_response], # model_response [timestamp], # timestamp diff --git a/AI/WxMini/Milvus/X4_select_all_data.py b/AI/WxMini/Milvus/X4_select_all_data.py index 290c89ec..0a8395ba 100644 --- a/AI/WxMini/Milvus/X4_select_all_data.py +++ b/AI/WxMini/Milvus/X4_select_all_data.py @@ -21,7 +21,7 @@ try: # 使用 Milvus 的 query 方法查询所有数据 results = collection_manager.collection.query( expr="", # 空表达式表示查询所有数据 - output_fields=["id", "session_id", "user_input", "model_response", "timestamp", "embedding"], # 指定返回的字段 + output_fields=["id", "person_id", "user_input", "model_response", "timestamp", "embedding"], # 指定返回的字段 limit=1000 # 设置最大返回记录数 ) print("查询结果:") @@ -29,14 +29,14 @@ try: for result in results: try: # 获取字段值 - session_id = result["session_id"] + person_id = result["person_id"] user_input = result["user_input"] model_response = result["model_response"] timestamp = result["timestamp"] embedding = result["embedding"] # 打印结果 print(f"ID: {result['id']}") - print(f"会话 ID: {session_id}") + print(f"会话 ID: {person_id}") print(f"用户问题: {user_input}") print(f"大模型回复: {model_response}") print(f"时间: {timestamp}") diff --git a/AI/WxMini/Milvus/X5_search_near_data.py b/AI/WxMini/Milvus/X5_search_near_data.py index 39ca8bd7..f75dc867 100644 --- a/AI/WxMini/Milvus/X5_search_near_data.py +++ b/AI/WxMini/Milvus/X5_search_near_data.py @@ -62,7 +62,7 @@ if results: # 查询非向量字段 record = collection_manager.query_by_id(hit.id) print(f"ID: {hit.id}") - print(f"会话 ID: {record['session_id']}") + print(f"会话 ID: {record['person_id']}") print(f"用户问题: {record['user_input']}") print(f"大模型回复: {record['model_response']}") print(f"时间: {record['timestamp']}") diff --git a/AI/WxMini/Sql/t_chat_log.sql b/AI/WxMini/Sql/t_chat_log.sql index cbb24698..17b9e98c 100644 --- a/AI/WxMini/Sql/t_chat_log.sql +++ b/AI/WxMini/Sql/t_chat_log.sql @@ -23,7 +23,7 @@ SET FOREIGN_KEY_CHECKS = 0; DROP TABLE IF EXISTS `t_chat_log`; CREATE TABLE `t_chat_log` ( `id` int(11) NOT NULL AUTO_INCREMENT COMMENT '主键', - `session_id` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '用户人员编号', + `person_id` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '用户人员编号', `user_input` varchar(2000) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '用户提出的问题', `model_response` varchar(2000) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '大模型的反馈', `audio_url` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT NULL COMMENT '生成的语音文件路径', diff --git a/AI/WxMini/Start.py b/AI/WxMini/Start.py index e770c91b..76c9dafd 100644 --- a/AI/WxMini/Start.py +++ b/AI/WxMini/Start.py @@ -13,7 +13,7 @@ from WxMini.Milvus.Utils.MilvusConnectionPool import * from WxMini.Utils.OssUtil import upload_mp3_to_oss_from_memory from WxMini.Utils.TtsUtil import TTS from WxMini.Utils.MySQLUtil import init_mysql_pool, save_chat_to_mysql, get_chat_log_by_session, update_risk, \ - get_risk_chat_log_page, get_last_chat_log_id + get_risk_chat_log_page, get_last_chat_log_id, get_user_by_login_name from WxMini.Utils.EmbeddingUtil import text_to_embedding # 配置日志 @@ -46,9 +46,9 @@ async def lifespan(app: FastAPI): # 会话结束后,调用检查方法,判断是不是有需要介入的问题出现 -async def on_session_end(session_id): +async def on_session_end(person_id): # 获取最后一条聊天记录 - last_id = await get_last_chat_log_id(app.state.mysql_pool, session_id) + last_id = await get_last_chat_log_id(app.state.mysql_pool, person_id) if last_id: # 查询最后一条记录的详细信息 async with app.state.mysql_pool.acquire() as conn: @@ -90,10 +90,10 @@ async def on_session_end(session_id): analysis_result = response.choices[0].message.content.strip() if analysis_result.startswith("NO"): # 异步执行 update_risk - asyncio.create_task(update_risk(app.state.mysql_pool, session_id, analysis_result)) - logger.info(f"已异步更新 session_id={session_id} 的风险状态。") + asyncio.create_task(update_risk(app.state.mysql_pool, person_id, analysis_result)) + logger.info(f"已异步更新 person_id={person_id} 的风险状态。") else: - logger.info(f"AI大模型没有发现任何心理健康问题,用户会话 {session_id} 没有风险。") + logger.info(f"AI大模型没有发现任何心理健康问题,用户会话 {person_id} 没有风险。") # 初始化 FastAPI 应用 @@ -106,11 +106,12 @@ client = AsyncOpenAI( ) + @app.post("/aichat/reply") -async def reply(session_id: str = Form(...), prompt: str = Form(...)): +async def reply(person_id: str = Form(...), prompt: str = Form(...)): """ 接收用户输入的 prompt,调用大模型并返回结果 - :param session_id: 用户会话 ID + :param person_id: 用户会话 ID :param prompt: 用户输入的 prompt :return: 大模型的回复 """ @@ -132,7 +133,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)): collection_manager.search, data=current_embedding, # 输入向量 search_params=search_params, # 搜索参数 - expr=f"session_id == '{session_id}'", # 按 session_id 过滤 + expr=f"person_id == '{person_id}'", # 按 person_id 过滤 limit=5 # 返回 5 条结果 ) end_time = time.time() @@ -182,7 +183,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)): # 记录用户输入和大模型反馈到向量数据库 timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) entities = [ - [session_id], # session_id + [person_id], # person_id [prompt[:500]], # user_input,截断到 500 字符 [result[:500]], # model_response,截断到 500 字符 [timestamp], # timestamp @@ -214,11 +215,11 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)): url = OSS_PREFIX + tts_file # 记录聊天数据到 MySQL - await save_chat_to_mysql(app.state.mysql_pool, session_id, prompt, result, url, duration) + await save_chat_to_mysql(app.state.mysql_pool, person_id, prompt, result, url, duration) logger.info("用户输入和大模型反馈已记录到 MySQL 数据库。") # 调用会话检查机制 - await on_session_end(session_id) + await on_session_end(person_id) # 返回数据 return { @@ -241,22 +242,62 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)): # 获取聊天记录 from fastapi import Query +# 登录接口 +@app.post("/aichat/login") +async def login( + login_name: str = Form(..., description="用户名"), + password: str = Form(..., description="密码") +): + """ + 用户登录接口 + :param login_name: 用户名 + :param password: 密码 + :return: 登录结果 + """ + if not login_name or not password: + raise HTTPException(status_code=400, detail="用户名和密码不能为空") + + # 调用 get_user_by_login_name 方法 + user = await get_user_by_login_name(app.state.mysql_pool, login_name) + if not user: + raise HTTPException(status_code=404, detail="用户不存在") + if user['password'] != password: + raise HTTPException(status_code=401, detail="密码错误") + + # 返回带字段名称的数据 + return { + "code": 200, + "message": "登录成功", + "data": { + "person_id": user["person_id"], + "login_name": user["login_name"], + "identity_id": user["identity_id"], + "person_name": user["person_name"], + "xb_name": user["xb_name"], + "city_name": user["city_name"], + "area_name": user["area_name"], + "school_name": user["school_name"], + "grade_name": user["grade_name"], + "class_name": user["class_name"] + } + } + # 获取聊天记录 @app.get("/aichat/get_chat_log") async def get_chat_log( - session_id: str, + person_id: str, page: int = Query(default=1, ge=1, description="当前页码(默认值为 1,但会动态计算为最后一页)"), page_size: int = Query(default=10, ge=1, le=100, description="每页记录数") ): """ 获取指定会话的聊天记录,默认返回最新的记录(最后一页) - :param session_id: 用户会话 ID + :param person_id: 用户会话 ID :param page: 当前页码(默认值为 1,但会动态计算为最后一页) :param page_size: 每页记录数 :return: 分页数据 """ # 调用 get_chat_log_by_session 方法 - result = await get_chat_log_by_session(app.state.mysql_pool, session_id, page, page_size) + result = await get_chat_log_by_session(app.state.mysql_pool, person_id, page, page_size) return result @app.get("/aichat/get_risk_page") diff --git a/AI/WxMini/Test/Math1.py b/AI/WxMini/Test/Math1.py index e6ad4749..85b9574c 100644 --- a/AI/WxMini/Test/Math1.py +++ b/AI/WxMini/Test/Math1.py @@ -14,10 +14,10 @@ prompt = "You are a helpful and harmless assistant. You are Qwen developed by Al #img_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/QVQ/demo.png' # 答案=46 -img_url='https://hzkc.oss-cn-beijing.aliyuncs.com/Temp/simple_math.jpeg' +#img_url='https://hzkc.oss-cn-beijing.aliyuncs.com/Temp/simple_math.jpeg' # 答案应该是80度 -# img_url='https://hzkc.oss-cn-beijing.aliyuncs.com/Temp/hard_math.jpeg' +img_url='https://hzkc.oss-cn-beijing.aliyuncs.com/Temp/hard_math.jpeg' response = client.chat.completions.create( #model="Qwen/QVQ-72B-Preview", diff --git a/AI/WxMini/Utils/MySQLUtil.py b/AI/WxMini/Utils/MySQLUtil.py index 4105a265..50647a2c 100644 --- a/AI/WxMini/Utils/MySQLUtil.py +++ b/AI/WxMini/Utils/MySQLUtil.py @@ -1,4 +1,6 @@ import logging +from typing import Optional, Dict + from aiomysql import create_pool from WxMini.Milvus.Config.MulvusConfig import * @@ -24,12 +26,12 @@ async def init_mysql_pool(): # 保存聊天记录到 MySQL -async def save_chat_to_mysql(mysql_pool, session_id, prompt, result, audio_url, duration): +async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration): async with mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( - "INSERT INTO t_chat_log (session_id, user_input, model_response,audio_url,duration,create_time) VALUES (%s, %s, %s, %s, %s,NOW())", - (session_id, prompt, result, audio_url, duration) + "INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,create_time) VALUES (%s, %s, %s, %s, %s,NOW())", + (person_id, prompt, result, audio_url, duration) ) await conn.commit() @@ -46,11 +48,11 @@ async def truncate_chat_log(mysql_pool): from aiomysql import DictCursor # 分页查询聊天记录 -async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10): +async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10): """ - 根据 session_id 查询聊天记录,并按 id 降序分页 + 根据 person_id 查询聊天记录,并按 id 降序分页 :param mysql_pool: MySQL 连接池 - :param session_id: 用户会话 ID + :param person_id: 用户会话 ID :param page: 当前页码(默认值为 1,但会动态计算为最后一页) :param page_size: 每页记录数 :return: 分页数据 @@ -62,8 +64,8 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10): async with conn.cursor(DictCursor) as cur: # 使用 DictCursor # 查询总记录数 await cur.execute( - "SELECT COUNT(*) FROM t_chat_log WHERE session_id = %s", - (session_id,) + "SELECT COUNT(*) FROM t_chat_log WHERE person_id = %s", + (person_id,) ) total = (await cur.fetchone())['COUNT(*)'] @@ -75,9 +77,9 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10): # 查询分页数据,按 id 降序排列 await cur.execute( - "SELECT id, session_id, user_input, model_response, audio_url, duration, create_time " - "FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT %s OFFSET %s", - (session_id, page_size, offset) + "SELECT id, person_id, user_input, model_response, audio_url, duration, create_time " + "FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s", + (person_id, page_size, offset) ) records = await cur.fetchall() @@ -88,7 +90,7 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10): result = [ { "id": record['id'], - "session_id": record['session_id'], + "person_id": record['person_id'], "user_input": record['user_input'], "model_response": record['model_response'], "audio_url": record['audio_url'], @@ -108,7 +110,7 @@ async def get_chat_log_by_session(mysql_pool, session_id, page=1, page_size=10): # 获取指定会话的最后一条记录的 id -async def get_last_chat_log_id(mysql_pool, session_id): +async def get_last_chat_log_id(mysql_pool, person_id): """ 获取指定会话的最后一条记录的 id :param mysql_pool: MySQL 连接池 @@ -118,19 +120,19 @@ async def get_last_chat_log_id(mysql_pool, session_id): async with mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( - "SELECT id FROM t_chat_log WHERE session_id = %s ORDER BY id DESC LIMIT 1", - (session_id,) + "SELECT id FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT 1", + (person_id,) ) result = await cur.fetchone() return result[0] if result else None # 更新为危险的记录 -async def update_risk(mysql_pool, session_id, risk_memo): +async def update_risk(mysql_pool, person_id, risk_memo): async with mysql_pool.acquire() as conn: async with conn.cursor() as cur: # 1. 获取此人员的最后一条记录 id - last_id = await get_last_chat_log_id(mysql_pool, session_id) + last_id = await get_last_chat_log_id(mysql_pool, person_id) if last_id: # 2. 更新 risk_flag 和 risk_memo @@ -139,9 +141,9 @@ async def update_risk(mysql_pool, session_id, risk_memo): (risk_memo.replace('\n', '').replace("NO", ""), last_id) ) await conn.commit() - logger.info(f"已更新 session_id={session_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。") + logger.info(f"已更新 person_id={person_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。") else: - logger.warning(f"未找到 session_id={session_id} 的记录。") + logger.warning(f"未找到 person_id={person_id} 的记录。") # 查询有风险的聊天记录 @@ -166,7 +168,7 @@ async def get_risk_chat_log_page(mysql_pool, risk_flag, page=1, page_size=10): # 查询分页数据 query = ( - "SELECT id, session_id, user_input, model_response, audio_url, duration, create_time, risk_memo " + "SELECT id, person_id, user_input, model_response, audio_url, duration, create_time, risk_memo " "FROM t_chat_log WHERE risk_flag = %s ORDER BY id DESC LIMIT %s OFFSET %s" ) params = (risk_flag, page_size, offset) @@ -180,7 +182,7 @@ async def get_risk_chat_log_page(mysql_pool, risk_flag, page=1, page_size=10): result = [ { "id": record[0], - "session_id": record[1], + "person_id": record[1], "user_input": record[2], "model_response": record[3], "audio_url": record[4], @@ -196,4 +198,23 @@ async def get_risk_chat_log_page(mysql_pool, risk_flag, page=1, page_size=10): "total": total, "page": page, "page_size": page_size - } \ No newline at end of file + } +# 查询用户信息 +async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]: + """ + 根据用户名查询用户信息 + :param pool: MySQL 连接池 + :param login_name: 用户名 + :return: 用户信息(字典形式) + """ + async with mysql_pool.acquire() as conn: + async with conn.cursor() as cursor: + sql = "SELECT * FROM t_base_person WHERE login_name = %s" + await cursor.execute(sql, (login_name,)) + row = await cursor.fetchone() + if not row: + return None + + # 将元组转换为字典 + columns = [column[0] for column in cursor.description] + return dict(zip(columns, row)) \ No newline at end of file diff --git a/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc b/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc index bef1a193..4fb6dc19 100644 Binary files a/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc and b/AI/WxMini/Utils/__pycache__/MySQLUtil.cpython-310.pyc differ diff --git a/AI/WxMini/Utils/__pycache__/RedisUtil.cpython-310.pyc b/AI/WxMini/Utils/__pycache__/RedisUtil.cpython-310.pyc index 507cb7b3..4a927147 100644 Binary files a/AI/WxMini/Utils/__pycache__/RedisUtil.cpython-310.pyc and b/AI/WxMini/Utils/__pycache__/RedisUtil.cpython-310.pyc differ diff --git a/AI/WxMini/__pycache__/Start.cpython-310.pyc b/AI/WxMini/__pycache__/Start.cpython-310.pyc index 58368920..d4d62029 100644 Binary files a/AI/WxMini/__pycache__/Start.cpython-310.pyc and b/AI/WxMini/__pycache__/Start.cpython-310.pyc differ diff --git a/AI/WxMini/alibabacloud-nls-python-sdk-dev/build/lib/nls/stream_input_tts.py b/AI/WxMini/alibabacloud-nls-python-sdk-dev/build/lib/nls/stream_input_tts.py index bf63d48b..3ec09e49 100644 --- a/AI/WxMini/alibabacloud-nls-python-sdk-dev/build/lib/nls/stream_input_tts.py +++ b/AI/WxMini/alibabacloud-nls-python-sdk-dev/build/lib/nls/stream_input_tts.py @@ -32,10 +32,10 @@ __all__ = ["NlsStreamInputTtsSynthesizer"] class NlsStreamInputTtsRequest: - def __init__(self, task_id, session_id, appkey): + def __init__(self, task_id, person_id, appkey): self.task_id = task_id self.appkey = appkey - self.session_id = session_id + self.person_id = person_id def getStartCMD(self, voice, format, sample_rate, volumn, speech_rate, pitch_rate, ex): self.voice = voice @@ -53,7 +53,7 @@ class NlsStreamInputTtsRequest: "appkey": self.appkey, }, "payload": { - "session_id": self.session_id, + "session_id": self.person_id, "voice": self.voice, "format": self.format, "sample_rate": self.sample_rate,