main
HuangHai 4 weeks ago
parent bc89a6cd6f
commit 93f3ab1ed9

@ -13,3 +13,11 @@ WORD2VEC_MODEL_PATH = r"D:\Tencent_AILab_ChineseEmbedding\Tencent_AILab_ChineseE
# DeepSeek
DEEPSEEK_API_KEY = 'sk-44ae895eeb614aa1a9c6460579e322f1'
DEEPSEEK_URL = 'https://api.deepseek.com'
# MYSQL配置信息
MYSQL_HOST = "10.10.14.210"
MYSQL_PORT = 22066
MYSQL_USER = "root"
MYSQL_PASSWORD = "DsideaL147258369"
MYSQL_DB_NAME = "base_db"

@ -0,0 +1,61 @@
/*
Navicat Premium Dump SQL
Source Server : 10.10.14.210
Source Server Type : MySQL
Source Server Version : 50742 (5.7.42-log)
Source Host : 10.10.14.210:22066
Source Schema : base_db
Target Server Type : MySQL
Target Server Version : 50742 (5.7.42-log)
File Encoding : 65001
Date: 24/06/2025 18:45:23
*/
SET NAMES utf8mb4;
SET FOREIGN_KEY_CHECKS = 0;
-- ----------------------------
-- Table structure for t_ai_kb
-- ----------------------------
DROP TABLE IF EXISTS `t_ai_kb`;
CREATE TABLE `t_ai_kb` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT '主键',
`kb_name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '知识库名称',
`short_name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '英文简称',
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
`is_delete` int(11) NOT NULL DEFAULT 0 COMMENT '是否删除',
PRIMARY KEY (`id`) USING BTREE,
UNIQUE INDEX `short_name`(`short_name`) USING BTREE,
INDEX `is_delete`(`is_delete`) USING BTREE
) ENGINE = InnoDB AUTO_INCREMENT = 1 CHARACTER SET = utf8mb4 COLLATE = utf8mb4_general_ci COMMENT = 'AI知识库' ROW_FORMAT = Dynamic;
-- ----------------------------
-- Records of t_ai_kb
-- ----------------------------
-- ----------------------------
-- Table structure for t_ai_kb_files
-- ----------------------------
DROP TABLE IF EXISTS `t_ai_kb_files`;
CREATE TABLE `t_ai_kb_files` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT '主键',
`file_name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '文件名称',
`ext_name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT '文件扩展名',
`create_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
`kb_id` int(11) NOT NULL COMMENT '隶属知识库ID',
`is_delete` int(11) NOT NULL DEFAULT 0 COMMENT '是否删除',
`state` int(11) NOT NULL DEFAULT 0 COMMENT '0:上传后未处理1上传后已处理,2:处理失败',
PRIMARY KEY (`id`) USING BTREE,
INDEX `kb_id`(`kb_id`) USING BTREE,
INDEX `is_delete`(`is_delete`, `state`) USING BTREE,
CONSTRAINT `t_ai_kb_files_ibfk_1` FOREIGN KEY (`kb_id`) REFERENCES `t_ai_kb` (`id`) ON DELETE CASCADE ON UPDATE RESTRICT
) ENGINE = InnoDB AUTO_INCREMENT = 1 CHARACTER SET = utf8mb4 COLLATE = utf8mb4_general_ci COMMENT = 'AI知识库上传的文件' ROW_FORMAT = Dynamic;
-- ----------------------------
-- Records of t_ai_kb_files
-- ----------------------------
SET FOREIGN_KEY_CHECKS = 1;

@ -0,0 +1,137 @@
"""
pip install fastapi uvicorn aiomysql
"""
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from Util.MySQLUtil import *
"""
API文档访问 http://localhost:8000/docs
该实现包含以下功能
- 知识库(t_ai_kb)的增删改查接口
- 知识库文件(t_ai_kb_files)的增删改查接口
- 使用MySQLUtil.py中的连接池管理
- 自动生成的Swagger文档
"""
app = FastAPI()
# 知识库模型
class KbModel(BaseModel):
kb_name: str
short_name: str
is_delete: Optional[int] = 0
# 知识库文件模型
class KbFileModel(BaseModel):
file_name: str
ext_name: str
kb_id: int
is_delete: Optional[int] = 0
state: Optional[int] = 0
@app.on_event("startup")
async def startup_event():
app.state.mysql_pool = await init_mysql_pool()
@app.on_event("shutdown")
async def shutdown_event():
app.state.mysql_pool.close()
await app.state.mysql_pool.wait_closed()
# 知识库CRUD接口
@app.post("/kb")
async def create_kb(kb: KbModel):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb (kb_name, short_name, is_delete)
VALUES (%s, %s, %s)""",
(kb.kb_name, kb.short_name, kb.is_delete)
)
await conn.commit()
return {"id": cur.lastrowid}
@app.get("/kb/{kb_id}")
async def read_kb(kb_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,))
result = await cur.fetchone()
if not result:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return result
@app.put("/kb/{kb_id}")
async def update_kb(kb_id: int, kb: KbModel):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb
SET kb_name = %s, short_name = %s, is_delete = %s
WHERE id = %s""",
(kb.kb_name, kb.short_name, kb.is_delete, kb_id)
)
await conn.commit()
return {"message": "Knowledge base updated"}
@app.delete("/kb/{kb_id}")
async def delete_kb(kb_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("DELETE FROM t_ai_kb WHERE id = %s", (kb_id,))
await conn.commit()
return {"message": "Knowledge base deleted"}
# 知识库文件CRUD接口
@app.post("/kb_files")
async def create_kb_file(file: KbFileModel):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb_files
(file_name, ext_name, kb_id, is_delete, state)
VALUES (%s, %s, %s, %s, %s)""",
(file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state)
)
await conn.commit()
return {"id": cur.lastrowid}
@app.get("/kb_files/{file_id}")
async def read_kb_file(file_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute("SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,))
result = await cur.fetchone()
if not result:
raise HTTPException(status_code=404, detail="File not found")
return result
@app.put("/kb_files/{file_id}")
async def update_kb_file(file_id: int, file: KbFileModel):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb_files
SET file_name = %s, ext_name = %s, kb_id = %s,
is_delete = %s, state = %s
WHERE id = %s""",
(file.file_name, file.ext_name, file.kb_id,
file.is_delete, file.state, file_id)
)
await conn.commit()
return {"message": "File updated"}
@app.delete("/kb_files/{file_id}")
async def delete_kb_file(file_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,))
await conn.commit()
return {"message": "File deleted"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

@ -0,0 +1,281 @@
"""
pip install aiomysql
"""
import logging
from typing import Optional, Dict, List
from aiomysql import create_pool
from Config.Config import *
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# MySQL 配置
MYSQL_CONFIG = {
"host": MYSQL_HOST,
"port": MYSQL_PORT,
"user": MYSQL_USER,
"password": MYSQL_PASSWORD,
"db": MYSQL_DB_NAME,
"minsize": 1,
"maxsize": 20,
}
# 初始化 MySQL 连接池
async def init_mysql_pool():
return await create_pool(**MYSQL_CONFIG)
# 保存聊天记录到 MySQL
async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, duration, input_type=1, output_type=1,
input_image_type=0, image_width=0, image_height=0):
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO t_chat_log (person_id, user_input, model_response,audio_url,duration,input_type,output_type,input_image_type,image_width,image_height,create_time) VALUES (%s, %s, %s, %s, %s, %s, %s,%s,%s,%s,NOW())",
(person_id, prompt, result, audio_url, duration, input_type, output_type, input_image_type, image_width,
image_height)
)
await conn.commit()
# 清空表
async def truncate_chat_log(mysql_pool):
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
await cur.execute("TRUNCATE TABLE t_chat_log")
await conn.commit()
logger.info("表 t_chat_log 已清空。")
from aiomysql import DictCursor
# 分页查询聊天记录
async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
"""
根据 person_id 查询聊天记录并按 id 降序分页
:param mysql_pool: MySQL 连接池
:param person_id: 用户会话 ID
:param page: 当前页码默认值为 1但会动态计算为最后一页
:param page_size: 每页记录数
:return: 分页数据
"""
if not mysql_pool:
raise ValueError("MySQL 连接池未初始化")
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor(DictCursor) as cur: # 使用 DictCursor
# 查询总记录数
await cur.execute(
"SELECT COUNT(*) FROM t_chat_log WHERE person_id = %s",
(person_id,)
)
total = (await cur.fetchone())['COUNT(*)']
# 计算总页数
total_pages = (total + page_size - 1) // page_size
# 计算偏移量
offset = (page - 1) * page_size
# 查询分页数据,按 id 降序排列
await cur.execute(
"SELECT id, person_id, user_input, model_response, audio_url, duration,input_type,output_type,input_image_type,image_width,image_height, create_time "
"FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(person_id, page_size, offset)
)
records = await cur.fetchall()
# 将查询结果反转,确保最新消息显示在最后
if page==1 and records:
records.reverse()
# 将查询结果转换为字典列表
result = [
{
"id": record['id'],
"person_id": record['person_id'],
"user_input": record['user_input'],
"model_response": record['model_response'],
"audio_url": record['audio_url'],
"duration": record['duration'],
"input_type": record['input_type'],
"output_type": record['output_type'],
"image_width": record['image_width'],
"image_height": record['image_height'],
"input_image_type": record['input_image_type'],
"create_time": record['create_time'].strftime("%Y-%m-%d %H:%M:%S")
}
for record in records
]
return {
"data": result, # 按 id 升序排列的数据
"total": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages
}
# 获取指定会话的最后一条记录的 id
async def get_last_chat_log_id(mysql_pool, person_id):
"""
获取指定会话的最后一条记录的 id
:param mysql_pool: MySQL 连接池
:param session_id: 用户会话 ID
:return: 最后一条记录的 id如果未找到则返回 None
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
await cur.execute(
"SELECT id FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT 1",
(person_id,)
)
result = await cur.fetchone()
return result[0] if result else None
# 更新为危险的记录
async def update_risk(mysql_pool, person_id, risk_memo):
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
# 1. 获取此人员的最后一条记录 id
last_id = await get_last_chat_log_id(mysql_pool, person_id)
if last_id:
# 2. 更新 risk_flag 和 risk_memo
await cur.execute(
"UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s",
(risk_memo.replace('\n', '').replace("NO", ""), last_id)
)
await conn.commit()
logger.info(f"已更新 person_id={person_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。")
else:
logger.warning(f"未找到 person_id={person_id} 的记录。")
# 查询用户信息
async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
"""
根据用户名查询用户信息
:param pool: MySQL 连接池
:param login_name: 用户名
:return: 用户信息字典形式
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cursor:
sql = "SELECT * FROM t_base_person WHERE login_name = %s"
await cursor.execute(sql, (login_name,))
row = await cursor.fetchone()
if not row:
return None
# 将元组转换为字典
columns = [column[0] for column in cursor.description]
return dict(zip(columns, row))
# 显示统计分析页面
async def get_chat_logs_summary(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int):
"""
获取聊天记录的统计分析结果
:param mysql_pool: MySQL 连接池
:param risk_flag: 风险标志
:param offset: 偏移量
:param page_size: 每页记录数
:return: 日志列表和总记录数
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cursor:
# 查询符合条件的记录
sql = """
SELECT tbp.*, COUNT(*) AS cnt
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
GROUP BY tcl.person_id
ORDER BY COUNT(*) DESC
LIMIT %s OFFSET %s
"""
await cursor.execute(sql, (risk_flag, page_size, offset))
rows = await cursor.fetchall()
# 获取列名
columns = [column[0] for column in cursor.description]
# 查询总记录数
count_sql = """
SELECT COUNT(DISTINCT tcl.person_id)
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
"""
await cursor.execute(count_sql, (risk_flag,))
total = (await cursor.fetchone())[0]
# 将元组转换为字典
logs = [dict(zip(columns, row)) for row in rows] if rows else []
return logs, total
async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, person_id: str, offset: int, page_size: int) -> (
List[Dict], int):
"""
根据风险标志查询聊天记录
:param mysql_pool: MySQL 连接池
:param risk_flag: 风险标志
:param offset: 分页偏移量
:param page_size: 每页记录数
:return: 聊天记录列表和总记录数
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cursor:
# 查询符合条件的记录
sql = """
SELECT tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration,
tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result,
tbp.person_id, tbp.login_name, tbp.person_name
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s and tcl.person_id=%s ORDER BY TCL.ID DESC
LIMIT %s OFFSET %s
"""
await cursor.execute(sql, (risk_flag, person_id, page_size, offset))
rows = await cursor.fetchall()
# 在 count_sql 执行前获取列名
columns = [column[0] for column in cursor.description]
# 查询总记录数
count_sql = """
SELECT COUNT(*)
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s and tcl.person_id=%s
"""
await cursor.execute(count_sql, (risk_flag, person_id))
total = (await cursor.fetchone())[0]
# 将元组转换为字典,并格式化 create_time
if rows:
logs = []
for row in rows:
log = dict(zip(columns, row))
# 格式化 create_time
if log["create_time"]:
log["create_time"] = log["create_time"].strftime("%Y-%m-%d %H:%M:%S")
logs.append(log)
return logs, total
return [], 0
Loading…
Cancel
Save