整合 dsAiTeachingModel 接口

This commit is contained in:
chengminglong
2025-08-18 10:20:16 +08:00
parent 7b149f0f51
commit b2d5069d79
39 changed files with 1705 additions and 3 deletions

View File

@@ -61,3 +61,8 @@ ZHIPU_API_KEY = "78dc1dfe37e04f29bd4ca9a49858a969.gn7TIZTfzpY35nx9"
# GPTNB的API KEY
GPTNB_API_KEY = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662"
# JWT配置信息
JWT_SECRET_KEY = "ZXZnZWVr5b+r5LmQ5L2g55qE5Ye66KGM"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 300000 # 访问令牌过期时间(分钟)

View File

@@ -0,0 +1,46 @@
# routes/DmController.py
from fastapi import APIRouter, Depends
from starlette.requests import Request
from Util.Database import *
from Routes.TeachingModel.auth.dependencies import get_current_user
from Util.ParseRequest import get_request_num_param
from Util.TranslateUtil import get_stage_map_by_id
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
router = APIRouter(dependencies=[Depends(get_current_user)])
# 学段、学科级联数据
@router.get("/getStageSubjectList")
async def get_stage_subject_list():
# 先查询学段list
select_stage_sql: str = "select stage_id, stage_name from t_dm_stage where b_use = 1 order by sort_id;"
stage_list = await find_by_sql(select_stage_sql,())
for stage in stage_list:
# 再查询学科list
select_subject_sql: str = "select subject_id, subject_name, icon from t_dm_subject where stage_id = " + str(stage["stage_id"]) + " order by sort_id;"
subject_list = await find_by_sql(select_subject_sql,())
stage["subject_list"] = subject_list
return {"success": True, "message": "成功!", "data": stage_list}
# 获取学段信息
@router.get("/getStageList")
async def get_stage_list():
# 先查询学段list
select_stage_sql: str = "select stage_id, stage_name from t_dm_stage where b_use = 1 order by sort_id;"
stage_list = await find_by_sql(select_stage_sql,())
return {"success": True, "message": "查询成功!", "data": {"stage_list": stage_list}}
# 根据学段ID获取学科列表
@router.get("/getSubjectList")
async def get_subject_list(request: Request):
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
stage_name = await get_stage_map_by_id(stage_id)
# 先查询学科list
select_subject_sql: str = f"select subject_id, subject_name, icon, {stage_id} as stage_id, '{stage_name}' as stage_name from t_dm_subject where stage_id = {stage_id} order by sort_id;"
print(select_subject_sql)
subject_list = await find_by_sql(select_subject_sql,())
return {"success": True, "message": "查询成功!", "data": {"subject_list": subject_list}}

View File

@@ -0,0 +1,178 @@
# routes/DocumentController.py
import os
from fastapi import APIRouter, Request, Response, Depends, UploadFile, File
from Routes.TeachingModel.auth.dependencies import get_current_user
from Util.PageUtil import *
from Util.ParseRequest import *
from Util.TranslateUtil import *
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
router = APIRouter(dependencies=[Depends(get_current_user)])
# 创建上传文件的目录
UPLOAD_DIR = "upload_file"
if not os.path.exists(UPLOAD_DIR):
os.makedirs(UPLOAD_DIR)
# 合法文件扩展名
supported_suffix_types = ['doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx']
# 【Document-1】文档管理列表
@router.get("/list")
async def list(request: Request):
# 获取参数
person_id = await get_request_str_param(request, "person_id", True, True)
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
theme_id = await get_request_num_param(request, "theme_id", False, True, -1)
document_suffix = await get_request_str_param(request, "document_suffix", False, True)
document_name = await get_request_str_param(request, "document_name", False, True)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 10)
print(person_id, stage_id, subject_id, document_suffix, document_name, page_number, page_size)
# 拼接查询SQL语句
select_document_sql: str = " SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and person_id = '" + person_id + "'"
if stage_id != -1:
select_document_sql += " AND stage_id = " + str(stage_id)
if subject_id != -1:
select_document_sql += " AND subject_id = " + str(subject_id)
if theme_id != -1:
select_document_sql += " AND theme_id = " + str(theme_id)
if document_suffix != "":
select_document_sql += " AND document_suffix = '" + document_suffix + "'"
if document_name != "":
select_document_sql += " AND document_name like '%" + document_name + "%'"
select_document_sql += " ORDER BY create_time DESC "
# 查询文档列表
page = await get_page_data_by_sql(select_document_sql, page_number, page_size)
person_ids = ""
for item in page["list"]:
person_ids += "'" + item["person_id"] + "',"
if person_ids != "":
person_ids = person_ids[:-1]
else:
person_ids = "''"
person_map = await get_person_map(person_ids)
stage_map = await get_stage_map()
subject_map = await get_subject_map()
for item in page["list"]:
theme_info = await find_by_id("t_ai_teaching_model_theme", "id", item["theme_id"])
item["theme_info"] = theme_info
item["stage_name"] = stage_map.get(str(item["stage_id"]), "未知学段")
item["subject_name"] = subject_map.get(str(item["subject_id"]), "未知学科")
item["person_name"] = person_map.get(str(item["person_id"]), "未知姓名")
item["create_time"] = str(item["create_time"].strftime("%Y-%m-%d %H:%M"))
return {"success": True, "message": "查询成功!", "data": page}
# 【Document-2】保存文档管理
@router.post("/save")
async def save(request: Request, file: UploadFile = File(...)):
# 获取参数
id = await get_request_num_param(request, "id", False, True, 0)
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
person_id = await get_request_str_param(request, "person_id", True, True)
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
# 先获取theme主题信息
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
if theme_object is None:
return {"success": False, "message": "主题不存在!"}
# 获取文件名
document_name = file.filename.split(".")[0]
if len(document_name) > 100:
return {"success": False, "message": "文件名过长已超过100个字符"}
# 检查文件名在该主题下是否重复
select_theme_document_sql: str = "SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and document_name = '" + document_name + "' and theme_id = " + str(theme_id)
if id != 0:
select_theme_document_sql += " AND id <> " + str(id)
theme_document = await find_by_sql(select_theme_document_sql, ())
if theme_document is not None:
return {"success": False, "message": "该主题下文档名称重复!"}
# 获取文件扩展名
document_suffix = file.filename.split(".")[-1]
# 检查文件扩展名
if document_suffix not in supported_suffix_types:
return {"success": False, "message": "不支持的文件类型!"}
# 构造文件保存路径
document_dir = UPLOAD_DIR + "/" + str(theme_object["short_name"]) + "_" + str(theme_object["id"]) + "/"
if not os.path.exists(document_dir):
os.makedirs(document_dir)
document_path = os.path.join(document_dir, file.filename)
# 保存文件
try:
with open(document_path, "wb") as buffer:
buffer.write(await file.read())
except Exception as e:
return {"success": False, "message": f"文件保存失败!{e}"}
# 构造保存文档SQL语句
param = {"stage_id": stage_id, "subject_id": subject_id, "document_name": document_name, "theme_id": theme_id, "document_path": document_path, "document_suffix": document_suffix, "person_id": person_id, "bureau_id": bureau_id}
# 保存数据
if id == 0:
param["train_flag"] = 0
# 插入数据
id = await insert("t_ai_teaching_model_document", param, False)
return {"success": True, "message": "保存成功!", "data": {"insert_id" : id}}
else:
# 更新数据
await update("t_ai_teaching_model_document", param, "id", id)
return {"success": True, "message": "更新成功!", "data": {"update_id" : id}}
# 【Document-3】获取文档信息
@router.get("/get")
async def get(request: Request):
# 获取参数
id = await get_request_num_param(request, "id", True, True, None)
# 查询数据
document_object = await find_by_id("t_ai_teaching_model_document", "id", id)
if document_object is None:
return {"success": False, "message": "未查询到该文档信息!"}
theme_info = await find_by_id("t_ai_teaching_model_theme", "id", document_object["theme_id"])
document_object["theme_info"] = theme_info
stage_map = await get_stage_map()
subject_map = await get_subject_map()
document_object["stage_name"] = stage_map.get(str(document_object["stage_id"]), "未知学段")
document_object["subject_name"] = subject_map.get(str(document_object["subject_id"]), "未知学科")
document_object["person_name"] = await find_person_name_by_id(document_object["person_id"])
return {"success": True, "message": "查询成功!", "data": {"document": document_object}}
# 功能【Document-4】删除文档信息
# 作者Kalman.CHENG ☆
# 时间2025-07-23
# 备注:
@router.post("/delete")
async def delete(request: Request):
# 获取参数
id = await get_request_num_param(request, "id", True, True, None)
document_object = await find_by_id("t_ai_teaching_model_document", "id", id)
if document_object is None:
return {"success": False, "message": "未查询到该文档信息!"}
# 删除文件
train_flag = document_object["train_flag"]
if train_flag == 0:
# 未训练的文档,直接删除文件
await update_batch_property("t_ai_teaching_model_document",
{"is_deleted": 1, "train_flag": 6}, {"is_deleted": 0, "id": id}, False)
elif train_flag == 1:
# 正在训练的文档,不能删除
return {"success": False, "message": "该文档正在训练中,不能删除!"}
elif train_flag == 2:
# 训练完成的文档,删除文件
await update_batch_property("t_ai_teaching_model_document",
{"is_deleted": 1, "train_flag": 3}, {"is_deleted": 0, "id": id}, False)
return {"success": True, "message": "删除成功!"}

View File

@@ -0,0 +1,159 @@
# routes/LoginController.py
import base64
import os
import json
import random
import string
import jwt
from captcha.image import ImageCaptcha
from fastapi import APIRouter, Request, Response, status, HTTPException
from Util.CommonUtil import *
from Util.CookieUtil import *
from Util.Database import *
from Util.JwtUtil import *
from Util.ParseRequest import *
from Config.Config import *
# 创建一个路由实例
router = APIRouter()
# 获取项目根目录
project_root = os.path.dirname(os.path.abspath(__file__))
# 配置验证码
image = ImageCaptcha(
width=100, height=30, # 增加宽度和高度
font_sizes=[26], # 增加字体大小
fonts=[os.path.join(project_root, 'DejaVuSans-Bold.ttf')] # 设置自定义字体路径
)
@router.get("/getCaptcha")
def get_captcha():
captcha_text = ''.join(random.choices(string.digits, k=4)) # 生成4个数字的验证码
session_id = os.urandom(16).hex()
# 将验证码文本存储在session中
session_data = {session_id: captcha_text}
with open("./session/captcha_sessions.json", "a") as session_file:
json.dump(session_data, session_file)
session_file.write("\n")
# 生成验证码图片
data = image.generate(captcha_text)
captcha_image_base64 = base64.b64encode(data.read()).decode()
return {"image": captcha_image_base64, "session_id": session_id}
# 验证用户并生成JWT令牌的接口
@router.post("/validateCaptcha")
async def validate_captcha(session_id : str, captcha : str):
try:
with open("./session/captcha_sessions.json", "r") as session_file:
sessions = {}
for line in session_file:
sessions.update(json.loads(line))
except json.JSONDecodeError:
sessions = {}
correct_captcha_text = sessions.get(session_id)
if not correct_captcha_text or correct_captcha_text.lower() != captcha.lower():
return {"success": False, "message": "验证码错误"}
return {"success": True, "message": "验证码正确"}
# 获取cookie中的token的方法
@router.get("/getToken")
async def get_token(request: Request):
token = CookieUtil.get_cookie(request, key="auth_token")
print(token)
if token:
try:
decoded_token = jwt.decode(token, JWT_SECRET_KEY, algorithms=['HS256'])
logging.info(f"Token 解码成功: {decoded_token}")
return {"success": True, "message": "Token 验证成功", "token_data": decoded_token}
except jwt.ExpiredSignatureError:
logging.error("Token 过期")
return {"success": False, "message": "Token 过期"}
except jwt.InvalidTokenError:
logging.error("无效的 Token")
return {"success": False, "message": "无效的 Token"}
else:
logging.error("未找到 Token")
return {"success": False, "message": "未找到 Token"}
# 登出
@router.get("/logout")
async def logout(request: Request, response: Response):
token = CookieUtil.get_cookie(request, key="auth_token")
if token:
CookieUtil.remove_cookie(response, key="auth_token", path="/" )
logging.info(f"Token <UNK> cookie: {token}")
return {"success": True, "message": "账号已登出!"}
else:
logging.error("<UNK> Token")
return {"success": True, "message": "未找到有效Token,账号已登出!"}
# 验证用户并生成JWT令牌的接口
@router.post("/login")
async def login(request: Request, response: Response):
username = await get_request_str_param(request, "username", True, True)
password = await get_request_str_param(request, "password", True, True)
if not username or not password:
return {"success": False, "message": "用户名和密码不能为空"}
password = md5_encrypt(password)
select_user_sql: str = "SELECT person_id, person_name, identity_id, login_name, xb, bureau_id, org_id, pwdmd5 FROM t_sys_loginperson WHERE login_name = '" + username + "' AND b_use = 1"
userlist = await find_by_sql(select_user_sql,())
user = userlist[0] if userlist else None
logging.info(f"查询结果: {user}")
if user and user['pwdmd5'] == password: # 验证的cas用户密码md5加密的版本
token = create_access_token({"user_id": user['person_id'], "identity_id": user['identity_id']})
CookieUtil.set_cookie(
res=response,
key="auth_token",
value=token,
secure=False, # 在开发环境中,确保 secure=False
httponly=False,
max_age=3600, # 设置cookie的有效时间为1小时
path="/" # 设置cookie的路径
)
logging.info(f"Token 已成功设置到 cookie: {token}")
user.pop('pwdmd5', None) # 移除密码字段
return {"success": True, "message": "登录成功", "data": {"token": token, "user_data": user}}
else:
return {"success": False, "message": "用户名或密码错误"}
# 【Base-Login-3】通过手机号获取Person的ID
@router.get("/getPersonIdByTelephone")
async def get_person_id_by_telephone(request: Request):
telephone = await get_request_str_param(request, "telephone", True, True)
if not telephone:
return {"success": False, "message": "手机号不能为空"}
select_user_sql: str = "SELECT person_id FROM t_sys_loginperson WHERE telephone = '" + telephone + "' and b_use = 1 "
userlist = await find_by_sql(select_user_sql,())
user = userlist[0] if userlist else None
if user:
return {"success": True, "message": "查询成功", "data": {"person_id": user['person_id']}}
else:
return {"success": False, "message": "未查询到相关信息"}
# 【Base-Login-4】忘记密码重设不登录的状态
@router.post("/resetPassword")
async def reset_password(request: Request):
person_id = await get_request_str_param(request, "person_id", True, True)
password = await get_request_str_param(request, "password", True, True)
if not person_id or not password:
return {"success": False, "message": "用户ID和新密码不能为空"}
password_md5 = md5_encrypt(password)
update_user_sql: str = "UPDATE t_sys_loginperson SET original_pwd = '" + password + "', pwdmd5 = '" + password_md5 + "' WHERE person_id = '" + person_id + "'"
await execute_sql(update_user_sql, ())
return {"success": True, "message": "密码修改成功"}

View File

@@ -0,0 +1,292 @@
# routes/TeachingModelController.py
import json
import subprocess
import tempfile
import time
import urllib
import uuid
from io import BytesIO
from fastapi import APIRouter, Depends
from sse_starlette import EventSourceResponse
from starlette.responses import StreamingResponse
from Routes.TeachingModel.auth.dependencies import *
from Util.LightRagUtil import *
from Util.PageUtil import *
from Util.ParseRequest import *
from lightrag import *
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
router = APIRouter(dependencies=[Depends(get_current_user)])
rag_type: str = "file"
# rag_type: str = "pg"
# 【TeachingModel-1】获取主题列表
@router.get("/getTrainedTheme")
async def get_trained_theme(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
subject_id = await get_request_num_param(request, "subject_id", True, True, None)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 10)
# 数据库查询
select_trained_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 AND bureau_id = '{bureau_id}' AND stage_id = {stage_id} AND subject_id = {subject_id}"
print(select_trained_theme_sql)
page = await get_page_data_by_sql(select_trained_theme_sql, page_number, page_size)
page = await translate_person_bureau_name(page)
# 结果返回
return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-2】获取热门主题列表
@router.get("/getHotTheme")
async def get_hot_theme(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 3)
# 数据库查询
select_hot_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY quote_count DESC"
print(select_hot_theme_sql)
page = await get_page_data_by_sql(select_hot_theme_sql, page_number, page_size)
page = await translate_person_bureau_name(page)
# 结果返回
return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-3】获取最新主题列表
@router.get("/getNewTheme")
async def get_new_theme(request: Request):
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 3)
# 数据库查询
select_new_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY create_time DESC"
print(select_new_theme_sql)
page = await get_page_data_by_sql(select_new_theme_sql, page_number, page_size)
page = await translate_person_bureau_name(page)
# 结果返回
return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-4】获取问题列表
@router.get("/getQuestion")
async def get_question(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
person_id = await get_request_str_param(request, "person_id", True, True)
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
question_type = await get_request_num_param(request, "question_type", True, True, None)
page_number = await get_request_num_param(request, "page_number", False, True, 1)
page_size = await get_request_num_param(request, "page_size", False, True, 10)
person_sql = ""
if question_type == 2:
person_sql = f"AND person_id = '{person_id}'"
# 数据库查询
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = {question_type} {person_sql} ORDER BY quote_count DESC, create_time DESC"
print(select_question_sql)
page = await get_page_data_by_sql(select_question_sql, page_number, page_size)
return {"success": True, "message": "查询成功!", "data": page}
##########################################################################################
# 功能【TeachingModel-5】向RAG提问返回RAG的回复
# 作者Kalman.CHENG ☆
# 时间2025-08-14
# 备注question_type=2时question_id必填其他情况question_id不填
# question_type(0新增问题1常见问题2个人历史问题):
# 0【新增问题】会保存成个人历史问题并保存答案
# 1【常见问题】增加常见问题引用次数会保存成个人历史问题并保存答案
# 2【个人历史问题】对应页面新增的历史问题回显后“重试”操作会重新回答问题并替换最新答案
##########################################################################################
@router.post("/sendQuestion")
async def send_question(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
person_id = await get_request_str_param(request, "person_id", True, True)
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
question = await get_request_str_param(request, "question", True, True)
question_type = await get_request_num_param(request, "question_type", True, True, None)
question_id = await get_request_num_param(request, "question_id", False, True,0)
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
if theme_object is None:
return {"success": False, "message": "主题不存在!"}
# 处理theme的调用次数
update_sql: str = f"UPDATE t_ai_teaching_model_theme SET quote_count = quote_count + 1, update_time = now() WHERE id = {theme_id}"
await execute_sql(update_sql, ())
if question_type == 2 and question_id == 0:
return {"success": False, "message": "[question_type]=2时[question_id]不允许为空!"}
if question_type != 2:
param = {}
param["stage_id"] = int(theme_object["stage_id"])
param["subject_id"] = int(theme_object["subject_id"])
param["theme_id"] = theme_id
param["question"] = question
param["question_type"] = 2
param["quote_count"] = 1
param["question_person_id"] = person_id
param["person_id"] = person_id
param["bureau_id"] = bureau_id
question_id = await insert("t_ai_teaching_model_question", param)
if question_type == 1:
# 处理常见问题引用次数
update_common_question_sql: str = f"update t_ai_teaching_model_question set quote_count = quote_count + 1 where question_type = 1 and theme_id = {theme_id} and question = '{question}' and is_deleted = 0"
await execute_sql(update_common_question_sql, ())
# 向rag提问
topic = theme_object["short_name"]
# mode = "hybrid"
prompt = "\n 1、不要输出参考资料 或者 References "
prompt = prompt + "\n 2、资料中提供化学反应方程式的一定要严格按提供的Latex公式输出绝对不允许对Latex公式进行修改 "
prompt = prompt + "\n 3、如果资料中提供了图片的一定要严格按照原文提供图片输出绝对不能省略或不输出"
prompt = prompt + "\n 4、知识库中存在的问题严格按知识库中的内容回答不允许扩展"
prompt = prompt + "\n 5、如果问题与提供的知识库内容不符则明确告诉未在知识库范围内提到请这样告诉我“暂时无法回答这个问题呢我专注于【" + theme_object["theme_name"] + "】这方面知识,若需这方面帮助,请告知我更准确的信息吧~”"
prompt = prompt + "\n 6、发现输出内容中包含Latex公式的一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现"
working_path = "./Topic/" + topic
# if rag_type == "file":
async def generate_response_stream(query: str, mode: str, user_prompt: str):
try:
result = ""
rag = await initialize_rag(working_path)
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
async for chunk in resp:
if not chunk:
continue
yield f"data: {json.dumps({'reply': chunk})}\n\n"
result += chunk
print(chunk, end='', flush=True)
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
# 保存答案
update_question_sql = f"update t_ai_teaching_model_question set question_answer = '{result}', update_time = now() where id = {question_id}"
await execute_sql(update_question_sql, ())
# 清理资源
await rag.finalize_storages()
return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
# elif rag_type == "pg":
# workspace = theme_object["short_name"]
# # 使用PG库后这个是没有用的,但目前的项目代码要求必传,就写一个吧。
# working_dir = 'WorkingPath/' + workspace
# if not os.path.exists(working_dir):
# os.makedirs(working_dir)
# async def generate_response_stream(query: str, mode: str, user_prompt: str):
# try:
# logger.info("workspace=" + workspace)
# rag = await initialize_pg_rag(WORKING_DIR=working_dir, workspace=workspace)
#
# resp = await rag.aquery(
# query=query,
# param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
# async for chunk in resp:
# if not chunk:
# continue
# yield f"data: {json.dumps({'reply': chunk})}\n\n"
# print(chunk, end='', flush=True)
# except Exception as e:
# yield f"data: {json.dumps({'error': str(e)})}\n\n"
# finally:
# # 发送流结束标记
# yield "data: [DONE]\n\n"
# # 清理资源
# await rag.finalize_storages()
# return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
@router.post("/saveWord")
async def save_word(request: Request):
# 获取参数
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
markdown_content = await get_request_str_param(request, "markdown_content", True, True)
question = await get_request_str_param(request, "question", True, True)
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
if theme_object is None:
return {"success": False, "message": "主题不存在!"}
filename = "【理想大模型】" + str(theme_object["theme_name"]) + "(" + str(question) + ")" + str(time.time()) + ".docx"
output_file = None
try:
# markdown_content 替换换行符
# markdown_content = markdown_content.replace("\n", "")
# 创建临时Markdown文件
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
with open(temp_md, "w", encoding="utf-8") as f:
f.write(markdown_content)
# 使用pandoc转换
output_file = os.path.join(tempfile.gettempdir(), filename)
subprocess.run(['pandoc', '--reference-doc', 'static/template/templates.docx', '-s', temp_md, '-o', output_file, '--resource-path=static'], check=True)
# 读取生成的Word文件
with open(output_file, "rb") as f:
stream = BytesIO(f.read())
# 返回响应
encoded_filename = urllib.parse.quote(filename)
return StreamingResponse(
stream,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Connection": "close",
"Content-Type": "application/vnd.ms-excel;charset=UTF-8",
"Content-Disposition": f"attachment; filename={encoded_filename}"}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
...
# 清理临时文件
try:
if temp_md and os.path.exists(temp_md):
os.remove(temp_md)
if output_file and os.path.exists(output_file):
os.remove(output_file)
except Exception as e:
logger.warning(f"Failed to clean up temp files: {str(e)}")
########################################
# 【TeachingModel-7】保存个人历史问题答案
# 作者Kalman.CHENG ☆
# 时间2025-08-14
# 备注:
########################################
@router.post("/saveAnswer")
async def save_answer(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
person_id = await get_request_str_param(request, "person_id", True, True)
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
question = await get_request_str_param(request, "question", True, True)
answer = await get_request_str_param(request, "answer", True, True)
# 验证是否存在个人历史问题
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = 2 AND question = '{question}' AND person_id = '{person_id}' order by create_time desc limit 1"
print(select_question_sql)
select_question_result = await find_by_sql(select_question_sql, ())
if select_question_result is None: # 个人历史问题不存在
return {"success": False, "message": "个人历史问题不存在!"}
question_id = select_question_result[0].get("id")
# 保存答案
update_answer_sql: str = f"UPDATE t_ai_teaching_model_question SET question_answer = '{answer}', update_time = now() WHERE id = {question_id}"
print(update_answer_sql)
await execute_sql(update_answer_sql, ())
# 结果返回
return {"success": True, "message": "保存成功!"}

View File

@@ -0,0 +1,23 @@
# routes/TestController.py
from fastapi import APIRouter, Request
from utils.ParseRequest import *
# 创建一个路由实例
router = APIRouter()
@router.get("/parse", response_model=dict)
async def parse_query(request: Request):
request_data = await parse_request_data(request)
return request_data
@router.post("/parse", response_model=dict)
async def parse_form(request: Request):
request_data = await parse_request_data(request)
return request_data
@router.post("/parse_json", response_model=dict)
async def parse_json(request: Request):
request_data = await parse_request_data(request)
return request_data

View File

@@ -0,0 +1,175 @@
# routes/ThemeController.py
from fastapi import APIRouter, Depends
from Util.Database import *
from Util.ParseRequest import *
from Routes.TeachingModel.auth.dependencies import *
from Util.PageUtil import *
from Util.TranslateUtil import *
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
router = APIRouter(dependencies=[Depends(get_current_user)])
# 功能【Theme-1】主题管理列表
# 作者Kalman.CHENG ☆
# 时间2025-07-14
# 备注:
@router.get("/list")
async def list(request: Request):
# 获取参数
person_id = await get_request_str_param(request, "person_id", True, True)
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
page_number = await get_request_num_param(request, "page_number", False, True,1)
page_size = await get_request_num_param(request, "page_size", False, True, 10)
theme_name = await get_request_str_param(request, "theme_name", False, True)
print(stage_id, person_id, subject_id, page_number, page_size, theme_name)
# 拼接查询SQL语句
select_theme_sql: str = " SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and person_id = '" + person_id + "'"
if stage_id != -1:
select_theme_sql += " and stage_id = " + str(stage_id)
if subject_id != -1:
select_theme_sql += " and subject_id = " + str(subject_id)
if theme_name != "":
select_theme_sql += " and theme_name like '%" + theme_name + "%'"
select_theme_sql += " ORDER BY create_time DESC"
# 查询主题列表
page = await get_page_data_by_sql(select_theme_sql, page_number, page_size)
person_ids = ""
for item in page["list"]:
person_ids += "'" + item["person_id"] + "',"
if person_ids != "":
person_ids = person_ids[:-1]
else:
person_ids = "''"
person_map = await get_person_map(person_ids)
stage_map = await get_stage_map()
subject_map = await get_subject_map()
for item in page["list"]:
item["stage_name"] = stage_map.get(str(item["stage_id"]), "未知学段")
item["subject_name"] = subject_map.get(str(item["subject_id"]), "未知学科")
item["person_name"] = person_map.get(str(item["person_id"]), "未知姓名")
return {"success": True, "message": "查询成功!", "data": page}
# 功能【Theme-2】保存主题管理
# 作者Kalman.CHENG ☆
# 时间2025-07-14
# 备注:
@router.post("/save")
async def save(request: Request):
# 获取参数
id = await get_request_num_param(request, "id", False, True, 0)
theme_name = await get_request_str_param(request, "theme_name", True, True)
if len(theme_name) > 50:
return {"success": False, "message": "主题名称不能超过50个字符"}
short_name = await get_request_str_param(request, "short_name", True, True)
if len(short_name) > 50:
return {"success": False, "message": "主题英文简称不能超过50个字符"}
theme_icon = await get_request_str_param(request, "theme_icon", False, True)
if len(theme_name) > 200:
return {"success": False, "message": "主题图标不能超过200个字符"}
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
subject_id = await get_request_num_param(request, "subject_id", True, True, None)
person_id = await get_request_str_param(request, "person_id", True, True)
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
# 校验参数
check_theme_sql = "SELECT theme_name FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and bureau_id = '" + bureau_id + "' and theme_name = '" + theme_name + "'"
if id != 0:
check_theme_sql += " and id <> " + str(id)
print(check_theme_sql)
check_theme_result = await find_by_sql(check_theme_sql,())
if check_theme_result:
return {"success": False, "message": "该主题名称已存在!"}
if short_name.length > 50:
return {"success": False, "message": "主题英文简称不能超过50个字符"}
check_short_name_sql = "SELECT short_name FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and bureau_id = '" + bureau_id + "' and short_name = '" + short_name + "'"
if id != 0:
check_short_name_sql += " and id <> " + str(id)
print(check_short_name_sql)
check_short_name_result = await find_by_sql(check_short_name_sql,())
if check_short_name_result:
return {"success": False, "message": "该主题英文简称已存在!"}
# 组装参数
param = {"theme_name": theme_name,"short_name": short_name,"theme_icon": theme_icon,"stage_id": stage_id,"subject_id": subject_id,"person_id": person_id,"bureau_id": bureau_id}
# 保存数据
if id == 0:
param["search_flag"] = 0
param["train_flag"] = 0
param["quote_count"] = 0
# 插入数据
id = await insert("t_ai_teaching_model_theme", param, False)
return {"success": True, "message": "保存成功!", "data": {"insert_id" : id}}
else:
# 更新数据
await update("t_ai_teaching_model_theme", param, "id", id, False)
return {"success": True, "message": "更新成功!", "data": {"update_id" : id}}
# 功能【Theme-3】获取主题信息
# 作者Kalman.CHENG ☆
# 时间2025-07-14
# 备注:
@router.get("/get")
async def get(request: Request):
# 获取参数
id = await get_request_num_param(request, "id", True, True, None)
theme_obj = await find_by_id("t_ai_teaching_model_theme", "id", id)
if theme_obj is None:
return {"success": False, "message": "未查询到该主题信息!"}
stage_map = await get_stage_map()
subject_map = await get_subject_map()
theme_obj["stage_name"] = stage_map.get(str(theme_obj["stage_id"]), "未知学段")
theme_obj["subject_name"] = subject_map.get(str(theme_obj["subject_id"]), "未知学科")
theme_obj["person_name"] = await find_person_name_by_id(theme_obj["person_id"])
return {"success": True, "message": "查询成功!", "data": {"theme": theme_obj}}
@router.post("/delete")
async def delete(request: Request):
# 获取参数
id = await get_request_num_param(request, "id", True, True, None)
result = await delete_by_id("t_ai_teaching_model_theme", "id", id)
if not result:
return {"success": False, "message": "删除失败!"}
return {"success": True, "message": "删除成功!"}
# 功能【Theme-5】根据学段学科获取主题列表
# 作者Kalman.CHENG ☆
# 时间2025-07-31
# 备注:
@router.get("/getListByStageSubject")
async def get_list_by_stage_subject(request: Request):
# 获取参数
person_id = await get_request_str_param(request, "person_id", True, True)
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
# 拼接查询SQL语句
select_theme_sql: str = " select id as theme_id, theme_name from t_ai_teaching_model_theme where is_deleted = 0 and person_id = '" + person_id + "'"
if stage_id != -1:
select_theme_sql += " and stage_id = " + str(stage_id)
if subject_id != -1:
select_theme_sql += " and subject_id = " + str(subject_id)
select_theme_result = await find_by_sql(select_theme_sql,())
return {"success": True, "message": "查询成功!", "data": {"theme_list": select_theme_result}}

View File

@@ -0,0 +1,50 @@
# routes/UserController.py
import re
from fastapi import APIRouter, Request, Response, Depends
from Routes.TeachingModel.auth.dependencies import *
from Util.CommonUtil import md5_encrypt
from Util.Database import *
from Util.ParseRequest import *
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
router = APIRouter(dependencies=[Depends(get_current_user)])
# 【Base-User-1】维护用户手机号
@router.post("/modifyTelephone")
async def modify_telephone(request: Request):
person_id = await get_request_str_param(request, "person_id", True, True)
telephone = await get_request_str_param(request, "telephone", True, True)
# 校验手机号码格式
if not re.match(r"^1[3-9]\d{9}$", telephone):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="手机号码格式错误")
# 校验手机号码是否已被注册
select_telephone_sql: str = "select * from t_sys_loginperson where b_use = 1 and telephone = '" + telephone + "' and person_id <> '" + person_id + "'"
userlist = await find_by_sql(select_telephone_sql, ())
if userlist is not None:
return {"success": False, "message": "手机号码已被注册"}
else:
update_telephone_sql: str = "update t_sys_loginperson set telephone = '" + telephone + "' where person_id = '" + person_id + "'"
await execute_sql(update_telephone_sql, ())
return {"success": True, "message": "修改成功"}
# 【Base-User-2】维护用户密码
@router.post("/modifyPassword")
async def modify_password(request: Request):
person_id = await get_request_str_param(request, "person_id", True, True)
old_password = await get_request_str_param(request, "old_password", True, True)
password = await get_request_str_param(request, "password", True, True)
# 校验旧密码是否正确
select_password_sql: str = "select pwdmd5 from t_sys_loginperson where person_id = '" + person_id + "' and b_use = 1"
userlist = await find_by_sql(select_password_sql, ())
if len(userlist) == 0:
return {"success": False, "message": "用户不存在"}
else:
if userlist[0]["pwdmd5"] != md5_encrypt(old_password):
return {"success": False, "message": "旧密码错误"}
else:
update_password_sql: str = "update t_sys_loginperson set original_pwd = '" + password + "',pwdmd5 = '" + md5_encrypt(password) + "' where person_id = '" + person_id + "'"
await execute_sql(update_password_sql, ())
return {"success": True, "message": "修改成功"}

View File

@@ -0,0 +1,29 @@
# dependencies.py
from fastapi import Request, HTTPException, status, Header
from Util.JwtUtil import *
async def get_current_user(token: str = Header(None, description="Authorization token")):
if token is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供令牌,请登录后使用!",
headers={"WWW-Authenticate": "Bearer"},
)
token = token.split(" ")[1] if token.startswith("Bearer ") else token
payload = verify_token(token)
if payload is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据,请登录后使用!",
headers={"WWW-Authenticate": "Bearer"},
)
user_id = payload.get("user_id")
identity_id = payload.get("identity_id")
if user_id is None or identity_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证用户,请重新登录后使用!",
headers={"WWW-Authenticate": "Bearer"},
)
return user_id

View File

@@ -0,0 +1,102 @@
import asyncio
import datetime
import logging
import os
from Util.Database import find_by_sql, find_by_id, execute_sql
from Util.DocxUtil import get_docx_content_by_pandoc
from Util.LightRagUtil import initialize_rag
# 更详细地控制日志输出
logger = logging.getLogger('lightrag')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# 后台任务,监控是否有新的未训练的文档进行训练
async def train_document_task():
print(datetime.datetime.now(), "线程5秒后开始运行【监控是否有需要处理的文档】")
await asyncio.sleep(5) # 使用 asyncio.sleep 而不是 time.sleep
# 这里放置你的线程逻辑
while True:
print("测试定时任务运行")
handle_flag = False
if handle_flag:
# 这里可以放置你的线程要执行的代码
logging.info("开始查询是否有待处理文档:")
no_train_document_sql: str = " SELECT * FROM t_ai_teaching_model_document WHERE train_flag in (0,3) ORDER BY create_time DESC"
no_train_document_result = await find_by_sql(no_train_document_sql, ())
logger.info(no_train_document_result)
if not no_train_document_result:
print(datetime.datetime.now(), "没有需要处理的文档")
else:
print(datetime.datetime.now(), "存在未训练的文档" + str(len(no_train_document_result))+"")
# 这里可以根据train_flag的值来判断是训练还是删除
document = no_train_document_result[0]
theme = await find_by_id("t_ai_teaching_model_theme", "id", document["theme_id"])
document_name = document["document_name"] + "." + document["document_suffix"]
working_dir = "Topic/" + theme["short_name"]
document_path = document["document_path"]
if document["train_flag"] == 0:
# 训练文档
update_sql_document: str = " UPDATE t_ai_teaching_model_document SET train_flag = 1 WHERE id = " + str(document["id"])
await execute_sql(update_sql_document, ())
update_sql_theme: str = " UPDATE t_ai_teaching_model_theme SET train_flag = 1 WHERE id = " + str(theme["id"])
await execute_sql(update_sql_theme, ())
logging.info(f"开始处理文档:{document_name}, 还有{len(no_train_document_result) - 1}个文档需要处理!")
# 训练代码开始
# content = get_docx_content_by_pandoc(document_path)
try:
# 注意默认设置使用NetworkX
rag = await initialize_rag(working_dir)
# 获取docx文件的内容
content = get_docx_content_by_pandoc(document_path)
await rag.ainsert(content, ids=[document_name], file_paths=[document_name])
logger.info(f"Inserted content from {document_name}")
except Exception as e:
logger.error(f"An error occurred: {e}")
finally:
await rag.finalize_storages()
# 训练结束,更新训练状态
update_document_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 2 WHERE id = " + str(document["id"])
await execute_sql(update_document_sql, ())
elif document["train_flag"] == 3:
update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 4 WHERE id = " + str(document["id"])
await execute_sql(update_sql, ())
logging.info(f"开始删除文档:{document_name}, 还有{len(no_train_document_result) - 1}个文档需要处理!")
# 删除文档开始
try:
# 注意默认设置使用NetworkX
rag = await initialize_rag(working_dir)
await rag.adelete_by_doc_id(doc_id = document_name)
logger.info(f"Deleted content from {document_name}")
except Exception as e:
logger.error(f"An error occurred: {e}")
finally:
await rag.finalize_storages()
# 删除结束,更新训练状态
update_document_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 5 WHERE id = " + str(document["id"])
await execute_sql(update_document_sql, ())
# 整体更新主题状态
select_document_sql: str = f"select train_flag, count(1) as train_count from t_ai_teaching_model_document where theme_id = {theme['id']} and is_deleted = 0 and train_flag in (0,1,2) group by train_flag"
select_document_result = await find_by_sql(select_document_sql, ())
train_document_count_map = {}
for item in select_document_result:
train_document_count_map[str(item["train_flag"])] = int(item["train_count"])
train_document_count_1 = train_document_count_map.get("1", 0)
train_document_count_2 = train_document_count_map.get("2", 0)
if train_document_count_2 > 0:
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 1, train_flag = 2 WHERE id = {theme['id']}"
await execute_sql(update_theme_sql, ())
else:
if train_document_count_1 > 0:
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 0, train_flag = 1 WHERE id = {theme['id']}"
await execute_sql(update_theme_sql, ())
else:
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 0, train_flag = 0 WHERE id = {theme['id']}"
await execute_sql(update_theme_sql, ())
# 添加适当的等待时间,避免频繁查询
await asyncio.sleep(20) # 开发阶段每20秒一次
# await asyncio.sleep(600) # 每十分钟查询一次

View File

@@ -1,6 +1,9 @@
import uvicorn
import asyncio
from fastapi import FastAPI
from starlette.staticfiles import StaticFiles
from Routes.TeachingModel.tasks.BackgroundTasks import train_document_task
from Util.PostgreSQLUtil import init_postgres_pool, close_postgres_pool
from Routes.Ggb import router as ggb_router
@@ -9,6 +12,13 @@ from Routes.Oss import router as oss_router
from Routes.Rag import router as rag_router
from Routes.ZuoWen import router as zuowen_router
from Routes.Llm import router as llm_router
from Routes.TeachingModel.api.LoginController import router as login_router
from Routes.TeachingModel.api.UserController import router as user_router
from Routes.TeachingModel.api.DmController import router as dm_router
from Routes.TeachingModel.api.ThemeController import router as theme_router
from Routes.TeachingModel.api.DocumentController import router as document_router
from Routes.TeachingModel.api.TeachingModelController import router as teaching_model_router
from Util.LightRagUtil import *
from contextlib import asynccontextmanager
@@ -25,6 +35,8 @@ async def lifespan(_: FastAPI):
pool = await init_postgres_pool()
app.state.pool = pool
asyncio.create_task(train_document_task())
try:
yield
finally:
@@ -44,5 +56,19 @@ app.include_router(oss_router) # 阿里云OSS路由
app.include_router(llm_router) # 大模型路由
# Teaching Model 相关路由
# 登录相关(不用登录)
app.include_router(login_router, prefix="/api/login", tags=["login"])
# 用户相关
app.include_router(user_router, prefix="/api/user", tags=["user"])
# 字典相关(Dm)
app.include_router(dm_router, prefix="/api/dm", tags=["dm"])
# 主题相关
app.include_router(theme_router, prefix="/api/theme", tags=["theme"])
# 文档相关
app.include_router(document_router, prefix="/api/document", tags=["document"])
# 问题相关(大模型应用)
app.include_router(teaching_model_router, prefix="/api/teaching/model", tags=["teacher_model"])
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8100)

View File

@@ -0,0 +1,17 @@
import hashlib
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def md5_encrypt(text):
# 创建一个md5哈希对象
md5_hash = hashlib.md5()
# 更新哈希对象的数据,注意这里需要将字符串转换为字节类型
md5_hash.update(text.encode('utf-8'))
# 获取十六进制表示的哈希值
encrypted_text = md5_hash.hexdigest()
return encrypted_text

View File

@@ -0,0 +1,81 @@
from fastapi.responses import Response
from fastapi.requests import Request
import logging
class CookieUtil:
@staticmethod
def set_cookie(res: Response, key: str, value: str, max_age: int = 3600, path: str = "/", secure: bool = False,
httponly: bool = True):
"""
设置cookie
:param res: FastAPI的Response对象
:param key: cookie的键
:param value: cookie的值
:param max_age: cookie的有效时间默认1小时3600秒
:param path: cookie的路径默认为根路径("/")
:param secure: 是否仅通过HTTPS传输默认False
:param httponly: 是否仅通过HTTP请求访问默认True
"""
res.set_cookie(
key=key,
value=value,
httponly=httponly,
secure=secure,
max_age=max_age,
path=path
)
logging.info(f"设置 cookie: {key}={value}")
@staticmethod
def get_cookie(req: Request, key: str) -> str:
"""
从请求中获取cookie的值
:param req: FastAPI的Request对象
:param key: cookie的键
:return: cookie的值如果未找到则返回None
"""
token_value = req.cookies.get(key)
logging.info(f"从 cookie 中获取的 {key}: {token_value}")
return token_value
@staticmethod
def remove_cookie(res: Response, key: str, path: str = "/"):
"""
移除cookie
:param res: FastAPI的Response对象
:param key: cookie的键
:param path: cookie的路径默认为根路径("/")
"""
res.delete_cookie(key, path=path)
logging.info(f"移除 cookie: {key}")
# 示例使用
if __name__ == "__main__":
# 创建一个示例响应对象
response = Response(content="示例响应")
# 设置一个cookie
CookieUtil.set_cookie(response, key="auth_token", value="your_token_value")
# 创建一个示例请求对象通常从FastAPI请求中获取
request = Request(scope={
"type": "http",
"headers": [
(b"cookie", b"auth_token=your_token_value")
]
})
# 获取cookie
token = CookieUtil.get_cookie(request, key="auth_token")
print(token)
# 移除cookie
CookieUtil.remove_cookie(response, key="auth_token")
# 打印示例响应头
print(response.headers)

278
dsLightRag/Util/Database.py Normal file
View File

@@ -0,0 +1,278 @@
# Database.py
import datetime
import logging
from Util.PostgreSQLUtil import init_postgres_pool
# 根据sql语句查询数据
async def find_by_sql(sql: str, params: tuple):
pg_pool = await init_postgres_pool()
try:
async with pg_pool.acquire() as conn:
result = await conn.fetch(sql, *params)
# 将 asyncpg.Record 转换为字典
result_dict = [dict(record) for record in result]
if result_dict:
return result_dict
else:
return None
except Exception as e:
logging.error(f"数据库查询错误: {e}")
finally:
if pg_pool is not None:
await pg_pool.close()
# 插入数据
async def insert(tableName, param, onlyForParam=False):
current_time = datetime.datetime.now()
columns = []
values = []
placeholders = []
for key, value in param.items():
if value is not None:
if isinstance(value, (int, float)):
columns.append(key)
values.append(value)
placeholders.append(f"${len(values)}")
elif isinstance(value, str):
columns.append(key)
values.append(value)
placeholders.append(f"${len(values)}")
else:
columns.append(key)
values.append(None)
placeholders.append("NULL")
if not onlyForParam:
if 'is_deleted' not in param:
columns.append("is_deleted")
values.append(0)
placeholders.append(f"${len(values)}")
if 'create_time' not in param:
columns.append("create_time")
values.append(current_time)
placeholders.append(f"${len(values)}")
if 'update_time' not in param:
columns.append("update_time")
values.append(current_time)
placeholders.append(f"${len(values)}")
# 构造 SQL 语句
column_names = ", ".join(columns)
placeholder_names = ", ".join(placeholders)
sql = f"INSERT INTO {tableName} ({column_names}) VALUES ({placeholder_names}) RETURNING id"
pg_pool = await init_postgres_pool()
try:
async with pg_pool.acquire() as conn:
result = await conn.fetchrow(sql, *values)
if result:
return result['id']
else:
logging.error("插入数据失败: 未返回ID")
return None
except Exception as e:
logging.error(f"数据库查询错误: {e}")
logging.error(f"执行的SQL语句: {sql}")
logging.error(f"参数: {values}")
# raise Exception(f"为表[{tableName}]插入数据失败: {e}")
finally:
if pg_pool is not None:
await pg_pool.close()
# 更新数据
async def update(table_name, param, property_name, property_value, only_for_param=False):
current_time = datetime.datetime.now()
set_clauses = []
values = []
# 处理要更新的参数
for key, value in param.items():
if value is not None:
if isinstance(value, (int, float)):
set_clauses.append(f"{key} = ${len(values) + 1}")
values.append(value)
elif isinstance(value, str):
set_clauses.append(f"{key} = ${len(values) + 1}")
values.append(value)
else:
set_clauses.append(f"{key} = NULL")
values.append(None)
if not only_for_param:
if 'update_time' not in param:
set_clauses.append(f"update_time = ${len(values) + 1}")
values.append(current_time)
# 构造 SQL 语句
set_clause = ", ".join(set_clauses)
sql = f"UPDATE {table_name} SET {set_clause} WHERE {property_name} = ${len(values) + 1} RETURNING id"
print(sql)
# 添加条件参数
values.append(property_value)
pg_pool = await init_postgres_pool()
try:
async with pg_pool.acquire() as conn:
result = await conn.fetchrow(sql, *values)
if result:
return result['id']
else:
logging.error("更新数据失败: 未返回ID")
return None
except Exception as e:
logging.error(f"数据库查询错误: {e}")
logging.error(f"执行的SQL语句: {sql}")
logging.error(f"参数: {values}")
# raise Exception(f"为表[{table_name}]更新数据失败: {e}")
finally:
if pg_pool is not None:
await pg_pool.close()
# 获取Bean
# 通过主键查询
async def find_by_id(table_name, property_name, property_value):
if table_name and property_name and property_value is not None:
# 构造 SQL 语句
sql = f"SELECT * FROM {table_name} WHERE is_deleted = 0 AND {property_name} = $1"
logging.debug(sql)
# 执行查询
result = await find_by_sql(sql, (property_value,))
if not result:
logging.info(f"未查询到[{property_name}]为{property_value}的有效数据!")
return None
# 返回第一条数据
return result[0]
else:
logging.error("参数不全")
return None
# 通过主键删除
# 逻辑删除
async def delete_by_id(table_name, property_name, property_value):
if table_name and property_name and property_value is not None:
sql = f"UPDATE {table_name} SET is_deleted = 1, update_time = now() WHERE {property_name} = $1 and is_deleted = 0"
logging.debug(sql)
# 执行删除
pg_pool = await init_postgres_pool()
try:
async with pg_pool.acquire() as conn:
result = await conn.execute(sql, property_value)
if result:
return True
else:
logging.error("删除失败: 未找到数据")
return False
except Exception as e:
logging.error(f"数据库查询错误: {e}")
logging.error(f"执行的SQL语句: {sql}")
logging.error(f"参数: {property_value}")
# raise Exception(f"为表[{table_name}]删除数据失败: {e}")
finally:
if pg_pool is not None:
await pg_pool.close()
else:
logging.error("参数不全")
return False
# 执行一个SQL语句
async def execute_sql(sql, params):
pg_pool = await init_postgres_pool()
try:
async with pg_pool.acquire() as conn:
await conn.fetch(sql, *params)
except Exception as e:
logging.error(f"数据库查询错误: {e}")
logging.error(f"执行的SQL语句: {sql}")
# raise Exception(f"执行SQL失败: {e}")
finally:
if pg_pool is not None:
await pg_pool.close()
async def find_person_name_by_id(person_id):
sql = f"SELECT person_name FROM t_sys_loginperson WHERE person_id = $1 and b_use = 1"
logging.debug(sql)
person_list = await find_by_sql(sql, (person_id,))
if person_list:
return person_list[0]['person_name']
else:
return None
async def find_bureau_name_by_id(org_id):
sql = f"SELECT org_name FROM t_base_organization WHERE org_id = $1 and b_use = 1"
logging.debug(sql)
org_list = await find_by_sql(sql, (org_id,))
if org_list:
return org_list[0]['org_name']
else:
return None
async def update_batch_property(table_name, update_param, property_param, only_for_param=False):
current_time = datetime.datetime.now()
set_clauses = []
values = []
# 处理要更新的参数
for key, value in update_param.items():
if value is not None:
if isinstance(value, (int, float)):
set_clauses.append(f"{key} = ${len(values) + 1}")
values.append(value)
elif isinstance(value, str):
set_clauses.append(f"{key} = ${len(values) + 1}")
values.append(value)
else:
set_clauses.append(f"{key} = NULL")
values.append(None)
if not only_for_param:
if 'update_time' not in update_param:
set_clauses.append(f"update_time = ${len(values) + 1}")
values.append(current_time)
# 构造 WHERE 子句
property_clauses = []
for key, value in property_param.items():
if value is not None:
if isinstance(value, (int, float)):
property_clauses.append(f"{key} = ${len(values) + 1}")
values.append(value)
elif isinstance(value, str):
property_clauses.append(f"{key} = ${len(values) + 1}")
values.append(value)
else:
property_clauses.append(f"{key} IS NULL")
values.append(None)
# 构造 SQL 语句
set_clause = ", ".join(set_clauses)
property_clause = " AND ".join(property_clauses)
sql = f"UPDATE {table_name} SET {set_clause} WHERE {property_clause}"
logging.debug(sql)
pg_pool = await init_postgres_pool()
try:
async with pg_pool.acquire() as conn:
result = await conn.execute(sql, *values)
affected_rows = conn.rowcount
return affected_rows
except Exception as e:
logging.error(f"数据库查询错误: {e}")
logging.error(f"执行的SQL语句: {sql}")
logging.error(f"参数: {values}")
# raise Exception(f"为表[{table_name}]批量更新数据失败: {e}")
finally:
if pg_pool is not None:
await pg_pool.close()

View File

@@ -6,6 +6,7 @@ import uuid
from PIL import Image
import os
from networkx.algorithms.bipartite.centrality import betweenness_centrality
# 或者如果你想更详细地控制日志输出
logger = logging.getLogger('DocxUtil')
@@ -114,8 +115,20 @@ def get_docx_content_by_pandoc(docx_file):
pos = line.find(")")
q = line[:pos + 1]
q = q.replace("./static", ".")
# q = q[4:-1]
# q='<img src="'+q+'" alt="我是图片">'
# Modify by Kalman.CHENG ☆: 增加逻辑对图片路径处理,在(和static之间加上/
left_idx = line.find("(")
static_idx = line.find("static")
if left_idx == -1 or static_idx == -1 or left_idx > static_idx:
print("路径中不包含(+~+static的已知格式")
else:
between_content = q[left_idx+1:static_idx].strip()
if between_content:
q = q[:left_idx+1] + '\\' + q[static_idx:]
else:
q = q[:static_idx] + '\\' + q[static_idx:]
print(f"q3{q}")
#q = q[4:-1]
#q='<img src="'+q+'" alt="我是图片">'
img_idx += 1
content += q + "\n"
else:
@@ -126,4 +139,4 @@ def get_docx_content_by_pandoc(docx_file):
f.write(content)
# 删除临时文件 output_file
# os.remove(temp_markdown)
return content.replace("\n\n", "\n").replace("\\", "")
return content.replace("\n\n", "\n").replace("\\", "/")

View File

@@ -0,0 +1,21 @@
# jwt_utils.py
from datetime import datetime, timedelta
from jose import JWTError, jwt
from Config.Config import *
def create_access_token(data: dict):
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_token(token: str):
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError:
return None

View File

@@ -0,0 +1,68 @@
import math
from Util.Database import *
# 查询数据条数
async def get_total_data_count(total_data_sql):
total_data_count = 0
total_data_count_sql = "select count(*) as num from (" + total_data_sql + ") as temp_table"
result = await find_by_sql(total_data_count_sql,())
row = result[0] if result else None
if row:
total_data_count = row.get("num")
return total_data_count
def get_page_by_total_row(total_data_count, page_number, page_size):
total_page = (page_size != 0) and math.floor((total_data_count + page_size - 1) / page_size) or 0
if page_number <= 0:
page_number = 1
if 0 < total_page < page_number:
page_number = total_page
offset = page_size * page_number - page_size
limit = page_size
return total_data_count, total_page, offset, limit
async def get_page_data_by_sql(total_data_sql: str, page_number: int, page_size: int):
total_row: int = 0
total_page: int = 0
total_data_sql = total_data_sql.replace(";", "")
total_data_sql = total_data_sql.replace(" FROM ", " from ")
# 查询总数
total_data_count = await get_total_data_count(total_data_sql)
if total_data_count == 0:
return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []}
else:
total_row, total_page, offset, limit = get_page_by_total_row(total_data_count, page_number, page_size)
# 构造执行分页查询的sql语句
page_data_sql = total_data_sql + " LIMIT %d offset %d " % (limit, offset)
print(page_data_sql)
# 执行分页查询
page_data = await find_by_sql(page_data_sql, ())
if page_data:
return {"page_number": page_number, "page_size": page_size, "total_row": total_row, "total_page": total_page, "list": page_data}
else:
return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []}
# 翻译person_name, bureau_id
async def translate_person_bureau_name(page):
for item in page["list"]:
if item["person_id"] is not None:
person_id = str(item["person_id"])
person_name = await find_person_name_by_id(person_id)
if person_name:
item["person_name"] = person_name
else:
item["person_name"] = ""
if item["bureau_id"] is not None:
bureau_id = str(item["bureau_id"])
bureau_name = await find_bureau_name_by_id(bureau_id)
if bureau_name:
item["bureau_name"] = bureau_name
else:
item["bureau_name"] = ""
return page

View File

@@ -0,0 +1,102 @@
import datetime
from fastapi import HTTPException, Request, status
async def parse_request_data(request: Request):
data = {
"headers": dict(request.headers),
"params": {},
"cookies": dict(request.cookies),
"time": datetime.datetime.utcnow(),
"ip": request.client.host
}
request_method = request.method
if request_method == "GET":
query_params = request.query_params
for key, value in query_params.items():
parse_args({key: value}, data)
elif request_method == "POST":
content_type = request.headers.get("content-type", "").lower()
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
form_data = await request.form()
for key, value in form_data.items():
parse_args({key: value}, data)
elif "application/json" in content_type:
json_data = await request.json()
for key, value in json_data.items():
parse_args({key: value}, data)
else:
raise HTTPException(status_code=400, detail="Unsupported content type")
return data['params']
def parse_args(args, data):
if args:
for key, value in args.items():
data['params'][key] = value
# 获取请求参数中的字符串参数
# param_name --> 参数名
# nonempty --> 是否必填
# trim --> 是否去除两端空格
# 返回参数值
async def get_request_str_param(request: Request, param_name: str, nonempty: bool, trim: bool):
request_data = await parse_request_data(request)
if request_data is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="请求数据格式不正确",
)
value = str(request_data.get(param_name)) if request_data.get(param_name) is not None else ""
if trim and value != "":
value = value.strip()
if nonempty and value == "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="[" + param_name + "]不允许为空!",
)
return value
# 获取请求参数中的数字参数
# param_name --> 参数名
# nonempty --> 是否必填
# trim --> 是否去除两端空格
# 返回参数值
async def get_request_num_param(request: Request, param_name: str, nonempty: bool, trim: bool, default_value):
value = await get_request_str_param(request, param_name, nonempty, trim)
if nonempty:
return await str2num(param_name, value)
else:
if value == "":
return default_value
return await str2num(param_name, value)
# 字符串转数字, 判断字符串是否包含小数点转float or 转int
# param_name --> 参数名
# value --> 字符串值
# 返回数字值
# 若字符串值不是数字则抛出HTTPException
async def str2num(param_name: str, value: str):
if value.find(".") != -1:
try:
return float(value)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="[" + param_name + "]必须为数字!",
)
else:
try:
return int(value)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="[" + param_name + "]必须为数字!",
)

View File

@@ -0,0 +1,37 @@
from Util.Database import find_by_sql
async def get_stage_map():
select_stage_sql: str = "select * from t_dm_stage where b_use = 1"
select_stage_result = await find_by_sql(select_stage_sql, ())
stage_map = {}
for stage in select_stage_result:
stage_map[str(stage["stage_id"])] = stage["stage_name"]
return stage_map
async def get_stage_map_by_id(stage_id: int):
select_stage_sql: str = f"select stage_id, stage_name from t_dm_stage where b_use = 1 and stage_id = {stage_id}"
select_stage_result = await find_by_sql(select_stage_sql, ())
if select_stage_result is not None:
return select_stage_result[0]["stage_name"]
else:
return "未知学段"
async def get_subject_map():
select_subject_sql: str = "select * from t_dm_subject"
select_subject_result = await find_by_sql(select_subject_sql, ())
subject_map = {}
for subject in select_subject_result:
subject_map[str(subject["subject_id"])] = subject["subject_name"]
return subject_map
async def get_person_map(person_ids: str):
person_id_list = person_ids.split(",")
person_ids = ",".join(person_id_list)
select_person_sql: str = f"select person_id, person_name from t_sys_loginperson where person_id in ({person_ids}) and b_use = 1"
select_person_result = await find_by_sql(select_person_sql, ())
person_map = {}
if select_person_result is not None:
for person in select_person_result:
person_map[str(person["person_id"])] = person["person_name"]
return person_map

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.