整合 dsAiTeachingModel 接口
This commit is contained in:
@@ -61,3 +61,8 @@ ZHIPU_API_KEY = "78dc1dfe37e04f29bd4ca9a49858a969.gn7TIZTfzpY35nx9"
|
||||
|
||||
# GPTNB的API KEY
|
||||
GPTNB_API_KEY = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662"
|
||||
|
||||
# JWT配置信息
|
||||
JWT_SECRET_KEY = "ZXZnZWVr5b+r5LmQ5L2g55qE5Ye66KGM"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 300000 # 访问令牌过期时间(分钟)
|
||||
|
46
dsLightRag/Routes/TeachingModel/api/DmController.py
Normal file
46
dsLightRag/Routes/TeachingModel/api/DmController.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# routes/DmController.py
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from starlette.requests import Request
|
||||
|
||||
from Util.Database import *
|
||||
from Routes.TeachingModel.auth.dependencies import get_current_user
|
||||
from Util.ParseRequest import get_request_num_param
|
||||
from Util.TranslateUtil import get_stage_map_by_id
|
||||
|
||||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
|
||||
# 学段、学科级联数据
|
||||
@router.get("/getStageSubjectList")
|
||||
async def get_stage_subject_list():
|
||||
# 先查询学段list
|
||||
select_stage_sql: str = "select stage_id, stage_name from t_dm_stage where b_use = 1 order by sort_id;"
|
||||
stage_list = await find_by_sql(select_stage_sql,())
|
||||
for stage in stage_list:
|
||||
# 再查询学科list
|
||||
select_subject_sql: str = "select subject_id, subject_name, icon from t_dm_subject where stage_id = " + str(stage["stage_id"]) + " order by sort_id;"
|
||||
subject_list = await find_by_sql(select_subject_sql,())
|
||||
stage["subject_list"] = subject_list
|
||||
|
||||
return {"success": True, "message": "成功!", "data": stage_list}
|
||||
|
||||
# 获取学段信息
|
||||
@router.get("/getStageList")
|
||||
async def get_stage_list():
|
||||
# 先查询学段list
|
||||
select_stage_sql: str = "select stage_id, stage_name from t_dm_stage where b_use = 1 order by sort_id;"
|
||||
stage_list = await find_by_sql(select_stage_sql,())
|
||||
return {"success": True, "message": "查询成功!", "data": {"stage_list": stage_list}}
|
||||
|
||||
|
||||
# 根据学段ID获取学科列表
|
||||
@router.get("/getSubjectList")
|
||||
async def get_subject_list(request: Request):
|
||||
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
|
||||
stage_name = await get_stage_map_by_id(stage_id)
|
||||
# 先查询学科list
|
||||
select_subject_sql: str = f"select subject_id, subject_name, icon, {stage_id} as stage_id, '{stage_name}' as stage_name from t_dm_subject where stage_id = {stage_id} order by sort_id;"
|
||||
print(select_subject_sql)
|
||||
subject_list = await find_by_sql(select_subject_sql,())
|
||||
return {"success": True, "message": "查询成功!", "data": {"subject_list": subject_list}}
|
178
dsLightRag/Routes/TeachingModel/api/DocumentController.py
Normal file
178
dsLightRag/Routes/TeachingModel/api/DocumentController.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# routes/DocumentController.py
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Request, Response, Depends, UploadFile, File
|
||||
|
||||
from Routes.TeachingModel.auth.dependencies import get_current_user
|
||||
from Util.PageUtil import *
|
||||
from Util.ParseRequest import *
|
||||
from Util.TranslateUtil import *
|
||||
|
||||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
|
||||
# 创建上传文件的目录
|
||||
UPLOAD_DIR = "upload_file"
|
||||
if not os.path.exists(UPLOAD_DIR):
|
||||
os.makedirs(UPLOAD_DIR)
|
||||
|
||||
# 合法文件扩展名
|
||||
supported_suffix_types = ['doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx']
|
||||
|
||||
# 【Document-1】文档管理列表
|
||||
@router.get("/list")
|
||||
async def list(request: Request):
|
||||
# 获取参数
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
|
||||
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
|
||||
theme_id = await get_request_num_param(request, "theme_id", False, True, -1)
|
||||
document_suffix = await get_request_str_param(request, "document_suffix", False, True)
|
||||
document_name = await get_request_str_param(request, "document_name", False, True)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
||||
|
||||
print(person_id, stage_id, subject_id, document_suffix, document_name, page_number, page_size)
|
||||
|
||||
# 拼接查询SQL语句
|
||||
|
||||
select_document_sql: str = " SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and person_id = '" + person_id + "'"
|
||||
if stage_id != -1:
|
||||
select_document_sql += " AND stage_id = " + str(stage_id)
|
||||
if subject_id != -1:
|
||||
select_document_sql += " AND subject_id = " + str(subject_id)
|
||||
if theme_id != -1:
|
||||
select_document_sql += " AND theme_id = " + str(theme_id)
|
||||
if document_suffix != "":
|
||||
select_document_sql += " AND document_suffix = '" + document_suffix + "'"
|
||||
if document_name != "":
|
||||
select_document_sql += " AND document_name like '%" + document_name + "%'"
|
||||
select_document_sql += " ORDER BY create_time DESC "
|
||||
|
||||
# 查询文档列表
|
||||
page = await get_page_data_by_sql(select_document_sql, page_number, page_size)
|
||||
person_ids = ""
|
||||
for item in page["list"]:
|
||||
person_ids += "'" + item["person_id"] + "',"
|
||||
if person_ids != "":
|
||||
person_ids = person_ids[:-1]
|
||||
else:
|
||||
person_ids = "''"
|
||||
person_map = await get_person_map(person_ids)
|
||||
stage_map = await get_stage_map()
|
||||
subject_map = await get_subject_map()
|
||||
|
||||
for item in page["list"]:
|
||||
theme_info = await find_by_id("t_ai_teaching_model_theme", "id", item["theme_id"])
|
||||
item["theme_info"] = theme_info
|
||||
item["stage_name"] = stage_map.get(str(item["stage_id"]), "未知学段")
|
||||
item["subject_name"] = subject_map.get(str(item["subject_id"]), "未知学科")
|
||||
item["person_name"] = person_map.get(str(item["person_id"]), "未知姓名")
|
||||
item["create_time"] = str(item["create_time"].strftime("%Y-%m-%d %H:%M"))
|
||||
|
||||
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
# 【Document-2】保存文档管理
|
||||
@router.post("/save")
|
||||
async def save(request: Request, file: UploadFile = File(...)):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", False, True, 0)
|
||||
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
|
||||
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
|
||||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
# 先获取theme主题信息
|
||||
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
|
||||
if theme_object is None:
|
||||
return {"success": False, "message": "主题不存在!"}
|
||||
# 获取文件名
|
||||
document_name = file.filename.split(".")[0]
|
||||
if len(document_name) > 100:
|
||||
return {"success": False, "message": "文件名过长,已超过100个字符!"}
|
||||
# 检查文件名在该主题下是否重复
|
||||
select_theme_document_sql: str = "SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and document_name = '" + document_name + "' and theme_id = " + str(theme_id)
|
||||
if id != 0:
|
||||
select_theme_document_sql += " AND id <> " + str(id)
|
||||
theme_document = await find_by_sql(select_theme_document_sql, ())
|
||||
if theme_document is not None:
|
||||
return {"success": False, "message": "该主题下文档名称重复!"}
|
||||
# 获取文件扩展名
|
||||
document_suffix = file.filename.split(".")[-1]
|
||||
# 检查文件扩展名
|
||||
if document_suffix not in supported_suffix_types:
|
||||
return {"success": False, "message": "不支持的文件类型!"}
|
||||
# 构造文件保存路径
|
||||
document_dir = UPLOAD_DIR + "/" + str(theme_object["short_name"]) + "_" + str(theme_object["id"]) + "/"
|
||||
if not os.path.exists(document_dir):
|
||||
os.makedirs(document_dir)
|
||||
document_path = os.path.join(document_dir, file.filename)
|
||||
# 保存文件
|
||||
try:
|
||||
with open(document_path, "wb") as buffer:
|
||||
buffer.write(await file.read())
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"文件保存失败!{e}"}
|
||||
|
||||
# 构造保存文档SQL语句
|
||||
param = {"stage_id": stage_id, "subject_id": subject_id, "document_name": document_name, "theme_id": theme_id, "document_path": document_path, "document_suffix": document_suffix, "person_id": person_id, "bureau_id": bureau_id}
|
||||
|
||||
# 保存数据
|
||||
if id == 0:
|
||||
param["train_flag"] = 0
|
||||
# 插入数据
|
||||
id = await insert("t_ai_teaching_model_document", param, False)
|
||||
return {"success": True, "message": "保存成功!", "data": {"insert_id" : id}}
|
||||
else:
|
||||
# 更新数据
|
||||
await update("t_ai_teaching_model_document", param, "id", id)
|
||||
return {"success": True, "message": "更新成功!", "data": {"update_id" : id}}
|
||||
|
||||
# 【Document-3】获取文档信息
|
||||
@router.get("/get")
|
||||
async def get(request: Request):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", True, True, None)
|
||||
# 查询数据
|
||||
document_object = await find_by_id("t_ai_teaching_model_document", "id", id)
|
||||
if document_object is None:
|
||||
return {"success": False, "message": "未查询到该文档信息!"}
|
||||
theme_info = await find_by_id("t_ai_teaching_model_theme", "id", document_object["theme_id"])
|
||||
document_object["theme_info"] = theme_info
|
||||
stage_map = await get_stage_map()
|
||||
subject_map = await get_subject_map()
|
||||
document_object["stage_name"] = stage_map.get(str(document_object["stage_id"]), "未知学段")
|
||||
document_object["subject_name"] = subject_map.get(str(document_object["subject_id"]), "未知学科")
|
||||
document_object["person_name"] = await find_person_name_by_id(document_object["person_id"])
|
||||
|
||||
return {"success": True, "message": "查询成功!", "data": {"document": document_object}}
|
||||
|
||||
|
||||
|
||||
# 功能:【Document-4】删除文档信息
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-07-23
|
||||
# 备注:
|
||||
@router.post("/delete")
|
||||
async def delete(request: Request):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", True, True, None)
|
||||
document_object = await find_by_id("t_ai_teaching_model_document", "id", id)
|
||||
if document_object is None:
|
||||
return {"success": False, "message": "未查询到该文档信息!"}
|
||||
# 删除文件
|
||||
train_flag = document_object["train_flag"]
|
||||
if train_flag == 0:
|
||||
# 未训练的文档,直接删除文件
|
||||
await update_batch_property("t_ai_teaching_model_document",
|
||||
{"is_deleted": 1, "train_flag": 6}, {"is_deleted": 0, "id": id}, False)
|
||||
elif train_flag == 1:
|
||||
# 正在训练的文档,不能删除
|
||||
return {"success": False, "message": "该文档正在训练中,不能删除!"}
|
||||
elif train_flag == 2:
|
||||
# 训练完成的文档,删除文件
|
||||
await update_batch_property("t_ai_teaching_model_document",
|
||||
{"is_deleted": 1, "train_flag": 3}, {"is_deleted": 0, "id": id}, False)
|
||||
return {"success": True, "message": "删除成功!"}
|
159
dsLightRag/Routes/TeachingModel/api/LoginController.py
Normal file
159
dsLightRag/Routes/TeachingModel/api/LoginController.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# routes/LoginController.py
|
||||
import base64
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
import jwt
|
||||
|
||||
from captcha.image import ImageCaptcha
|
||||
from fastapi import APIRouter, Request, Response, status, HTTPException
|
||||
from Util.CommonUtil import *
|
||||
from Util.CookieUtil import *
|
||||
from Util.Database import *
|
||||
from Util.JwtUtil import *
|
||||
from Util.ParseRequest import *
|
||||
from Config.Config import *
|
||||
|
||||
# 创建一个路由实例
|
||||
router = APIRouter()
|
||||
|
||||
# 获取项目根目录
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 配置验证码
|
||||
image = ImageCaptcha(
|
||||
width=100, height=30, # 增加宽度和高度
|
||||
font_sizes=[26], # 增加字体大小
|
||||
fonts=[os.path.join(project_root, 'DejaVuSans-Bold.ttf')] # 设置自定义字体路径
|
||||
)
|
||||
|
||||
@router.get("/getCaptcha")
|
||||
def get_captcha():
|
||||
captcha_text = ''.join(random.choices(string.digits, k=4)) # 生成4个数字的验证码
|
||||
session_id = os.urandom(16).hex()
|
||||
|
||||
# 将验证码文本存储在session中
|
||||
session_data = {session_id: captcha_text}
|
||||
with open("./session/captcha_sessions.json", "a") as session_file:
|
||||
json.dump(session_data, session_file)
|
||||
session_file.write("\n")
|
||||
|
||||
# 生成验证码图片
|
||||
data = image.generate(captcha_text)
|
||||
captcha_image_base64 = base64.b64encode(data.read()).decode()
|
||||
|
||||
return {"image": captcha_image_base64, "session_id": session_id}
|
||||
|
||||
# 验证用户并生成JWT令牌的接口
|
||||
@router.post("/validateCaptcha")
|
||||
async def validate_captcha(session_id : str, captcha : str):
|
||||
try:
|
||||
with open("./session/captcha_sessions.json", "r") as session_file:
|
||||
sessions = {}
|
||||
for line in session_file:
|
||||
sessions.update(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
sessions = {}
|
||||
|
||||
correct_captcha_text = sessions.get(session_id)
|
||||
|
||||
if not correct_captcha_text or correct_captcha_text.lower() != captcha.lower():
|
||||
return {"success": False, "message": "验证码错误"}
|
||||
return {"success": True, "message": "验证码正确"}
|
||||
|
||||
# 获取cookie中的token的方法
|
||||
@router.get("/getToken")
|
||||
async def get_token(request: Request):
|
||||
token = CookieUtil.get_cookie(request, key="auth_token")
|
||||
print(token)
|
||||
if token:
|
||||
try:
|
||||
decoded_token = jwt.decode(token, JWT_SECRET_KEY, algorithms=['HS256'])
|
||||
logging.info(f"Token 解码成功: {decoded_token}")
|
||||
return {"success": True, "message": "Token 验证成功", "token_data": decoded_token}
|
||||
except jwt.ExpiredSignatureError:
|
||||
logging.error("Token 过期")
|
||||
return {"success": False, "message": "Token 过期"}
|
||||
except jwt.InvalidTokenError:
|
||||
logging.error("无效的 Token")
|
||||
return {"success": False, "message": "无效的 Token"}
|
||||
else:
|
||||
logging.error("未找到 Token")
|
||||
return {"success": False, "message": "未找到 Token"}
|
||||
|
||||
|
||||
# 登出
|
||||
@router.get("/logout")
|
||||
async def logout(request: Request, response: Response):
|
||||
token = CookieUtil.get_cookie(request, key="auth_token")
|
||||
if token:
|
||||
CookieUtil.remove_cookie(response, key="auth_token", path="/" )
|
||||
logging.info(f"Token <UNK> cookie: {token}")
|
||||
return {"success": True, "message": "账号已登出!"}
|
||||
else:
|
||||
logging.error("<UNK> Token")
|
||||
return {"success": True, "message": "未找到有效Token,账号已登出!"}
|
||||
|
||||
|
||||
|
||||
# 验证用户并生成JWT令牌的接口
|
||||
@router.post("/login")
|
||||
async def login(request: Request, response: Response):
|
||||
|
||||
username = await get_request_str_param(request, "username", True, True)
|
||||
password = await get_request_str_param(request, "password", True, True)
|
||||
|
||||
if not username or not password:
|
||||
return {"success": False, "message": "用户名和密码不能为空"}
|
||||
|
||||
password = md5_encrypt(password)
|
||||
select_user_sql: str = "SELECT person_id, person_name, identity_id, login_name, xb, bureau_id, org_id, pwdmd5 FROM t_sys_loginperson WHERE login_name = '" + username + "' AND b_use = 1"
|
||||
userlist = await find_by_sql(select_user_sql,())
|
||||
user = userlist[0] if userlist else None
|
||||
logging.info(f"查询结果: {user}")
|
||||
if user and user['pwdmd5'] == password: # 验证的cas用户密码,md5加密的版本
|
||||
token = create_access_token({"user_id": user['person_id'], "identity_id": user['identity_id']})
|
||||
CookieUtil.set_cookie(
|
||||
res=response,
|
||||
key="auth_token",
|
||||
value=token,
|
||||
secure=False, # 在开发环境中,确保 secure=False
|
||||
httponly=False,
|
||||
max_age=3600, # 设置cookie的有效时间为1小时
|
||||
path="/" # 设置cookie的路径
|
||||
)
|
||||
logging.info(f"Token 已成功设置到 cookie: {token}")
|
||||
user.pop('pwdmd5', None) # 移除密码字段
|
||||
return {"success": True, "message": "登录成功", "data": {"token": token, "user_data": user}}
|
||||
else:
|
||||
return {"success": False, "message": "用户名或密码错误"}
|
||||
|
||||
|
||||
# 【Base-Login-3】通过手机号获取Person的ID
|
||||
@router.get("/getPersonIdByTelephone")
|
||||
async def get_person_id_by_telephone(request: Request):
|
||||
telephone = await get_request_str_param(request, "telephone", True, True)
|
||||
if not telephone:
|
||||
return {"success": False, "message": "手机号不能为空"}
|
||||
select_user_sql: str = "SELECT person_id FROM t_sys_loginperson WHERE telephone = '" + telephone + "' and b_use = 1 "
|
||||
userlist = await find_by_sql(select_user_sql,())
|
||||
user = userlist[0] if userlist else None
|
||||
if user:
|
||||
return {"success": True, "message": "查询成功", "data": {"person_id": user['person_id']}}
|
||||
else:
|
||||
return {"success": False, "message": "未查询到相关信息"}
|
||||
|
||||
|
||||
|
||||
# 【Base-Login-4】忘记密码重设,不登录的状态
|
||||
@router.post("/resetPassword")
|
||||
async def reset_password(request: Request):
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
password = await get_request_str_param(request, "password", True, True)
|
||||
if not person_id or not password:
|
||||
return {"success": False, "message": "用户ID和新密码不能为空"}
|
||||
password_md5 = md5_encrypt(password)
|
||||
update_user_sql: str = "UPDATE t_sys_loginperson SET original_pwd = '" + password + "', pwdmd5 = '" + password_md5 + "' WHERE person_id = '" + person_id + "'"
|
||||
await execute_sql(update_user_sql, ())
|
||||
return {"success": True, "message": "密码修改成功"}
|
292
dsLightRag/Routes/TeachingModel/api/TeachingModelController.py
Normal file
292
dsLightRag/Routes/TeachingModel/api/TeachingModelController.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# routes/TeachingModelController.py
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import urllib
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from Routes.TeachingModel.auth.dependencies import *
|
||||
from Util.LightRagUtil import *
|
||||
from Util.PageUtil import *
|
||||
from Util.ParseRequest import *
|
||||
from lightrag import *
|
||||
|
||||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
rag_type: str = "file"
|
||||
# rag_type: str = "pg"
|
||||
|
||||
|
||||
# 【TeachingModel-1】获取主题列表
|
||||
@router.get("/getTrainedTheme")
|
||||
async def get_trained_theme(request: Request):
|
||||
# 获取参数
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
|
||||
subject_id = await get_request_num_param(request, "subject_id", True, True, None)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
||||
|
||||
# 数据库查询
|
||||
select_trained_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 AND bureau_id = '{bureau_id}' AND stage_id = {stage_id} AND subject_id = {subject_id}"
|
||||
print(select_trained_theme_sql)
|
||||
page = await get_page_data_by_sql(select_trained_theme_sql, page_number, page_size)
|
||||
page = await translate_person_bureau_name(page)
|
||||
# 结果返回
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
# 【TeachingModel-2】获取热门主题列表
|
||||
@router.get("/getHotTheme")
|
||||
async def get_hot_theme(request: Request):
|
||||
# 获取参数
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 3)
|
||||
# 数据库查询
|
||||
select_hot_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY quote_count DESC"
|
||||
print(select_hot_theme_sql)
|
||||
page = await get_page_data_by_sql(select_hot_theme_sql, page_number, page_size)
|
||||
page = await translate_person_bureau_name(page)
|
||||
# 结果返回
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
# 【TeachingModel-3】获取最新主题列表
|
||||
@router.get("/getNewTheme")
|
||||
async def get_new_theme(request: Request):
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 3)
|
||||
# 数据库查询
|
||||
select_new_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY create_time DESC"
|
||||
print(select_new_theme_sql)
|
||||
page = await get_page_data_by_sql(select_new_theme_sql, page_number, page_size)
|
||||
page = await translate_person_bureau_name(page)
|
||||
# 结果返回
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
# 【TeachingModel-4】获取问题列表
|
||||
@router.get("/getQuestion")
|
||||
async def get_question(request: Request):
|
||||
# 获取参数
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||||
question_type = await get_request_num_param(request, "question_type", True, True, None)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
||||
|
||||
person_sql = ""
|
||||
if question_type == 2:
|
||||
person_sql = f"AND person_id = '{person_id}'"
|
||||
# 数据库查询
|
||||
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = {question_type} {person_sql} ORDER BY quote_count DESC, create_time DESC"
|
||||
print(select_question_sql)
|
||||
page = await get_page_data_by_sql(select_question_sql, page_number, page_size)
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
|
||||
##########################################################################################
|
||||
# 功能:【TeachingModel-5】向RAG提问,返回RAG的回复
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-08-14
|
||||
# 备注:question_type=2时,question_id必填,其他情况question_id不填;
|
||||
# question_type(0:新增问题;1:常见问题;2:个人历史问题;):
|
||||
# 0【新增问题】:会保存成个人历史问题,并保存答案
|
||||
# 1【常见问题】:增加常见问题引用次数,会保存成个人历史问题,并保存答案
|
||||
# 2【个人历史问题】对应页面新增的历史问题回显后“重试”操作:会重新回答问题,并替换最新答案
|
||||
##########################################################################################
|
||||
@router.post("/sendQuestion")
|
||||
async def send_question(request: Request):
|
||||
# 获取参数
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||||
question = await get_request_str_param(request, "question", True, True)
|
||||
question_type = await get_request_num_param(request, "question_type", True, True, None)
|
||||
question_id = await get_request_num_param(request, "question_id", False, True,0)
|
||||
|
||||
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
|
||||
if theme_object is None:
|
||||
return {"success": False, "message": "主题不存在!"}
|
||||
# 处理theme的调用次数
|
||||
update_sql: str = f"UPDATE t_ai_teaching_model_theme SET quote_count = quote_count + 1, update_time = now() WHERE id = {theme_id}"
|
||||
await execute_sql(update_sql, ())
|
||||
|
||||
if question_type == 2 and question_id == 0:
|
||||
return {"success": False, "message": "[question_type]=2时,[question_id]不允许为空!"}
|
||||
if question_type != 2:
|
||||
param = {}
|
||||
param["stage_id"] = int(theme_object["stage_id"])
|
||||
param["subject_id"] = int(theme_object["subject_id"])
|
||||
param["theme_id"] = theme_id
|
||||
param["question"] = question
|
||||
param["question_type"] = 2
|
||||
param["quote_count"] = 1
|
||||
param["question_person_id"] = person_id
|
||||
param["person_id"] = person_id
|
||||
param["bureau_id"] = bureau_id
|
||||
question_id = await insert("t_ai_teaching_model_question", param)
|
||||
if question_type == 1:
|
||||
# 处理常见问题引用次数
|
||||
update_common_question_sql: str = f"update t_ai_teaching_model_question set quote_count = quote_count + 1 where question_type = 1 and theme_id = {theme_id} and question = '{question}' and is_deleted = 0"
|
||||
await execute_sql(update_common_question_sql, ())
|
||||
|
||||
# 向rag提问
|
||||
topic = theme_object["short_name"]
|
||||
# mode = "hybrid"
|
||||
prompt = "\n 1、不要输出参考资料 或者 References !"
|
||||
prompt = prompt + "\n 2、资料中提供化学反应方程式的,一定要严格按提供的Latex公式输出,绝对不允许对Latex公式进行修改 !"
|
||||
prompt = prompt + "\n 3、如果资料中提供了图片的,一定要严格按照原文提供图片输出,绝对不能省略或不输出!"
|
||||
prompt = prompt + "\n 4、知识库中存在的问题,严格按知识库中的内容回答,不允许扩展!"
|
||||
prompt = prompt + "\n 5、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到,请这样告诉我:“暂时无法回答这个问题呢,我专注于【" + theme_object["theme_name"] + "】这方面知识,若需这方面帮助,请告知我更准确的信息吧~”"
|
||||
prompt = prompt + "\n 6、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!"
|
||||
working_path = "./Topic/" + topic
|
||||
# if rag_type == "file":
|
||||
async def generate_response_stream(query: str, mode: str, user_prompt: str):
|
||||
try:
|
||||
result = ""
|
||||
rag = await initialize_rag(working_path)
|
||||
resp = await rag.aquery(
|
||||
query=query,
|
||||
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
|
||||
async for chunk in resp:
|
||||
if not chunk:
|
||||
continue
|
||||
yield f"data: {json.dumps({'reply': chunk})}\n\n"
|
||||
result += chunk
|
||||
print(chunk, end='', flush=True)
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
finally:
|
||||
# 保存答案
|
||||
update_question_sql = f"update t_ai_teaching_model_question set question_answer = '{result}', update_time = now() where id = {question_id}"
|
||||
await execute_sql(update_question_sql, ())
|
||||
# 清理资源
|
||||
await rag.finalize_storages()
|
||||
|
||||
return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
|
||||
# elif rag_type == "pg":
|
||||
# workspace = theme_object["short_name"]
|
||||
# # 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。
|
||||
# working_dir = 'WorkingPath/' + workspace
|
||||
# if not os.path.exists(working_dir):
|
||||
# os.makedirs(working_dir)
|
||||
# async def generate_response_stream(query: str, mode: str, user_prompt: str):
|
||||
# try:
|
||||
# logger.info("workspace=" + workspace)
|
||||
# rag = await initialize_pg_rag(WORKING_DIR=working_dir, workspace=workspace)
|
||||
#
|
||||
# resp = await rag.aquery(
|
||||
# query=query,
|
||||
# param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
|
||||
# async for chunk in resp:
|
||||
# if not chunk:
|
||||
# continue
|
||||
# yield f"data: {json.dumps({'reply': chunk})}\n\n"
|
||||
# print(chunk, end='', flush=True)
|
||||
# except Exception as e:
|
||||
# yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
# finally:
|
||||
# # 发送流结束标记
|
||||
# yield "data: [DONE]\n\n"
|
||||
# # 清理资源
|
||||
# await rag.finalize_storages()
|
||||
# return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
|
||||
|
||||
|
||||
|
||||
@router.post("/saveWord")
|
||||
async def save_word(request: Request):
|
||||
# 获取参数
|
||||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||||
markdown_content = await get_request_str_param(request, "markdown_content", True, True)
|
||||
question = await get_request_str_param(request, "question", True, True)
|
||||
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
|
||||
if theme_object is None:
|
||||
return {"success": False, "message": "主题不存在!"}
|
||||
|
||||
filename = "【理想大模型】" + str(theme_object["theme_name"]) + "(" + str(question) + ")" + str(time.time()) + ".docx"
|
||||
output_file = None
|
||||
try:
|
||||
# markdown_content 替换换行符
|
||||
# markdown_content = markdown_content.replace("\n", "")
|
||||
|
||||
# 创建临时Markdown文件
|
||||
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
|
||||
with open(temp_md, "w", encoding="utf-8") as f:
|
||||
f.write(markdown_content)
|
||||
# 使用pandoc转换
|
||||
output_file = os.path.join(tempfile.gettempdir(), filename)
|
||||
subprocess.run(['pandoc', '--reference-doc', 'static/template/templates.docx', '-s', temp_md, '-o', output_file, '--resource-path=static'], check=True)
|
||||
|
||||
# 读取生成的Word文件
|
||||
with open(output_file, "rb") as f:
|
||||
stream = BytesIO(f.read())
|
||||
|
||||
# 返回响应
|
||||
encoded_filename = urllib.parse.quote(filename)
|
||||
return StreamingResponse(
|
||||
stream,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Connection": "close",
|
||||
"Content-Type": "application/vnd.ms-excel;charset=UTF-8",
|
||||
"Content-Disposition": f"attachment; filename={encoded_filename}"}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
finally:
|
||||
...
|
||||
# 清理临时文件
|
||||
try:
|
||||
if temp_md and os.path.exists(temp_md):
|
||||
os.remove(temp_md)
|
||||
if output_file and os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp files: {str(e)}")
|
||||
|
||||
|
||||
|
||||
########################################
|
||||
# 【TeachingModel-7】保存个人历史问题答案
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-08-14
|
||||
# 备注:
|
||||
########################################
|
||||
@router.post("/saveAnswer")
|
||||
async def save_answer(request: Request):
|
||||
# 获取参数
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||||
question = await get_request_str_param(request, "question", True, True)
|
||||
answer = await get_request_str_param(request, "answer", True, True)
|
||||
|
||||
# 验证是否存在个人历史问题
|
||||
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = 2 AND question = '{question}' AND person_id = '{person_id}' order by create_time desc limit 1"
|
||||
print(select_question_sql)
|
||||
select_question_result = await find_by_sql(select_question_sql, ())
|
||||
if select_question_result is None: # 个人历史问题不存在
|
||||
return {"success": False, "message": "个人历史问题不存在!"}
|
||||
|
||||
question_id = select_question_result[0].get("id")
|
||||
# 保存答案
|
||||
update_answer_sql: str = f"UPDATE t_ai_teaching_model_question SET question_answer = '{answer}', update_time = now() WHERE id = {question_id}"
|
||||
print(update_answer_sql)
|
||||
await execute_sql(update_answer_sql, ())
|
||||
# 结果返回
|
||||
return {"success": True, "message": "保存成功!"}
|
23
dsLightRag/Routes/TeachingModel/api/TestController.py
Normal file
23
dsLightRag/Routes/TeachingModel/api/TestController.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# routes/TestController.py
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from utils.ParseRequest import *
|
||||
|
||||
# 创建一个路由实例
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/parse", response_model=dict)
|
||||
async def parse_query(request: Request):
|
||||
request_data = await parse_request_data(request)
|
||||
return request_data
|
||||
@router.post("/parse", response_model=dict)
|
||||
async def parse_form(request: Request):
|
||||
request_data = await parse_request_data(request)
|
||||
return request_data
|
||||
|
||||
@router.post("/parse_json", response_model=dict)
|
||||
async def parse_json(request: Request):
|
||||
request_data = await parse_request_data(request)
|
||||
return request_data
|
175
dsLightRag/Routes/TeachingModel/api/ThemeController.py
Normal file
175
dsLightRag/Routes/TeachingModel/api/ThemeController.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# routes/ThemeController.py
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from Util.Database import *
|
||||
from Util.ParseRequest import *
|
||||
from Routes.TeachingModel.auth.dependencies import *
|
||||
from Util.PageUtil import *
|
||||
from Util.TranslateUtil import *
|
||||
|
||||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
|
||||
# 功能:【Theme-1】主题管理列表
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-07-14
|
||||
# 备注:
|
||||
@router.get("/list")
|
||||
async def list(request: Request):
|
||||
# 获取参数
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
|
||||
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True,1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
||||
theme_name = await get_request_str_param(request, "theme_name", False, True)
|
||||
|
||||
print(stage_id, person_id, subject_id, page_number, page_size, theme_name)
|
||||
|
||||
# 拼接查询SQL语句
|
||||
select_theme_sql: str = " SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and person_id = '" + person_id + "'"
|
||||
if stage_id != -1:
|
||||
select_theme_sql += " and stage_id = " + str(stage_id)
|
||||
if subject_id != -1:
|
||||
select_theme_sql += " and subject_id = " + str(subject_id)
|
||||
if theme_name != "":
|
||||
select_theme_sql += " and theme_name like '%" + theme_name + "%'"
|
||||
select_theme_sql += " ORDER BY create_time DESC"
|
||||
|
||||
# 查询主题列表
|
||||
page = await get_page_data_by_sql(select_theme_sql, page_number, page_size)
|
||||
person_ids = ""
|
||||
for item in page["list"]:
|
||||
person_ids += "'" + item["person_id"] + "',"
|
||||
if person_ids != "":
|
||||
person_ids = person_ids[:-1]
|
||||
else:
|
||||
person_ids = "''"
|
||||
|
||||
person_map = await get_person_map(person_ids)
|
||||
stage_map = await get_stage_map()
|
||||
subject_map = await get_subject_map()
|
||||
for item in page["list"]:
|
||||
item["stage_name"] = stage_map.get(str(item["stage_id"]), "未知学段")
|
||||
item["subject_name"] = subject_map.get(str(item["subject_id"]), "未知学科")
|
||||
item["person_name"] = person_map.get(str(item["person_id"]), "未知姓名")
|
||||
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
# 功能:【Theme-2】保存主题管理
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-07-14
|
||||
# 备注:
|
||||
@router.post("/save")
|
||||
async def save(request: Request):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", False, True, 0)
|
||||
theme_name = await get_request_str_param(request, "theme_name", True, True)
|
||||
if len(theme_name) > 50:
|
||||
return {"success": False, "message": "主题名称不能超过50个字符!"}
|
||||
short_name = await get_request_str_param(request, "short_name", True, True)
|
||||
if len(short_name) > 50:
|
||||
return {"success": False, "message": "主题英文简称不能超过50个字符!"}
|
||||
theme_icon = await get_request_str_param(request, "theme_icon", False, True)
|
||||
if len(theme_name) > 200:
|
||||
return {"success": False, "message": "主题图标不能超过200个字符!"}
|
||||
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
|
||||
subject_id = await get_request_num_param(request, "subject_id", True, True, None)
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
|
||||
# 校验参数
|
||||
check_theme_sql = "SELECT theme_name FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and bureau_id = '" + bureau_id + "' and theme_name = '" + theme_name + "'"
|
||||
if id != 0:
|
||||
check_theme_sql += " and id <> " + str(id)
|
||||
print(check_theme_sql)
|
||||
check_theme_result = await find_by_sql(check_theme_sql,())
|
||||
if check_theme_result:
|
||||
return {"success": False, "message": "该主题名称已存在!"}
|
||||
if short_name.length > 50:
|
||||
return {"success": False, "message": "主题英文简称不能超过50个字符!"}
|
||||
check_short_name_sql = "SELECT short_name FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and bureau_id = '" + bureau_id + "' and short_name = '" + short_name + "'"
|
||||
if id != 0:
|
||||
check_short_name_sql += " and id <> " + str(id)
|
||||
print(check_short_name_sql)
|
||||
check_short_name_result = await find_by_sql(check_short_name_sql,())
|
||||
if check_short_name_result:
|
||||
return {"success": False, "message": "该主题英文简称已存在!"}
|
||||
|
||||
# 组装参数
|
||||
param = {"theme_name": theme_name,"short_name": short_name,"theme_icon": theme_icon,"stage_id": stage_id,"subject_id": subject_id,"person_id": person_id,"bureau_id": bureau_id}
|
||||
|
||||
# 保存数据
|
||||
if id == 0:
|
||||
param["search_flag"] = 0
|
||||
param["train_flag"] = 0
|
||||
param["quote_count"] = 0
|
||||
# 插入数据
|
||||
id = await insert("t_ai_teaching_model_theme", param, False)
|
||||
return {"success": True, "message": "保存成功!", "data": {"insert_id" : id}}
|
||||
else:
|
||||
# 更新数据
|
||||
await update("t_ai_teaching_model_theme", param, "id", id, False)
|
||||
return {"success": True, "message": "更新成功!", "data": {"update_id" : id}}
|
||||
|
||||
|
||||
# 功能:【Theme-3】获取主题信息
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-07-14
|
||||
# 备注:
|
||||
@router.get("/get")
|
||||
async def get(request: Request):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", True, True, None)
|
||||
theme_obj = await find_by_id("t_ai_teaching_model_theme", "id", id)
|
||||
if theme_obj is None:
|
||||
return {"success": False, "message": "未查询到该主题信息!"}
|
||||
|
||||
stage_map = await get_stage_map()
|
||||
subject_map = await get_subject_map()
|
||||
theme_obj["stage_name"] = stage_map.get(str(theme_obj["stage_id"]), "未知学段")
|
||||
theme_obj["subject_name"] = subject_map.get(str(theme_obj["subject_id"]), "未知学科")
|
||||
theme_obj["person_name"] = await find_person_name_by_id(theme_obj["person_id"])
|
||||
|
||||
return {"success": True, "message": "查询成功!", "data": {"theme": theme_obj}}
|
||||
|
||||
|
||||
@router.post("/delete")
|
||||
async def delete(request: Request):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", True, True, None)
|
||||
result = await delete_by_id("t_ai_teaching_model_theme", "id", id)
|
||||
if not result:
|
||||
return {"success": False, "message": "删除失败!"}
|
||||
return {"success": True, "message": "删除成功!"}
|
||||
|
||||
|
||||
|
||||
# 功能:【Theme-5】根据学段学科获取主题列表
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-07-31
|
||||
# 备注:
|
||||
@router.get("/getListByStageSubject")
|
||||
async def get_list_by_stage_subject(request: Request):
|
||||
# 获取参数
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
|
||||
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
|
||||
|
||||
# 拼接查询SQL语句
|
||||
select_theme_sql: str = " select id as theme_id, theme_name from t_ai_teaching_model_theme where is_deleted = 0 and person_id = '" + person_id + "'"
|
||||
if stage_id != -1:
|
||||
select_theme_sql += " and stage_id = " + str(stage_id)
|
||||
if subject_id != -1:
|
||||
select_theme_sql += " and subject_id = " + str(subject_id)
|
||||
select_theme_result = await find_by_sql(select_theme_sql,())
|
||||
|
||||
return {"success": True, "message": "查询成功!", "data": {"theme_list": select_theme_result}}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
50
dsLightRag/Routes/TeachingModel/api/UserController.py
Normal file
50
dsLightRag/Routes/TeachingModel/api/UserController.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# routes/UserController.py
|
||||
import re
|
||||
|
||||
from fastapi import APIRouter, Request, Response, Depends
|
||||
from Routes.TeachingModel.auth.dependencies import *
|
||||
from Util.CommonUtil import md5_encrypt
|
||||
from Util.Database import *
|
||||
from Util.ParseRequest import *
|
||||
|
||||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
|
||||
# 【Base-User-1】维护用户手机号
|
||||
@router.post("/modifyTelephone")
|
||||
async def modify_telephone(request: Request):
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
telephone = await get_request_str_param(request, "telephone", True, True)
|
||||
# 校验手机号码格式
|
||||
if not re.match(r"^1[3-9]\d{9}$", telephone):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="手机号码格式错误")
|
||||
# 校验手机号码是否已被注册
|
||||
select_telephone_sql: str = "select * from t_sys_loginperson where b_use = 1 and telephone = '" + telephone + "' and person_id <> '" + person_id + "'"
|
||||
userlist = await find_by_sql(select_telephone_sql, ())
|
||||
if userlist is not None:
|
||||
return {"success": False, "message": "手机号码已被注册"}
|
||||
else:
|
||||
update_telephone_sql: str = "update t_sys_loginperson set telephone = '" + telephone + "' where person_id = '" + person_id + "'"
|
||||
await execute_sql(update_telephone_sql, ())
|
||||
return {"success": True, "message": "修改成功"}
|
||||
|
||||
|
||||
# 【Base-User-2】维护用户密码
|
||||
@router.post("/modifyPassword")
|
||||
async def modify_password(request: Request):
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
old_password = await get_request_str_param(request, "old_password", True, True)
|
||||
password = await get_request_str_param(request, "password", True, True)
|
||||
# 校验旧密码是否正确
|
||||
select_password_sql: str = "select pwdmd5 from t_sys_loginperson where person_id = '" + person_id + "' and b_use = 1"
|
||||
userlist = await find_by_sql(select_password_sql, ())
|
||||
if len(userlist) == 0:
|
||||
return {"success": False, "message": "用户不存在"}
|
||||
else:
|
||||
if userlist[0]["pwdmd5"] != md5_encrypt(old_password):
|
||||
return {"success": False, "message": "旧密码错误"}
|
||||
else:
|
||||
update_password_sql: str = "update t_sys_loginperson set original_pwd = '" + password + "',pwdmd5 = '" + md5_encrypt(password) + "' where person_id = '" + person_id + "'"
|
||||
await execute_sql(update_password_sql, ())
|
||||
return {"success": True, "message": "修改成功"}
|
||||
|
29
dsLightRag/Routes/TeachingModel/auth/dependencies.py
Normal file
29
dsLightRag/Routes/TeachingModel/auth/dependencies.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# dependencies.py
|
||||
|
||||
from fastapi import Request, HTTPException, status, Header
|
||||
from Util.JwtUtil import *
|
||||
|
||||
async def get_current_user(token: str = Header(None, description="Authorization token")):
|
||||
if token is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未提供令牌,请登录后使用!",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
token = token.split(" ")[1] if token.startswith("Bearer ") else token
|
||||
payload = verify_token(token)
|
||||
if payload is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无法验证凭据,请登录后使用!",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
user_id = payload.get("user_id")
|
||||
identity_id = payload.get("identity_id")
|
||||
if user_id is None or identity_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无法验证用户,请重新登录后使用!",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user_id
|
102
dsLightRag/Routes/TeachingModel/tasks/BackgroundTasks.py
Normal file
102
dsLightRag/Routes/TeachingModel/tasks/BackgroundTasks.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
|
||||
from Util.Database import find_by_sql, find_by_id, execute_sql
|
||||
from Util.DocxUtil import get_docx_content_by_pandoc
|
||||
from Util.LightRagUtil import initialize_rag
|
||||
|
||||
# 更详细地控制日志输出
|
||||
logger = logging.getLogger('lightrag')
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
logger.addHandler(handler)
|
||||
|
||||
# 后台任务,监控是否有新的未训练的文档进行训练
|
||||
async def train_document_task():
|
||||
print(datetime.datetime.now(), "线程5秒后开始运行【监控是否有需要处理的文档】")
|
||||
await asyncio.sleep(5) # 使用 asyncio.sleep 而不是 time.sleep
|
||||
# 这里放置你的线程逻辑
|
||||
while True:
|
||||
print("测试定时任务运行")
|
||||
handle_flag = False
|
||||
if handle_flag:
|
||||
# 这里可以放置你的线程要执行的代码
|
||||
logging.info("开始查询是否有待处理文档:")
|
||||
no_train_document_sql: str = " SELECT * FROM t_ai_teaching_model_document WHERE train_flag in (0,3) ORDER BY create_time DESC"
|
||||
no_train_document_result = await find_by_sql(no_train_document_sql, ())
|
||||
logger.info(no_train_document_result)
|
||||
if not no_train_document_result:
|
||||
print(datetime.datetime.now(), "没有需要处理的文档")
|
||||
else:
|
||||
print(datetime.datetime.now(), "存在未训练的文档" + str(len(no_train_document_result))+"个")
|
||||
# 这里可以根据train_flag的值来判断是训练还是删除
|
||||
document = no_train_document_result[0]
|
||||
theme = await find_by_id("t_ai_teaching_model_theme", "id", document["theme_id"])
|
||||
document_name = document["document_name"] + "." + document["document_suffix"]
|
||||
working_dir = "Topic/" + theme["short_name"]
|
||||
document_path = document["document_path"]
|
||||
if document["train_flag"] == 0:
|
||||
# 训练文档
|
||||
update_sql_document: str = " UPDATE t_ai_teaching_model_document SET train_flag = 1 WHERE id = " + str(document["id"])
|
||||
await execute_sql(update_sql_document, ())
|
||||
update_sql_theme: str = " UPDATE t_ai_teaching_model_theme SET train_flag = 1 WHERE id = " + str(theme["id"])
|
||||
await execute_sql(update_sql_theme, ())
|
||||
logging.info(f"开始处理文档:{document_name}, 还有{len(no_train_document_result) - 1}个文档需要处理!")
|
||||
# 训练代码开始
|
||||
# content = get_docx_content_by_pandoc(document_path)
|
||||
try:
|
||||
# 注意:默认设置使用NetworkX
|
||||
rag = await initialize_rag(working_dir)
|
||||
# 获取docx文件的内容
|
||||
content = get_docx_content_by_pandoc(document_path)
|
||||
await rag.ainsert(content, ids=[document_name], file_paths=[document_name])
|
||||
logger.info(f"Inserted content from {document_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred: {e}")
|
||||
finally:
|
||||
await rag.finalize_storages()
|
||||
# 训练结束,更新训练状态
|
||||
update_document_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 2 WHERE id = " + str(document["id"])
|
||||
await execute_sql(update_document_sql, ())
|
||||
elif document["train_flag"] == 3:
|
||||
update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 4 WHERE id = " + str(document["id"])
|
||||
await execute_sql(update_sql, ())
|
||||
logging.info(f"开始删除文档:{document_name}, 还有{len(no_train_document_result) - 1}个文档需要处理!")
|
||||
# 删除文档开始
|
||||
try:
|
||||
# 注意:默认设置使用NetworkX
|
||||
rag = await initialize_rag(working_dir)
|
||||
await rag.adelete_by_doc_id(doc_id = document_name)
|
||||
logger.info(f"Deleted content from {document_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred: {e}")
|
||||
finally:
|
||||
await rag.finalize_storages()
|
||||
# 删除结束,更新训练状态
|
||||
update_document_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 5 WHERE id = " + str(document["id"])
|
||||
await execute_sql(update_document_sql, ())
|
||||
|
||||
# 整体更新主题状态
|
||||
select_document_sql: str = f"select train_flag, count(1) as train_count from t_ai_teaching_model_document where theme_id = {theme['id']} and is_deleted = 0 and train_flag in (0,1,2) group by train_flag"
|
||||
select_document_result = await find_by_sql(select_document_sql, ())
|
||||
train_document_count_map = {}
|
||||
for item in select_document_result:
|
||||
train_document_count_map[str(item["train_flag"])] = int(item["train_count"])
|
||||
train_document_count_1 = train_document_count_map.get("1", 0)
|
||||
train_document_count_2 = train_document_count_map.get("2", 0)
|
||||
if train_document_count_2 > 0:
|
||||
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 1, train_flag = 2 WHERE id = {theme['id']}"
|
||||
await execute_sql(update_theme_sql, ())
|
||||
else:
|
||||
if train_document_count_1 > 0:
|
||||
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 0, train_flag = 1 WHERE id = {theme['id']}"
|
||||
await execute_sql(update_theme_sql, ())
|
||||
else:
|
||||
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 0, train_flag = 0 WHERE id = {theme['id']}"
|
||||
await execute_sql(update_theme_sql, ())
|
||||
# 添加适当的等待时间,避免频繁查询
|
||||
await asyncio.sleep(20) # 开发阶段每20秒一次
|
||||
# await asyncio.sleep(600) # 每十分钟查询一次
|
@@ -1,6 +1,9 @@
|
||||
import uvicorn
|
||||
import asyncio
|
||||
from fastapi import FastAPI
|
||||
from starlette.staticfiles import StaticFiles
|
||||
|
||||
from Routes.TeachingModel.tasks.BackgroundTasks import train_document_task
|
||||
from Util.PostgreSQLUtil import init_postgres_pool, close_postgres_pool
|
||||
|
||||
from Routes.Ggb import router as ggb_router
|
||||
@@ -9,6 +12,13 @@ from Routes.Oss import router as oss_router
|
||||
from Routes.Rag import router as rag_router
|
||||
from Routes.ZuoWen import router as zuowen_router
|
||||
from Routes.Llm import router as llm_router
|
||||
from Routes.TeachingModel.api.LoginController import router as login_router
|
||||
from Routes.TeachingModel.api.UserController import router as user_router
|
||||
from Routes.TeachingModel.api.DmController import router as dm_router
|
||||
from Routes.TeachingModel.api.ThemeController import router as theme_router
|
||||
from Routes.TeachingModel.api.DocumentController import router as document_router
|
||||
from Routes.TeachingModel.api.TeachingModelController import router as teaching_model_router
|
||||
|
||||
from Util.LightRagUtil import *
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@@ -25,6 +35,8 @@ async def lifespan(_: FastAPI):
|
||||
pool = await init_postgres_pool()
|
||||
app.state.pool = pool
|
||||
|
||||
asyncio.create_task(train_document_task())
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@@ -44,5 +56,19 @@ app.include_router(oss_router) # 阿里云OSS路由
|
||||
|
||||
app.include_router(llm_router) # 大模型路由
|
||||
|
||||
# Teaching Model 相关路由
|
||||
# 登录相关(不用登录)
|
||||
app.include_router(login_router, prefix="/api/login", tags=["login"])
|
||||
# 用户相关
|
||||
app.include_router(user_router, prefix="/api/user", tags=["user"])
|
||||
# 字典相关(Dm)
|
||||
app.include_router(dm_router, prefix="/api/dm", tags=["dm"])
|
||||
# 主题相关
|
||||
app.include_router(theme_router, prefix="/api/theme", tags=["theme"])
|
||||
# 文档相关
|
||||
app.include_router(document_router, prefix="/api/document", tags=["document"])
|
||||
# 问题相关(大模型应用)
|
||||
app.include_router(teaching_model_router, prefix="/api/teaching/model", tags=["teacher_model"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8100)
|
||||
|
17
dsLightRag/Util/CommonUtil.py
Normal file
17
dsLightRag/Util/CommonUtil.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def md5_encrypt(text):
|
||||
# 创建一个md5哈希对象
|
||||
md5_hash = hashlib.md5()
|
||||
# 更新哈希对象的数据,注意这里需要将字符串转换为字节类型
|
||||
md5_hash.update(text.encode('utf-8'))
|
||||
# 获取十六进制表示的哈希值
|
||||
encrypted_text = md5_hash.hexdigest()
|
||||
|
||||
return encrypted_text
|
81
dsLightRag/Util/CookieUtil.py
Normal file
81
dsLightRag/Util/CookieUtil.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from fastapi.responses import Response
|
||||
from fastapi.requests import Request
|
||||
import logging
|
||||
|
||||
|
||||
class CookieUtil:
|
||||
@staticmethod
|
||||
def set_cookie(res: Response, key: str, value: str, max_age: int = 3600, path: str = "/", secure: bool = False,
|
||||
httponly: bool = True):
|
||||
"""
|
||||
设置cookie
|
||||
|
||||
:param res: FastAPI的Response对象
|
||||
:param key: cookie的键
|
||||
:param value: cookie的值
|
||||
:param max_age: cookie的有效时间(秒),默认1小时(3600秒)
|
||||
:param path: cookie的路径,默认为根路径("/")
|
||||
:param secure: 是否仅通过HTTPS传输,默认False
|
||||
:param httponly: 是否仅通过HTTP请求访问,默认True
|
||||
"""
|
||||
res.set_cookie(
|
||||
key=key,
|
||||
value=value,
|
||||
httponly=httponly,
|
||||
secure=secure,
|
||||
max_age=max_age,
|
||||
path=path
|
||||
)
|
||||
logging.info(f"设置 cookie: {key}={value}")
|
||||
|
||||
@staticmethod
|
||||
def get_cookie(req: Request, key: str) -> str:
|
||||
"""
|
||||
从请求中获取cookie的值
|
||||
|
||||
:param req: FastAPI的Request对象
|
||||
:param key: cookie的键
|
||||
:return: cookie的值,如果未找到则返回None
|
||||
"""
|
||||
token_value = req.cookies.get(key)
|
||||
logging.info(f"从 cookie 中获取的 {key}: {token_value}")
|
||||
return token_value
|
||||
|
||||
@staticmethod
|
||||
def remove_cookie(res: Response, key: str, path: str = "/"):
|
||||
"""
|
||||
移除cookie
|
||||
|
||||
:param res: FastAPI的Response对象
|
||||
:param key: cookie的键
|
||||
:param path: cookie的路径,默认为根路径("/")
|
||||
"""
|
||||
res.delete_cookie(key, path=path)
|
||||
logging.info(f"移除 cookie: {key}")
|
||||
|
||||
|
||||
# 示例使用
|
||||
if __name__ == "__main__":
|
||||
# 创建一个示例响应对象
|
||||
response = Response(content="示例响应")
|
||||
|
||||
# 设置一个cookie
|
||||
CookieUtil.set_cookie(response, key="auth_token", value="your_token_value")
|
||||
|
||||
# 创建一个示例请求对象(通常从FastAPI请求中获取)
|
||||
request = Request(scope={
|
||||
"type": "http",
|
||||
"headers": [
|
||||
(b"cookie", b"auth_token=your_token_value")
|
||||
]
|
||||
})
|
||||
|
||||
# 获取cookie
|
||||
token = CookieUtil.get_cookie(request, key="auth_token")
|
||||
print(token)
|
||||
|
||||
# 移除cookie
|
||||
CookieUtil.remove_cookie(response, key="auth_token")
|
||||
|
||||
# 打印示例响应头
|
||||
print(response.headers)
|
278
dsLightRag/Util/Database.py
Normal file
278
dsLightRag/Util/Database.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# Database.py
|
||||
import datetime
|
||||
import logging
|
||||
|
||||
from Util.PostgreSQLUtil import init_postgres_pool
|
||||
|
||||
|
||||
# 根据sql语句查询数据
|
||||
async def find_by_sql(sql: str, params: tuple):
|
||||
pg_pool = await init_postgres_pool()
|
||||
try:
|
||||
async with pg_pool.acquire() as conn:
|
||||
result = await conn.fetch(sql, *params)
|
||||
# 将 asyncpg.Record 转换为字典
|
||||
result_dict = [dict(record) for record in result]
|
||||
if result_dict:
|
||||
return result_dict
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"数据库查询错误: {e}")
|
||||
finally:
|
||||
if pg_pool is not None:
|
||||
await pg_pool.close()
|
||||
|
||||
# 插入数据
|
||||
async def insert(tableName, param, onlyForParam=False):
|
||||
current_time = datetime.datetime.now()
|
||||
columns = []
|
||||
values = []
|
||||
placeholders = []
|
||||
|
||||
for key, value in param.items():
|
||||
if value is not None:
|
||||
if isinstance(value, (int, float)):
|
||||
columns.append(key)
|
||||
values.append(value)
|
||||
placeholders.append(f"${len(values)}")
|
||||
elif isinstance(value, str):
|
||||
columns.append(key)
|
||||
values.append(value)
|
||||
placeholders.append(f"${len(values)}")
|
||||
else:
|
||||
columns.append(key)
|
||||
values.append(None)
|
||||
placeholders.append("NULL")
|
||||
|
||||
if not onlyForParam:
|
||||
if 'is_deleted' not in param:
|
||||
columns.append("is_deleted")
|
||||
values.append(0)
|
||||
placeholders.append(f"${len(values)}")
|
||||
|
||||
if 'create_time' not in param:
|
||||
columns.append("create_time")
|
||||
values.append(current_time)
|
||||
placeholders.append(f"${len(values)}")
|
||||
|
||||
if 'update_time' not in param:
|
||||
columns.append("update_time")
|
||||
values.append(current_time)
|
||||
placeholders.append(f"${len(values)}")
|
||||
|
||||
# 构造 SQL 语句
|
||||
column_names = ", ".join(columns)
|
||||
placeholder_names = ", ".join(placeholders)
|
||||
sql = f"INSERT INTO {tableName} ({column_names}) VALUES ({placeholder_names}) RETURNING id"
|
||||
|
||||
pg_pool = await init_postgres_pool()
|
||||
try:
|
||||
async with pg_pool.acquire() as conn:
|
||||
result = await conn.fetchrow(sql, *values)
|
||||
if result:
|
||||
return result['id']
|
||||
else:
|
||||
logging.error("插入数据失败: 未返回ID")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"数据库查询错误: {e}")
|
||||
logging.error(f"执行的SQL语句: {sql}")
|
||||
logging.error(f"参数: {values}")
|
||||
# raise Exception(f"为表[{tableName}]插入数据失败: {e}")
|
||||
finally:
|
||||
if pg_pool is not None:
|
||||
await pg_pool.close()
|
||||
|
||||
|
||||
# 更新数据
|
||||
async def update(table_name, param, property_name, property_value, only_for_param=False):
|
||||
current_time = datetime.datetime.now()
|
||||
set_clauses = []
|
||||
values = []
|
||||
|
||||
# 处理要更新的参数
|
||||
for key, value in param.items():
|
||||
if value is not None:
|
||||
if isinstance(value, (int, float)):
|
||||
set_clauses.append(f"{key} = ${len(values) + 1}")
|
||||
values.append(value)
|
||||
elif isinstance(value, str):
|
||||
set_clauses.append(f"{key} = ${len(values) + 1}")
|
||||
values.append(value)
|
||||
else:
|
||||
set_clauses.append(f"{key} = NULL")
|
||||
values.append(None)
|
||||
|
||||
if not only_for_param:
|
||||
if 'update_time' not in param:
|
||||
set_clauses.append(f"update_time = ${len(values) + 1}")
|
||||
values.append(current_time)
|
||||
|
||||
# 构造 SQL 语句
|
||||
set_clause = ", ".join(set_clauses)
|
||||
sql = f"UPDATE {table_name} SET {set_clause} WHERE {property_name} = ${len(values) + 1} RETURNING id"
|
||||
print(sql)
|
||||
|
||||
# 添加条件参数
|
||||
values.append(property_value)
|
||||
|
||||
pg_pool = await init_postgres_pool()
|
||||
try:
|
||||
async with pg_pool.acquire() as conn:
|
||||
result = await conn.fetchrow(sql, *values)
|
||||
if result:
|
||||
return result['id']
|
||||
else:
|
||||
logging.error("更新数据失败: 未返回ID")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"数据库查询错误: {e}")
|
||||
logging.error(f"执行的SQL语句: {sql}")
|
||||
logging.error(f"参数: {values}")
|
||||
# raise Exception(f"为表[{table_name}]更新数据失败: {e}")
|
||||
finally:
|
||||
if pg_pool is not None:
|
||||
await pg_pool.close()
|
||||
|
||||
|
||||
|
||||
# 获取Bean
|
||||
# 通过主键查询
|
||||
async def find_by_id(table_name, property_name, property_value):
|
||||
if table_name and property_name and property_value is not None:
|
||||
# 构造 SQL 语句
|
||||
sql = f"SELECT * FROM {table_name} WHERE is_deleted = 0 AND {property_name} = $1"
|
||||
logging.debug(sql)
|
||||
|
||||
# 执行查询
|
||||
result = await find_by_sql(sql, (property_value,))
|
||||
if not result:
|
||||
logging.info(f"未查询到[{property_name}]为{property_value}的有效数据!")
|
||||
return None
|
||||
# 返回第一条数据
|
||||
return result[0]
|
||||
else:
|
||||
logging.error("参数不全")
|
||||
return None
|
||||
|
||||
# 通过主键删除
|
||||
# 逻辑删除
|
||||
async def delete_by_id(table_name, property_name, property_value):
|
||||
if table_name and property_name and property_value is not None:
|
||||
sql = f"UPDATE {table_name} SET is_deleted = 1, update_time = now() WHERE {property_name} = $1 and is_deleted = 0"
|
||||
logging.debug(sql)
|
||||
# 执行删除
|
||||
pg_pool = await init_postgres_pool()
|
||||
try:
|
||||
async with pg_pool.acquire() as conn:
|
||||
result = await conn.execute(sql, property_value)
|
||||
if result:
|
||||
return True
|
||||
else:
|
||||
logging.error("删除失败: 未找到数据")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"数据库查询错误: {e}")
|
||||
logging.error(f"执行的SQL语句: {sql}")
|
||||
logging.error(f"参数: {property_value}")
|
||||
# raise Exception(f"为表[{table_name}]删除数据失败: {e}")
|
||||
finally:
|
||||
if pg_pool is not None:
|
||||
await pg_pool.close()
|
||||
else:
|
||||
logging.error("参数不全")
|
||||
return False
|
||||
|
||||
|
||||
# 执行一个SQL语句
|
||||
async def execute_sql(sql, params):
|
||||
pg_pool = await init_postgres_pool()
|
||||
try:
|
||||
async with pg_pool.acquire() as conn:
|
||||
await conn.fetch(sql, *params)
|
||||
except Exception as e:
|
||||
logging.error(f"数据库查询错误: {e}")
|
||||
logging.error(f"执行的SQL语句: {sql}")
|
||||
# raise Exception(f"执行SQL失败: {e}")
|
||||
finally:
|
||||
if pg_pool is not None:
|
||||
await pg_pool.close()
|
||||
|
||||
|
||||
async def find_person_name_by_id(person_id):
|
||||
sql = f"SELECT person_name FROM t_sys_loginperson WHERE person_id = $1 and b_use = 1"
|
||||
logging.debug(sql)
|
||||
person_list = await find_by_sql(sql, (person_id,))
|
||||
if person_list:
|
||||
return person_list[0]['person_name']
|
||||
else:
|
||||
return None
|
||||
|
||||
async def find_bureau_name_by_id(org_id):
|
||||
sql = f"SELECT org_name FROM t_base_organization WHERE org_id = $1 and b_use = 1"
|
||||
logging.debug(sql)
|
||||
org_list = await find_by_sql(sql, (org_id,))
|
||||
if org_list:
|
||||
return org_list[0]['org_name']
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def update_batch_property(table_name, update_param, property_param, only_for_param=False):
|
||||
current_time = datetime.datetime.now()
|
||||
set_clauses = []
|
||||
values = []
|
||||
|
||||
# 处理要更新的参数
|
||||
for key, value in update_param.items():
|
||||
if value is not None:
|
||||
if isinstance(value, (int, float)):
|
||||
set_clauses.append(f"{key} = ${len(values) + 1}")
|
||||
values.append(value)
|
||||
elif isinstance(value, str):
|
||||
set_clauses.append(f"{key} = ${len(values) + 1}")
|
||||
values.append(value)
|
||||
else:
|
||||
set_clauses.append(f"{key} = NULL")
|
||||
values.append(None)
|
||||
|
||||
if not only_for_param:
|
||||
if 'update_time' not in update_param:
|
||||
set_clauses.append(f"update_time = ${len(values) + 1}")
|
||||
values.append(current_time)
|
||||
|
||||
# 构造 WHERE 子句
|
||||
property_clauses = []
|
||||
for key, value in property_param.items():
|
||||
if value is not None:
|
||||
if isinstance(value, (int, float)):
|
||||
property_clauses.append(f"{key} = ${len(values) + 1}")
|
||||
values.append(value)
|
||||
elif isinstance(value, str):
|
||||
property_clauses.append(f"{key} = ${len(values) + 1}")
|
||||
values.append(value)
|
||||
else:
|
||||
property_clauses.append(f"{key} IS NULL")
|
||||
values.append(None)
|
||||
|
||||
# 构造 SQL 语句
|
||||
set_clause = ", ".join(set_clauses)
|
||||
property_clause = " AND ".join(property_clauses)
|
||||
sql = f"UPDATE {table_name} SET {set_clause} WHERE {property_clause}"
|
||||
logging.debug(sql)
|
||||
|
||||
pg_pool = await init_postgres_pool()
|
||||
try:
|
||||
async with pg_pool.acquire() as conn:
|
||||
result = await conn.execute(sql, *values)
|
||||
affected_rows = conn.rowcount
|
||||
return affected_rows
|
||||
except Exception as e:
|
||||
logging.error(f"数据库查询错误: {e}")
|
||||
logging.error(f"执行的SQL语句: {sql}")
|
||||
logging.error(f"参数: {values}")
|
||||
# raise Exception(f"为表[{table_name}]批量更新数据失败: {e}")
|
||||
finally:
|
||||
if pg_pool is not None:
|
||||
await pg_pool.close()
|
@@ -6,6 +6,7 @@ import uuid
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
from networkx.algorithms.bipartite.centrality import betweenness_centrality
|
||||
|
||||
# 或者如果你想更详细地控制日志输出
|
||||
logger = logging.getLogger('DocxUtil')
|
||||
@@ -114,8 +115,20 @@ def get_docx_content_by_pandoc(docx_file):
|
||||
pos = line.find(")")
|
||||
q = line[:pos + 1]
|
||||
q = q.replace("./static", ".")
|
||||
# q = q[4:-1]
|
||||
# q='<img src="'+q+'" alt="我是图片">'
|
||||
# Modify by Kalman.CHENG ☆: 增加逻辑对图片路径处理,在(和static之间加上/
|
||||
left_idx = line.find("(")
|
||||
static_idx = line.find("static")
|
||||
if left_idx == -1 or static_idx == -1 or left_idx > static_idx:
|
||||
print("路径中不包含(+~+static的已知格式")
|
||||
else:
|
||||
between_content = q[left_idx+1:static_idx].strip()
|
||||
if between_content:
|
||||
q = q[:left_idx+1] + '\\' + q[static_idx:]
|
||||
else:
|
||||
q = q[:static_idx] + '\\' + q[static_idx:]
|
||||
print(f"q3:{q}")
|
||||
#q = q[4:-1]
|
||||
#q='<img src="'+q+'" alt="我是图片">'
|
||||
img_idx += 1
|
||||
content += q + "\n"
|
||||
else:
|
||||
@@ -126,4 +139,4 @@ def get_docx_content_by_pandoc(docx_file):
|
||||
f.write(content)
|
||||
# 删除临时文件 output_file
|
||||
# os.remove(temp_markdown)
|
||||
return content.replace("\n\n", "\n").replace("\\", "")
|
||||
return content.replace("\n\n", "\n").replace("\\", "/")
|
||||
|
21
dsLightRag/Util/JwtUtil.py
Normal file
21
dsLightRag/Util/JwtUtil.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# jwt_utils.py
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from jose import JWTError, jwt
|
||||
from Config.Config import *
|
||||
|
||||
|
||||
def create_access_token(data: dict):
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(token: str):
|
||||
try:
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
68
dsLightRag/Util/PageUtil.py
Normal file
68
dsLightRag/Util/PageUtil.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import math
|
||||
from Util.Database import *
|
||||
|
||||
|
||||
# 查询数据条数
|
||||
async def get_total_data_count(total_data_sql):
|
||||
total_data_count = 0
|
||||
total_data_count_sql = "select count(*) as num from (" + total_data_sql + ") as temp_table"
|
||||
result = await find_by_sql(total_data_count_sql,())
|
||||
row = result[0] if result else None
|
||||
if row:
|
||||
total_data_count = row.get("num")
|
||||
return total_data_count
|
||||
|
||||
|
||||
def get_page_by_total_row(total_data_count, page_number, page_size):
|
||||
total_page = (page_size != 0) and math.floor((total_data_count + page_size - 1) / page_size) or 0
|
||||
if page_number <= 0:
|
||||
page_number = 1
|
||||
if 0 < total_page < page_number:
|
||||
page_number = total_page
|
||||
offset = page_size * page_number - page_size
|
||||
limit = page_size
|
||||
return total_data_count, total_page, offset, limit
|
||||
|
||||
|
||||
async def get_page_data_by_sql(total_data_sql: str, page_number: int, page_size: int):
|
||||
total_row: int = 0
|
||||
total_page: int = 0
|
||||
total_data_sql = total_data_sql.replace(";", "")
|
||||
total_data_sql = total_data_sql.replace(" FROM ", " from ")
|
||||
|
||||
# 查询总数
|
||||
total_data_count = await get_total_data_count(total_data_sql)
|
||||
if total_data_count == 0:
|
||||
return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []}
|
||||
else:
|
||||
total_row, total_page, offset, limit = get_page_by_total_row(total_data_count, page_number, page_size)
|
||||
|
||||
# 构造执行分页查询的sql语句
|
||||
page_data_sql = total_data_sql + " LIMIT %d offset %d " % (limit, offset)
|
||||
print(page_data_sql)
|
||||
# 执行分页查询
|
||||
page_data = await find_by_sql(page_data_sql, ())
|
||||
if page_data:
|
||||
return {"page_number": page_number, "page_size": page_size, "total_row": total_row, "total_page": total_page, "list": page_data}
|
||||
else:
|
||||
return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []}
|
||||
|
||||
|
||||
# 翻译person_name, bureau_id
|
||||
async def translate_person_bureau_name(page):
|
||||
for item in page["list"]:
|
||||
if item["person_id"] is not None:
|
||||
person_id = str(item["person_id"])
|
||||
person_name = await find_person_name_by_id(person_id)
|
||||
if person_name:
|
||||
item["person_name"] = person_name
|
||||
else:
|
||||
item["person_name"] = ""
|
||||
if item["bureau_id"] is not None:
|
||||
bureau_id = str(item["bureau_id"])
|
||||
bureau_name = await find_bureau_name_by_id(bureau_id)
|
||||
if bureau_name:
|
||||
item["bureau_name"] = bureau_name
|
||||
else:
|
||||
item["bureau_name"] = ""
|
||||
return page
|
102
dsLightRag/Util/ParseRequest.py
Normal file
102
dsLightRag/Util/ParseRequest.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import datetime
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
|
||||
async def parse_request_data(request: Request):
|
||||
data = {
|
||||
"headers": dict(request.headers),
|
||||
"params": {},
|
||||
"cookies": dict(request.cookies),
|
||||
"time": datetime.datetime.utcnow(),
|
||||
"ip": request.client.host
|
||||
}
|
||||
|
||||
request_method = request.method
|
||||
|
||||
if request_method == "GET":
|
||||
query_params = request.query_params
|
||||
for key, value in query_params.items():
|
||||
parse_args({key: value}, data)
|
||||
|
||||
elif request_method == "POST":
|
||||
content_type = request.headers.get("content-type", "").lower()
|
||||
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
|
||||
form_data = await request.form()
|
||||
for key, value in form_data.items():
|
||||
parse_args({key: value}, data)
|
||||
elif "application/json" in content_type:
|
||||
json_data = await request.json()
|
||||
for key, value in json_data.items():
|
||||
parse_args({key: value}, data)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Unsupported content type")
|
||||
|
||||
return data['params']
|
||||
|
||||
def parse_args(args, data):
|
||||
if args:
|
||||
for key, value in args.items():
|
||||
data['params'][key] = value
|
||||
|
||||
|
||||
# 获取请求参数中的字符串参数
|
||||
# param_name --> 参数名
|
||||
# nonempty --> 是否必填
|
||||
# trim --> 是否去除两端空格
|
||||
# 返回参数值
|
||||
async def get_request_str_param(request: Request, param_name: str, nonempty: bool, trim: bool):
|
||||
request_data = await parse_request_data(request)
|
||||
if request_data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="请求数据格式不正确",
|
||||
)
|
||||
value = str(request_data.get(param_name)) if request_data.get(param_name) is not None else ""
|
||||
if trim and value != "":
|
||||
value = value.strip()
|
||||
if nonempty and value == "":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="[" + param_name + "]不允许为空!",
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
# 获取请求参数中的数字参数
|
||||
# param_name --> 参数名
|
||||
# nonempty --> 是否必填
|
||||
# trim --> 是否去除两端空格
|
||||
# 返回参数值
|
||||
async def get_request_num_param(request: Request, param_name: str, nonempty: bool, trim: bool, default_value):
|
||||
value = await get_request_str_param(request, param_name, nonempty, trim)
|
||||
if nonempty:
|
||||
return await str2num(param_name, value)
|
||||
else:
|
||||
if value == "":
|
||||
return default_value
|
||||
return await str2num(param_name, value)
|
||||
|
||||
|
||||
# 字符串转数字, 判断字符串是否包含小数点,转float or 转int
|
||||
# param_name --> 参数名
|
||||
# value --> 字符串值
|
||||
# 返回数字值
|
||||
# 若字符串值不是数字,则抛出HTTPException
|
||||
async def str2num(param_name: str, value: str):
|
||||
if value.find(".") != -1:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="[" + param_name + "]必须为数字!",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="[" + param_name + "]必须为数字!",
|
||||
)
|
37
dsLightRag/Util/TranslateUtil.py
Normal file
37
dsLightRag/Util/TranslateUtil.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from Util.Database import find_by_sql
|
||||
|
||||
|
||||
async def get_stage_map():
|
||||
select_stage_sql: str = "select * from t_dm_stage where b_use = 1"
|
||||
select_stage_result = await find_by_sql(select_stage_sql, ())
|
||||
stage_map = {}
|
||||
for stage in select_stage_result:
|
||||
stage_map[str(stage["stage_id"])] = stage["stage_name"]
|
||||
return stage_map
|
||||
|
||||
async def get_stage_map_by_id(stage_id: int):
|
||||
select_stage_sql: str = f"select stage_id, stage_name from t_dm_stage where b_use = 1 and stage_id = {stage_id}"
|
||||
select_stage_result = await find_by_sql(select_stage_sql, ())
|
||||
if select_stage_result is not None:
|
||||
return select_stage_result[0]["stage_name"]
|
||||
else:
|
||||
return "未知学段"
|
||||
|
||||
async def get_subject_map():
|
||||
select_subject_sql: str = "select * from t_dm_subject"
|
||||
select_subject_result = await find_by_sql(select_subject_sql, ())
|
||||
subject_map = {}
|
||||
for subject in select_subject_result:
|
||||
subject_map[str(subject["subject_id"])] = subject["subject_name"]
|
||||
return subject_map
|
||||
|
||||
async def get_person_map(person_ids: str):
|
||||
person_id_list = person_ids.split(",")
|
||||
person_ids = ",".join(person_id_list)
|
||||
select_person_sql: str = f"select person_id, person_name from t_sys_loginperson where person_id in ({person_ids}) and b_use = 1"
|
||||
select_person_result = await find_by_sql(select_person_sql, ())
|
||||
person_map = {}
|
||||
if select_person_result is not None:
|
||||
for person in select_person_result:
|
||||
person_map[str(person["person_id"])] = person["person_name"]
|
||||
return person_map
|
BIN
dsLightRag/upload_file/CESHI_19/ChangChun.docx
Normal file
BIN
dsLightRag/upload_file/CESHI_19/ChangChun.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/CESHI_19/Chemistry.docx
Normal file
BIN
dsLightRag/upload_file/CESHI_19/Chemistry.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/Chemistry_15/Chemistry.docx
Normal file
BIN
dsLightRag/upload_file/Chemistry_15/Chemistry.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/HighChemistry_30/高中化学必修一教案_5312205.docx
Normal file
BIN
dsLightRag/upload_file/HighChemistry_30/高中化学必修一教案_5312205.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/HighChinese_28/高中语文必修一教案_1449263.docx
Normal file
BIN
dsLightRag/upload_file/HighChinese_28/高中语文必修一教案_1449263.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/HighMath_24/高中数学必修一全册教案_3221547.docx
Normal file
BIN
dsLightRag/upload_file/HighMath_24/高中数学必修一全册教案_3221547.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/HighMath_24/高中数学必修一常见题型归类_3027455.docx
Normal file
BIN
dsLightRag/upload_file/HighMath_24/高中数学必修一常见题型归类_3027455.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/HighMath_24/高中数学必修一教学计划_1448630.docx
Normal file
BIN
dsLightRag/upload_file/HighMath_24/高中数学必修一教学计划_1448630.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/HighMath_24/高中数学必修一讲义_1693010.docx
Normal file
BIN
dsLightRag/upload_file/HighMath_24/高中数学必修一讲义_1693010.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/HighPhysics_29/高中物理必修一教案_1591280.docx
Normal file
BIN
dsLightRag/upload_file/HighPhysics_29/高中物理必修一教案_1591280.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/JiHe_16/JiHe.docx
Normal file
BIN
dsLightRag/upload_file/JiHe_16/JiHe.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/Math_13/Math.docx
Normal file
BIN
dsLightRag/upload_file/Math_13/Math.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/MiddleHistory_23/初中历史知识梳理.docx
Normal file
BIN
dsLightRag/upload_file/MiddleHistory_23/初中历史知识梳理.docx
Normal file
Binary file not shown.
Binary file not shown.
BIN
dsLightRag/upload_file/ShiJi_17/ShiJi_1.docx
Normal file
BIN
dsLightRag/upload_file/ShiJi_17/ShiJi_1.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/ShiJi_17/ShiJi_2.docx
Normal file
BIN
dsLightRag/upload_file/ShiJi_17/ShiJi_2.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/ShiJi_17/ShiJi_3.docx
Normal file
BIN
dsLightRag/upload_file/ShiJi_17/ShiJi_3.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/ShiJi_17/ShiJi_4.docx
Normal file
BIN
dsLightRag/upload_file/ShiJi_17/ShiJi_4.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/ShiJi_17/ShiJi_5.docx
Normal file
BIN
dsLightRag/upload_file/ShiJi_17/ShiJi_5.docx
Normal file
Binary file not shown.
BIN
dsLightRag/upload_file/SuShi_14/SuShi.docx
Normal file
BIN
dsLightRag/upload_file/SuShi_14/SuShi.docx
Normal file
Binary file not shown.
Reference in New Issue
Block a user