diff --git a/dsRag/Config/Config.py b/dsRag/Config/Config.py index 520bee08..db0bf35c 100644 --- a/dsRag/Config/Config.py +++ b/dsRag/Config/Config.py @@ -13,3 +13,11 @@ WORD2VEC_MODEL_PATH = r"D:\Tencent_AILab_ChineseEmbedding\Tencent_AILab_ChineseE # DeepSeek DEEPSEEK_API_KEY = 'sk-44ae895eeb614aa1a9c6460579e322f1' DEEPSEEK_URL = 'https://api.deepseek.com' + + +# MYSQL配置信息 +MYSQL_HOST = "10.10.14.210" +MYSQL_PORT = 22066 +MYSQL_USER = "root" +MYSQL_PASSWORD = "DsideaL147258369" +MYSQL_DB_NAME = "base_db" \ No newline at end of file diff --git a/dsRag/Config/__pycache__/Config.cpython-310.pyc b/dsRag/Config/__pycache__/Config.cpython-310.pyc index 6451932c..a05d9cfa 100644 Binary files a/dsRag/Config/__pycache__/Config.cpython-310.pyc and b/dsRag/Config/__pycache__/Config.cpython-310.pyc differ diff --git a/dsRag/Sql/t_ai_kb.sql b/dsRag/Sql/t_ai_kb.sql new file mode 100644 index 00000000..d83cbfbe --- /dev/null +++ b/dsRag/Sql/t_ai_kb.sql @@ -0,0 +1,61 @@ +/* + Navicat Premium Dump SQL + + Source Server : 10.10.14.210 + Source Server Type : MySQL + Source Server Version : 50742 (5.7.42-log) + Source Host : 10.10.14.210:22066 + Source Schema : base_db + + Target Server Type : MySQL + Target Server Version : 50742 (5.7.42-log) + File Encoding : 65001 + + Date: 24/06/2025 18:45:23 +*/ + +SET NAMES utf8mb4; +SET FOREIGN_KEY_CHECKS = 0; + +-- ---------------------------- +-- Table structure for t_ai_kb +-- ---------------------------- +DROP TABLE IF EXISTS `t_ai_kb`; +CREATE TABLE `t_ai_kb` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT '主键', + `kb_name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '知识库名称', + `short_name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '英文简称', + `create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + `is_delete` int(11) NOT NULL DEFAULT 0 COMMENT '是否删除', + PRIMARY KEY (`id`) USING BTREE, + UNIQUE INDEX `short_name`(`short_name`) USING BTREE, + INDEX `is_delete`(`is_delete`) USING BTREE +) ENGINE = InnoDB AUTO_INCREMENT = 1 CHARACTER SET = utf8mb4 COLLATE = utf8mb4_general_ci COMMENT = 'AI知识库' ROW_FORMAT = Dynamic; + +-- ---------------------------- +-- Records of t_ai_kb +-- ---------------------------- + +-- ---------------------------- +-- Table structure for t_ai_kb_files +-- ---------------------------- +DROP TABLE IF EXISTS `t_ai_kb_files`; +CREATE TABLE `t_ai_kb_files` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT '主键', + `file_name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '文件名称', + `ext_name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '文件扩展名', + `create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + `kb_id` int(11) NOT NULL COMMENT '隶属知识库ID', + `is_delete` int(11) NOT NULL DEFAULT 0 COMMENT '是否删除', + `state` int(11) NOT NULL DEFAULT 0 COMMENT '0:上传后未处理,1:上传后已处理,2:处理失败', + PRIMARY KEY (`id`) USING BTREE, + INDEX `kb_id`(`kb_id`) USING BTREE, + INDEX `is_delete`(`is_delete`, `state`) USING BTREE, + CONSTRAINT `t_ai_kb_files_ibfk_1` FOREIGN KEY (`kb_id`) REFERENCES `t_ai_kb` (`id`) ON DELETE CASCADE ON UPDATE RESTRICT +) ENGINE = InnoDB AUTO_INCREMENT = 1 CHARACTER SET = utf8mb4 COLLATE = utf8mb4_general_ci COMMENT = 'AI知识库上传的文件' ROW_FORMAT = Dynamic; + +-- ---------------------------- +-- Records of t_ai_kb_files +-- ---------------------------- + +SET FOREIGN_KEY_CHECKS = 1; diff --git a/dsRag/Start.py b/dsRag/Start.py new file mode 100644 index 00000000..1c452d94 --- /dev/null +++ b/dsRag/Start.py @@ -0,0 +1,137 @@ +""" +pip install fastapi uvicorn aiomysql +""" +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +from Util.MySQLUtil import * + +""" +API文档访问: http://localhost:8000/docs +该实现包含以下功能: + +- 知识库(t_ai_kb)的增删改查接口 +- 知识库文件(t_ai_kb_files)的增删改查接口 +- 使用MySQLUtil.py中的连接池管理 +- 自动生成的Swagger文档 +""" + +app = FastAPI() + +# 知识库模型 +class KbModel(BaseModel): + kb_name: str + short_name: str + is_delete: Optional[int] = 0 + +# 知识库文件模型 +class KbFileModel(BaseModel): + file_name: str + ext_name: str + kb_id: int + is_delete: Optional[int] = 0 + state: Optional[int] = 0 + +@app.on_event("startup") +async def startup_event(): + app.state.mysql_pool = await init_mysql_pool() + +@app.on_event("shutdown") +async def shutdown_event(): + app.state.mysql_pool.close() + await app.state.mysql_pool.wait_closed() + +# 知识库CRUD接口 +@app.post("/kb") +async def create_kb(kb: KbModel): + async with app.state.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + """INSERT INTO t_ai_kb (kb_name, short_name, is_delete) + VALUES (%s, %s, %s)""", + (kb.kb_name, kb.short_name, kb.is_delete) + ) + await conn.commit() + return {"id": cur.lastrowid} + +@app.get("/kb/{kb_id}") +async def read_kb(kb_id: int): + async with app.state.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,)) + result = await cur.fetchone() + if not result: + raise HTTPException(status_code=404, detail="Knowledge base not found") + return result + +@app.put("/kb/{kb_id}") +async def update_kb(kb_id: int, kb: KbModel): + async with app.state.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + """UPDATE t_ai_kb + SET kb_name = %s, short_name = %s, is_delete = %s + WHERE id = %s""", + (kb.kb_name, kb.short_name, kb.is_delete, kb_id) + ) + await conn.commit() + return {"message": "Knowledge base updated"} + +@app.delete("/kb/{kb_id}") +async def delete_kb(kb_id: int): + async with app.state.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute("DELETE FROM t_ai_kb WHERE id = %s", (kb_id,)) + await conn.commit() + return {"message": "Knowledge base deleted"} + +# 知识库文件CRUD接口 +@app.post("/kb_files") +async def create_kb_file(file: KbFileModel): + async with app.state.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + """INSERT INTO t_ai_kb_files + (file_name, ext_name, kb_id, is_delete, state) + VALUES (%s, %s, %s, %s, %s)""", + (file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state) + ) + await conn.commit() + return {"id": cur.lastrowid} + +@app.get("/kb_files/{file_id}") +async def read_kb_file(file_id: int): + async with app.state.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute("SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,)) + result = await cur.fetchone() + if not result: + raise HTTPException(status_code=404, detail="File not found") + return result + +@app.put("/kb_files/{file_id}") +async def update_kb_file(file_id: int, file: KbFileModel): + async with app.state.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + """UPDATE t_ai_kb_files + SET file_name = %s, ext_name = %s, kb_id = %s, + is_delete = %s, state = %s + WHERE id = %s""", + (file.file_name, file.ext_name, file.kb_id, + file.is_delete, file.state, file_id) + ) + await conn.commit() + return {"message": "File updated"} + +@app.delete("/kb_files/{file_id}") +async def delete_kb_file(file_id: int): + async with app.state.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute("DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,)) + await conn.commit() + return {"message": "File deleted"} + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/dsRag/Util/MySQLUtil.py b/dsRag/Util/MySQLUtil.py new file mode 100644 index 00000000..0302e558 --- /dev/null +++ b/dsRag/Util/MySQLUtil.py @@ -0,0 +1,281 @@ +""" +pip install aiomysql +""" +import logging +from typing import Optional, Dict, List + +from aiomysql import create_pool +from Config.Config import * +# 配置日志 +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +# MySQL 配置 +MYSQL_CONFIG = { + "host": MYSQL_HOST, + "port": MYSQL_PORT, + "user": MYSQL_USER, + "password": MYSQL_PASSWORD, + "db": MYSQL_DB_NAME, + "minsize": 1, + "maxsize": 20, +} + + +# 初始化 MySQL 连接池 +async def init_mysql_pool(): + return await create_pool(**MYSQL_CONFIG) + + +# 保存聊天记录到 MySQL +async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration, input_type=1, output_type=1, + input_image_type=0, image_width=0, image_height=0): + async with mysql_pool.acquire() as conn: + await conn.ping() # 重置连接 + async with conn.cursor() as cur: + await cur.execute( + "INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,input_type,output_type,input_image_type,image_width,image_height,create_time) VALUES (%s, %s, %s, %s, %s, %s, %s,%s,%s,%s,NOW())", + (person_id, prompt, result, audio_url, duration, input_type, output_type, input_image_type, image_width, + image_height) + ) + await conn.commit() + + +# 清空表 +async def truncate_chat_log(mysql_pool): + async with mysql_pool.acquire() as conn: + await conn.ping() # 重置连接 + async with conn.cursor() as cur: + await cur.execute("TRUNCATE TABLE t_chat_log") + await conn.commit() + logger.info("表 t_chat_log 已清空。") + + +from aiomysql import DictCursor + + +# 分页查询聊天记录 +async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10): + """ + 根据 person_id 查询聊天记录,并按 id 降序分页 + :param mysql_pool: MySQL 连接池 + :param person_id: 用户会话 ID + :param page: 当前页码(默认值为 1,但会动态计算为最后一页) + :param page_size: 每页记录数 + :return: 分页数据 + """ + if not mysql_pool: + raise ValueError("MySQL 连接池未初始化") + + async with mysql_pool.acquire() as conn: + await conn.ping() # 重置连接 + async with conn.cursor(DictCursor) as cur: # 使用 DictCursor + # 查询总记录数 + await cur.execute( + "SELECT COUNT(*) FROM t_chat_log WHERE person_id = %s", + (person_id,) + ) + total = (await cur.fetchone())['COUNT(*)'] + + # 计算总页数 + total_pages = (total + page_size - 1) // page_size + + # 计算偏移量 + offset = (page - 1) * page_size + + # 查询分页数据,按 id 降序排列 + await cur.execute( + "SELECT id, person_id, user_input, model_response, audio_url, duration,input_type,output_type,input_image_type,image_width,image_height, create_time " + "FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s", + (person_id, page_size, offset) + ) + records = await cur.fetchall() + + # 将查询结果反转,确保最新消息显示在最后 + if page==1 and records: + records.reverse() + + # 将查询结果转换为字典列表 + result = [ + { + "id": record['id'], + "person_id": record['person_id'], + "user_input": record['user_input'], + "model_response": record['model_response'], + "audio_url": record['audio_url'], + "duration": record['duration'], + "input_type": record['input_type'], + "output_type": record['output_type'], + "image_width": record['image_width'], + "image_height": record['image_height'], + "input_image_type": record['input_image_type'], + "create_time": record['create_time'].strftime("%Y-%m-%d %H:%M:%S") + } + for record in records + ] + + return { + "data": result, # 按 id 升序排列的数据 + "total": total, + "page": page, + "page_size": page_size, + "total_pages": total_pages + } + + +# 获取指定会话的最后一条记录的 id +async def get_last_chat_log_id(mysql_pool, person_id): + """ + 获取指定会话的最后一条记录的 id + :param mysql_pool: MySQL 连接池 + :param session_id: 用户会话 ID + :return: 最后一条记录的 id,如果未找到则返回 None + """ + async with mysql_pool.acquire() as conn: + await conn.ping() # 重置连接 + async with conn.cursor() as cur: + await cur.execute( + "SELECT id FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT 1", + (person_id,) + ) + result = await cur.fetchone() + return result[0] if result else None + + +# 更新为危险的记录 +async def update_risk(mysql_pool, person_id, risk_memo): + async with mysql_pool.acquire() as conn: + await conn.ping() # 重置连接 + async with conn.cursor() as cur: + # 1. 获取此人员的最后一条记录 id + last_id = await get_last_chat_log_id(mysql_pool, person_id) + + if last_id: + # 2. 更新 risk_flag 和 risk_memo + await cur.execute( + "UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s", + (risk_memo.replace('\n', '').replace("NO", ""), last_id) + ) + await conn.commit() + logger.info(f"已更新 person_id={person_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。") + else: + logger.warning(f"未找到 person_id={person_id} 的记录。") + + +# 查询用户信息 +async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]: + """ + 根据用户名查询用户信息 + :param pool: MySQL 连接池 + :param login_name: 用户名 + :return: 用户信息(字典形式) + """ + async with mysql_pool.acquire() as conn: + await conn.ping() # 重置连接 + async with conn.cursor() as cursor: + sql = "SELECT * FROM t_base_person WHERE login_name = %s" + await cursor.execute(sql, (login_name,)) + row = await cursor.fetchone() + if not row: + return None + + # 将元组转换为字典 + columns = [column[0] for column in cursor.description] + return dict(zip(columns, row)) + + +# 显示统计分析页面 +async def get_chat_logs_summary(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int): + """ + 获取聊天记录的统计分析结果 + :param mysql_pool: MySQL 连接池 + :param risk_flag: 风险标志 + :param offset: 偏移量 + :param page_size: 每页记录数 + :return: 日志列表和总记录数 + """ + async with mysql_pool.acquire() as conn: + await conn.ping() # 重置连接 + async with conn.cursor() as cursor: + # 查询符合条件的记录 + sql = """ + SELECT tbp.*, COUNT(*) AS cnt + FROM t_chat_log AS tcl + INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id + WHERE tcl.risk_flag = %s + GROUP BY tcl.person_id + ORDER BY COUNT(*) DESC + LIMIT %s OFFSET %s + """ + await cursor.execute(sql, (risk_flag, page_size, offset)) + rows = await cursor.fetchall() + + # 获取列名 + columns = [column[0] for column in cursor.description] + + # 查询总记录数 + count_sql = """ + SELECT COUNT(DISTINCT tcl.person_id) + FROM t_chat_log AS tcl + INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id + WHERE tcl.risk_flag = %s + """ + await cursor.execute(count_sql, (risk_flag,)) + total = (await cursor.fetchone())[0] + + # 将元组转换为字典 + logs = [dict(zip(columns, row)) for row in rows] if rows else [] + + return logs, total + + +async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, person_id: str, offset: int, page_size: int) -> ( + List[Dict], int): + """ + 根据风险标志查询聊天记录 + :param mysql_pool: MySQL 连接池 + :param risk_flag: 风险标志 + :param offset: 分页偏移量 + :param page_size: 每页记录数 + :return: 聊天记录列表和总记录数 + """ + async with mysql_pool.acquire() as conn: + await conn.ping() # 重置连接 + async with conn.cursor() as cursor: + # 查询符合条件的记录 + sql = """ + SELECT tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration, + tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result, + tbp.person_id, tbp.login_name, tbp.person_name + FROM t_chat_log AS tcl + INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id + WHERE tcl.risk_flag = %s and tcl.person_id=%s ORDER BY TCL.ID DESC + LIMIT %s OFFSET %s + """ + await cursor.execute(sql, (risk_flag, person_id, page_size, offset)) + rows = await cursor.fetchall() + + # 在 count_sql 执行前获取列名 + columns = [column[0] for column in cursor.description] + + # 查询总记录数 + count_sql = """ + SELECT COUNT(*) + FROM t_chat_log AS tcl + INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id + WHERE tcl.risk_flag = %s and tcl.person_id=%s + """ + await cursor.execute(count_sql, (risk_flag, person_id)) + total = (await cursor.fetchone())[0] + + # 将元组转换为字典,并格式化 create_time + if rows: + logs = [] + for row in rows: + log = dict(zip(columns, row)) + # 格式化 create_time + if log["create_time"]: + log["create_time"] = log["create_time"].strftime("%Y-%m-%d %H:%M:%S") + logs.append(log) + return logs, total + return [], 0 diff --git a/dsRag/Util/__pycache__/MySQLUtil.cpython-310.pyc b/dsRag/Util/__pycache__/MySQLUtil.cpython-310.pyc new file mode 100644 index 00000000..c310e899 Binary files /dev/null and b/dsRag/Util/__pycache__/MySQLUtil.cpython-310.pyc differ