整合 dsAiTeachingModel 接口
This commit is contained in:
46
dsLightRag/Routes/TeachingModel/api/DmController.py
Normal file
46
dsLightRag/Routes/TeachingModel/api/DmController.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# routes/DmController.py
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from starlette.requests import Request
|
||||
|
||||
from Util.Database import *
|
||||
from Routes.TeachingModel.auth.dependencies import get_current_user
|
||||
from Util.ParseRequest import get_request_num_param
|
||||
from Util.TranslateUtil import get_stage_map_by_id
|
||||
|
||||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
|
||||
# 学段、学科级联数据
|
||||
@router.get("/getStageSubjectList")
|
||||
async def get_stage_subject_list():
|
||||
# 先查询学段list
|
||||
select_stage_sql: str = "select stage_id, stage_name from t_dm_stage where b_use = 1 order by sort_id;"
|
||||
stage_list = await find_by_sql(select_stage_sql,())
|
||||
for stage in stage_list:
|
||||
# 再查询学科list
|
||||
select_subject_sql: str = "select subject_id, subject_name, icon from t_dm_subject where stage_id = " + str(stage["stage_id"]) + " order by sort_id;"
|
||||
subject_list = await find_by_sql(select_subject_sql,())
|
||||
stage["subject_list"] = subject_list
|
||||
|
||||
return {"success": True, "message": "成功!", "data": stage_list}
|
||||
|
||||
# 获取学段信息
|
||||
@router.get("/getStageList")
|
||||
async def get_stage_list():
|
||||
# 先查询学段list
|
||||
select_stage_sql: str = "select stage_id, stage_name from t_dm_stage where b_use = 1 order by sort_id;"
|
||||
stage_list = await find_by_sql(select_stage_sql,())
|
||||
return {"success": True, "message": "查询成功!", "data": {"stage_list": stage_list}}
|
||||
|
||||
|
||||
# 根据学段ID获取学科列表
|
||||
@router.get("/getSubjectList")
|
||||
async def get_subject_list(request: Request):
|
||||
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
|
||||
stage_name = await get_stage_map_by_id(stage_id)
|
||||
# 先查询学科list
|
||||
select_subject_sql: str = f"select subject_id, subject_name, icon, {stage_id} as stage_id, '{stage_name}' as stage_name from t_dm_subject where stage_id = {stage_id} order by sort_id;"
|
||||
print(select_subject_sql)
|
||||
subject_list = await find_by_sql(select_subject_sql,())
|
||||
return {"success": True, "message": "查询成功!", "data": {"subject_list": subject_list}}
|
178
dsLightRag/Routes/TeachingModel/api/DocumentController.py
Normal file
178
dsLightRag/Routes/TeachingModel/api/DocumentController.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# routes/DocumentController.py
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Request, Response, Depends, UploadFile, File
|
||||
|
||||
from Routes.TeachingModel.auth.dependencies import get_current_user
|
||||
from Util.PageUtil import *
|
||||
from Util.ParseRequest import *
|
||||
from Util.TranslateUtil import *
|
||||
|
||||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
|
||||
# 创建上传文件的目录
|
||||
UPLOAD_DIR = "upload_file"
|
||||
if not os.path.exists(UPLOAD_DIR):
|
||||
os.makedirs(UPLOAD_DIR)
|
||||
|
||||
# 合法文件扩展名
|
||||
supported_suffix_types = ['doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx']
|
||||
|
||||
# 【Document-1】文档管理列表
|
||||
@router.get("/list")
|
||||
async def list(request: Request):
|
||||
# 获取参数
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
|
||||
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
|
||||
theme_id = await get_request_num_param(request, "theme_id", False, True, -1)
|
||||
document_suffix = await get_request_str_param(request, "document_suffix", False, True)
|
||||
document_name = await get_request_str_param(request, "document_name", False, True)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
||||
|
||||
print(person_id, stage_id, subject_id, document_suffix, document_name, page_number, page_size)
|
||||
|
||||
# 拼接查询SQL语句
|
||||
|
||||
select_document_sql: str = " SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and person_id = '" + person_id + "'"
|
||||
if stage_id != -1:
|
||||
select_document_sql += " AND stage_id = " + str(stage_id)
|
||||
if subject_id != -1:
|
||||
select_document_sql += " AND subject_id = " + str(subject_id)
|
||||
if theme_id != -1:
|
||||
select_document_sql += " AND theme_id = " + str(theme_id)
|
||||
if document_suffix != "":
|
||||
select_document_sql += " AND document_suffix = '" + document_suffix + "'"
|
||||
if document_name != "":
|
||||
select_document_sql += " AND document_name like '%" + document_name + "%'"
|
||||
select_document_sql += " ORDER BY create_time DESC "
|
||||
|
||||
# 查询文档列表
|
||||
page = await get_page_data_by_sql(select_document_sql, page_number, page_size)
|
||||
person_ids = ""
|
||||
for item in page["list"]:
|
||||
person_ids += "'" + item["person_id"] + "',"
|
||||
if person_ids != "":
|
||||
person_ids = person_ids[:-1]
|
||||
else:
|
||||
person_ids = "''"
|
||||
person_map = await get_person_map(person_ids)
|
||||
stage_map = await get_stage_map()
|
||||
subject_map = await get_subject_map()
|
||||
|
||||
for item in page["list"]:
|
||||
theme_info = await find_by_id("t_ai_teaching_model_theme", "id", item["theme_id"])
|
||||
item["theme_info"] = theme_info
|
||||
item["stage_name"] = stage_map.get(str(item["stage_id"]), "未知学段")
|
||||
item["subject_name"] = subject_map.get(str(item["subject_id"]), "未知学科")
|
||||
item["person_name"] = person_map.get(str(item["person_id"]), "未知姓名")
|
||||
item["create_time"] = str(item["create_time"].strftime("%Y-%m-%d %H:%M"))
|
||||
|
||||
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
# 【Document-2】保存文档管理
|
||||
@router.post("/save")
|
||||
async def save(request: Request, file: UploadFile = File(...)):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", False, True, 0)
|
||||
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
|
||||
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
|
||||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
# 先获取theme主题信息
|
||||
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
|
||||
if theme_object is None:
|
||||
return {"success": False, "message": "主题不存在!"}
|
||||
# 获取文件名
|
||||
document_name = file.filename.split(".")[0]
|
||||
if len(document_name) > 100:
|
||||
return {"success": False, "message": "文件名过长,已超过100个字符!"}
|
||||
# 检查文件名在该主题下是否重复
|
||||
select_theme_document_sql: str = "SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and document_name = '" + document_name + "' and theme_id = " + str(theme_id)
|
||||
if id != 0:
|
||||
select_theme_document_sql += " AND id <> " + str(id)
|
||||
theme_document = await find_by_sql(select_theme_document_sql, ())
|
||||
if theme_document is not None:
|
||||
return {"success": False, "message": "该主题下文档名称重复!"}
|
||||
# 获取文件扩展名
|
||||
document_suffix = file.filename.split(".")[-1]
|
||||
# 检查文件扩展名
|
||||
if document_suffix not in supported_suffix_types:
|
||||
return {"success": False, "message": "不支持的文件类型!"}
|
||||
# 构造文件保存路径
|
||||
document_dir = UPLOAD_DIR + "/" + str(theme_object["short_name"]) + "_" + str(theme_object["id"]) + "/"
|
||||
if not os.path.exists(document_dir):
|
||||
os.makedirs(document_dir)
|
||||
document_path = os.path.join(document_dir, file.filename)
|
||||
# 保存文件
|
||||
try:
|
||||
with open(document_path, "wb") as buffer:
|
||||
buffer.write(await file.read())
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"文件保存失败!{e}"}
|
||||
|
||||
# 构造保存文档SQL语句
|
||||
param = {"stage_id": stage_id, "subject_id": subject_id, "document_name": document_name, "theme_id": theme_id, "document_path": document_path, "document_suffix": document_suffix, "person_id": person_id, "bureau_id": bureau_id}
|
||||
|
||||
# 保存数据
|
||||
if id == 0:
|
||||
param["train_flag"] = 0
|
||||
# 插入数据
|
||||
id = await insert("t_ai_teaching_model_document", param, False)
|
||||
return {"success": True, "message": "保存成功!", "data": {"insert_id" : id}}
|
||||
else:
|
||||
# 更新数据
|
||||
await update("t_ai_teaching_model_document", param, "id", id)
|
||||
return {"success": True, "message": "更新成功!", "data": {"update_id" : id}}
|
||||
|
||||
# 【Document-3】获取文档信息
|
||||
@router.get("/get")
|
||||
async def get(request: Request):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", True, True, None)
|
||||
# 查询数据
|
||||
document_object = await find_by_id("t_ai_teaching_model_document", "id", id)
|
||||
if document_object is None:
|
||||
return {"success": False, "message": "未查询到该文档信息!"}
|
||||
theme_info = await find_by_id("t_ai_teaching_model_theme", "id", document_object["theme_id"])
|
||||
document_object["theme_info"] = theme_info
|
||||
stage_map = await get_stage_map()
|
||||
subject_map = await get_subject_map()
|
||||
document_object["stage_name"] = stage_map.get(str(document_object["stage_id"]), "未知学段")
|
||||
document_object["subject_name"] = subject_map.get(str(document_object["subject_id"]), "未知学科")
|
||||
document_object["person_name"] = await find_person_name_by_id(document_object["person_id"])
|
||||
|
||||
return {"success": True, "message": "查询成功!", "data": {"document": document_object}}
|
||||
|
||||
|
||||
|
||||
# 功能:【Document-4】删除文档信息
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-07-23
|
||||
# 备注:
|
||||
@router.post("/delete")
|
||||
async def delete(request: Request):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", True, True, None)
|
||||
document_object = await find_by_id("t_ai_teaching_model_document", "id", id)
|
||||
if document_object is None:
|
||||
return {"success": False, "message": "未查询到该文档信息!"}
|
||||
# 删除文件
|
||||
train_flag = document_object["train_flag"]
|
||||
if train_flag == 0:
|
||||
# 未训练的文档,直接删除文件
|
||||
await update_batch_property("t_ai_teaching_model_document",
|
||||
{"is_deleted": 1, "train_flag": 6}, {"is_deleted": 0, "id": id}, False)
|
||||
elif train_flag == 1:
|
||||
# 正在训练的文档,不能删除
|
||||
return {"success": False, "message": "该文档正在训练中,不能删除!"}
|
||||
elif train_flag == 2:
|
||||
# 训练完成的文档,删除文件
|
||||
await update_batch_property("t_ai_teaching_model_document",
|
||||
{"is_deleted": 1, "train_flag": 3}, {"is_deleted": 0, "id": id}, False)
|
||||
return {"success": True, "message": "删除成功!"}
|
159
dsLightRag/Routes/TeachingModel/api/LoginController.py
Normal file
159
dsLightRag/Routes/TeachingModel/api/LoginController.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# routes/LoginController.py
|
||||
import base64
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
import jwt
|
||||
|
||||
from captcha.image import ImageCaptcha
|
||||
from fastapi import APIRouter, Request, Response, status, HTTPException
|
||||
from Util.CommonUtil import *
|
||||
from Util.CookieUtil import *
|
||||
from Util.Database import *
|
||||
from Util.JwtUtil import *
|
||||
from Util.ParseRequest import *
|
||||
from Config.Config import *
|
||||
|
||||
# 创建一个路由实例
|
||||
router = APIRouter()
|
||||
|
||||
# 获取项目根目录
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 配置验证码
|
||||
image = ImageCaptcha(
|
||||
width=100, height=30, # 增加宽度和高度
|
||||
font_sizes=[26], # 增加字体大小
|
||||
fonts=[os.path.join(project_root, 'DejaVuSans-Bold.ttf')] # 设置自定义字体路径
|
||||
)
|
||||
|
||||
@router.get("/getCaptcha")
|
||||
def get_captcha():
|
||||
captcha_text = ''.join(random.choices(string.digits, k=4)) # 生成4个数字的验证码
|
||||
session_id = os.urandom(16).hex()
|
||||
|
||||
# 将验证码文本存储在session中
|
||||
session_data = {session_id: captcha_text}
|
||||
with open("./session/captcha_sessions.json", "a") as session_file:
|
||||
json.dump(session_data, session_file)
|
||||
session_file.write("\n")
|
||||
|
||||
# 生成验证码图片
|
||||
data = image.generate(captcha_text)
|
||||
captcha_image_base64 = base64.b64encode(data.read()).decode()
|
||||
|
||||
return {"image": captcha_image_base64, "session_id": session_id}
|
||||
|
||||
# 验证用户并生成JWT令牌的接口
|
||||
@router.post("/validateCaptcha")
|
||||
async def validate_captcha(session_id : str, captcha : str):
|
||||
try:
|
||||
with open("./session/captcha_sessions.json", "r") as session_file:
|
||||
sessions = {}
|
||||
for line in session_file:
|
||||
sessions.update(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
sessions = {}
|
||||
|
||||
correct_captcha_text = sessions.get(session_id)
|
||||
|
||||
if not correct_captcha_text or correct_captcha_text.lower() != captcha.lower():
|
||||
return {"success": False, "message": "验证码错误"}
|
||||
return {"success": True, "message": "验证码正确"}
|
||||
|
||||
# 获取cookie中的token的方法
|
||||
@router.get("/getToken")
|
||||
async def get_token(request: Request):
|
||||
token = CookieUtil.get_cookie(request, key="auth_token")
|
||||
print(token)
|
||||
if token:
|
||||
try:
|
||||
decoded_token = jwt.decode(token, JWT_SECRET_KEY, algorithms=['HS256'])
|
||||
logging.info(f"Token 解码成功: {decoded_token}")
|
||||
return {"success": True, "message": "Token 验证成功", "token_data": decoded_token}
|
||||
except jwt.ExpiredSignatureError:
|
||||
logging.error("Token 过期")
|
||||
return {"success": False, "message": "Token 过期"}
|
||||
except jwt.InvalidTokenError:
|
||||
logging.error("无效的 Token")
|
||||
return {"success": False, "message": "无效的 Token"}
|
||||
else:
|
||||
logging.error("未找到 Token")
|
||||
return {"success": False, "message": "未找到 Token"}
|
||||
|
||||
|
||||
# 登出
|
||||
@router.get("/logout")
|
||||
async def logout(request: Request, response: Response):
|
||||
token = CookieUtil.get_cookie(request, key="auth_token")
|
||||
if token:
|
||||
CookieUtil.remove_cookie(response, key="auth_token", path="/" )
|
||||
logging.info(f"Token <UNK> cookie: {token}")
|
||||
return {"success": True, "message": "账号已登出!"}
|
||||
else:
|
||||
logging.error("<UNK> Token")
|
||||
return {"success": True, "message": "未找到有效Token,账号已登出!"}
|
||||
|
||||
|
||||
|
||||
# 验证用户并生成JWT令牌的接口
|
||||
@router.post("/login")
|
||||
async def login(request: Request, response: Response):
|
||||
|
||||
username = await get_request_str_param(request, "username", True, True)
|
||||
password = await get_request_str_param(request, "password", True, True)
|
||||
|
||||
if not username or not password:
|
||||
return {"success": False, "message": "用户名和密码不能为空"}
|
||||
|
||||
password = md5_encrypt(password)
|
||||
select_user_sql: str = "SELECT person_id, person_name, identity_id, login_name, xb, bureau_id, org_id, pwdmd5 FROM t_sys_loginperson WHERE login_name = '" + username + "' AND b_use = 1"
|
||||
userlist = await find_by_sql(select_user_sql,())
|
||||
user = userlist[0] if userlist else None
|
||||
logging.info(f"查询结果: {user}")
|
||||
if user and user['pwdmd5'] == password: # 验证的cas用户密码,md5加密的版本
|
||||
token = create_access_token({"user_id": user['person_id'], "identity_id": user['identity_id']})
|
||||
CookieUtil.set_cookie(
|
||||
res=response,
|
||||
key="auth_token",
|
||||
value=token,
|
||||
secure=False, # 在开发环境中,确保 secure=False
|
||||
httponly=False,
|
||||
max_age=3600, # 设置cookie的有效时间为1小时
|
||||
path="/" # 设置cookie的路径
|
||||
)
|
||||
logging.info(f"Token 已成功设置到 cookie: {token}")
|
||||
user.pop('pwdmd5', None) # 移除密码字段
|
||||
return {"success": True, "message": "登录成功", "data": {"token": token, "user_data": user}}
|
||||
else:
|
||||
return {"success": False, "message": "用户名或密码错误"}
|
||||
|
||||
|
||||
# 【Base-Login-3】通过手机号获取Person的ID
|
||||
@router.get("/getPersonIdByTelephone")
|
||||
async def get_person_id_by_telephone(request: Request):
|
||||
telephone = await get_request_str_param(request, "telephone", True, True)
|
||||
if not telephone:
|
||||
return {"success": False, "message": "手机号不能为空"}
|
||||
select_user_sql: str = "SELECT person_id FROM t_sys_loginperson WHERE telephone = '" + telephone + "' and b_use = 1 "
|
||||
userlist = await find_by_sql(select_user_sql,())
|
||||
user = userlist[0] if userlist else None
|
||||
if user:
|
||||
return {"success": True, "message": "查询成功", "data": {"person_id": user['person_id']}}
|
||||
else:
|
||||
return {"success": False, "message": "未查询到相关信息"}
|
||||
|
||||
|
||||
|
||||
# 【Base-Login-4】忘记密码重设,不登录的状态
|
||||
@router.post("/resetPassword")
|
||||
async def reset_password(request: Request):
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
password = await get_request_str_param(request, "password", True, True)
|
||||
if not person_id or not password:
|
||||
return {"success": False, "message": "用户ID和新密码不能为空"}
|
||||
password_md5 = md5_encrypt(password)
|
||||
update_user_sql: str = "UPDATE t_sys_loginperson SET original_pwd = '" + password + "', pwdmd5 = '" + password_md5 + "' WHERE person_id = '" + person_id + "'"
|
||||
await execute_sql(update_user_sql, ())
|
||||
return {"success": True, "message": "密码修改成功"}
|
292
dsLightRag/Routes/TeachingModel/api/TeachingModelController.py
Normal file
292
dsLightRag/Routes/TeachingModel/api/TeachingModelController.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# routes/TeachingModelController.py
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import urllib
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from Routes.TeachingModel.auth.dependencies import *
|
||||
from Util.LightRagUtil import *
|
||||
from Util.PageUtil import *
|
||||
from Util.ParseRequest import *
|
||||
from lightrag import *
|
||||
|
||||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
rag_type: str = "file"
|
||||
# rag_type: str = "pg"
|
||||
|
||||
|
||||
# 【TeachingModel-1】获取主题列表
|
||||
@router.get("/getTrainedTheme")
|
||||
async def get_trained_theme(request: Request):
|
||||
# 获取参数
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
|
||||
subject_id = await get_request_num_param(request, "subject_id", True, True, None)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
||||
|
||||
# 数据库查询
|
||||
select_trained_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 AND bureau_id = '{bureau_id}' AND stage_id = {stage_id} AND subject_id = {subject_id}"
|
||||
print(select_trained_theme_sql)
|
||||
page = await get_page_data_by_sql(select_trained_theme_sql, page_number, page_size)
|
||||
page = await translate_person_bureau_name(page)
|
||||
# 结果返回
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
# 【TeachingModel-2】获取热门主题列表
|
||||
@router.get("/getHotTheme")
|
||||
async def get_hot_theme(request: Request):
|
||||
# 获取参数
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 3)
|
||||
# 数据库查询
|
||||
select_hot_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY quote_count DESC"
|
||||
print(select_hot_theme_sql)
|
||||
page = await get_page_data_by_sql(select_hot_theme_sql, page_number, page_size)
|
||||
page = await translate_person_bureau_name(page)
|
||||
# 结果返回
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
# 【TeachingModel-3】获取最新主题列表
|
||||
@router.get("/getNewTheme")
|
||||
async def get_new_theme(request: Request):
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 3)
|
||||
# 数据库查询
|
||||
select_new_theme_sql: str = f"SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and search_flag = 1 and bureau_id = '{bureau_id}' ORDER BY create_time DESC"
|
||||
print(select_new_theme_sql)
|
||||
page = await get_page_data_by_sql(select_new_theme_sql, page_number, page_size)
|
||||
page = await translate_person_bureau_name(page)
|
||||
# 结果返回
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
# 【TeachingModel-4】获取问题列表
|
||||
@router.get("/getQuestion")
|
||||
async def get_question(request: Request):
|
||||
# 获取参数
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||||
question_type = await get_request_num_param(request, "question_type", True, True, None)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True, 1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
||||
|
||||
person_sql = ""
|
||||
if question_type == 2:
|
||||
person_sql = f"AND person_id = '{person_id}'"
|
||||
# 数据库查询
|
||||
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = {question_type} {person_sql} ORDER BY quote_count DESC, create_time DESC"
|
||||
print(select_question_sql)
|
||||
page = await get_page_data_by_sql(select_question_sql, page_number, page_size)
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
|
||||
##########################################################################################
|
||||
# 功能:【TeachingModel-5】向RAG提问,返回RAG的回复
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-08-14
|
||||
# 备注:question_type=2时,question_id必填,其他情况question_id不填;
|
||||
# question_type(0:新增问题;1:常见问题;2:个人历史问题;):
|
||||
# 0【新增问题】:会保存成个人历史问题,并保存答案
|
||||
# 1【常见问题】:增加常见问题引用次数,会保存成个人历史问题,并保存答案
|
||||
# 2【个人历史问题】对应页面新增的历史问题回显后“重试”操作:会重新回答问题,并替换最新答案
|
||||
##########################################################################################
|
||||
@router.post("/sendQuestion")
|
||||
async def send_question(request: Request):
|
||||
# 获取参数
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||||
question = await get_request_str_param(request, "question", True, True)
|
||||
question_type = await get_request_num_param(request, "question_type", True, True, None)
|
||||
question_id = await get_request_num_param(request, "question_id", False, True,0)
|
||||
|
||||
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
|
||||
if theme_object is None:
|
||||
return {"success": False, "message": "主题不存在!"}
|
||||
# 处理theme的调用次数
|
||||
update_sql: str = f"UPDATE t_ai_teaching_model_theme SET quote_count = quote_count + 1, update_time = now() WHERE id = {theme_id}"
|
||||
await execute_sql(update_sql, ())
|
||||
|
||||
if question_type == 2 and question_id == 0:
|
||||
return {"success": False, "message": "[question_type]=2时,[question_id]不允许为空!"}
|
||||
if question_type != 2:
|
||||
param = {}
|
||||
param["stage_id"] = int(theme_object["stage_id"])
|
||||
param["subject_id"] = int(theme_object["subject_id"])
|
||||
param["theme_id"] = theme_id
|
||||
param["question"] = question
|
||||
param["question_type"] = 2
|
||||
param["quote_count"] = 1
|
||||
param["question_person_id"] = person_id
|
||||
param["person_id"] = person_id
|
||||
param["bureau_id"] = bureau_id
|
||||
question_id = await insert("t_ai_teaching_model_question", param)
|
||||
if question_type == 1:
|
||||
# 处理常见问题引用次数
|
||||
update_common_question_sql: str = f"update t_ai_teaching_model_question set quote_count = quote_count + 1 where question_type = 1 and theme_id = {theme_id} and question = '{question}' and is_deleted = 0"
|
||||
await execute_sql(update_common_question_sql, ())
|
||||
|
||||
# 向rag提问
|
||||
topic = theme_object["short_name"]
|
||||
# mode = "hybrid"
|
||||
prompt = "\n 1、不要输出参考资料 或者 References !"
|
||||
prompt = prompt + "\n 2、资料中提供化学反应方程式的,一定要严格按提供的Latex公式输出,绝对不允许对Latex公式进行修改 !"
|
||||
prompt = prompt + "\n 3、如果资料中提供了图片的,一定要严格按照原文提供图片输出,绝对不能省略或不输出!"
|
||||
prompt = prompt + "\n 4、知识库中存在的问题,严格按知识库中的内容回答,不允许扩展!"
|
||||
prompt = prompt + "\n 5、如果问题与提供的知识库内容不符,则明确告诉未在知识库范围内提到,请这样告诉我:“暂时无法回答这个问题呢,我专注于【" + theme_object["theme_name"] + "】这方面知识,若需这方面帮助,请告知我更准确的信息吧~”"
|
||||
prompt = prompt + "\n 6、发现输出内容中包含Latex公式的,一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现!"
|
||||
working_path = "./Topic/" + topic
|
||||
# if rag_type == "file":
|
||||
async def generate_response_stream(query: str, mode: str, user_prompt: str):
|
||||
try:
|
||||
result = ""
|
||||
rag = await initialize_rag(working_path)
|
||||
resp = await rag.aquery(
|
||||
query=query,
|
||||
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
|
||||
async for chunk in resp:
|
||||
if not chunk:
|
||||
continue
|
||||
yield f"data: {json.dumps({'reply': chunk})}\n\n"
|
||||
result += chunk
|
||||
print(chunk, end='', flush=True)
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
finally:
|
||||
# 保存答案
|
||||
update_question_sql = f"update t_ai_teaching_model_question set question_answer = '{result}', update_time = now() where id = {question_id}"
|
||||
await execute_sql(update_question_sql, ())
|
||||
# 清理资源
|
||||
await rag.finalize_storages()
|
||||
|
||||
return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
|
||||
# elif rag_type == "pg":
|
||||
# workspace = theme_object["short_name"]
|
||||
# # 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。
|
||||
# working_dir = 'WorkingPath/' + workspace
|
||||
# if not os.path.exists(working_dir):
|
||||
# os.makedirs(working_dir)
|
||||
# async def generate_response_stream(query: str, mode: str, user_prompt: str):
|
||||
# try:
|
||||
# logger.info("workspace=" + workspace)
|
||||
# rag = await initialize_pg_rag(WORKING_DIR=working_dir, workspace=workspace)
|
||||
#
|
||||
# resp = await rag.aquery(
|
||||
# query=query,
|
||||
# param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
|
||||
# async for chunk in resp:
|
||||
# if not chunk:
|
||||
# continue
|
||||
# yield f"data: {json.dumps({'reply': chunk})}\n\n"
|
||||
# print(chunk, end='', flush=True)
|
||||
# except Exception as e:
|
||||
# yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
# finally:
|
||||
# # 发送流结束标记
|
||||
# yield "data: [DONE]\n\n"
|
||||
# # 清理资源
|
||||
# await rag.finalize_storages()
|
||||
# return EventSourceResponse(generate_response_stream(query=question, mode="hybrid", user_prompt=prompt))
|
||||
|
||||
|
||||
|
||||
@router.post("/saveWord")
|
||||
async def save_word(request: Request):
|
||||
# 获取参数
|
||||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||||
markdown_content = await get_request_str_param(request, "markdown_content", True, True)
|
||||
question = await get_request_str_param(request, "question", True, True)
|
||||
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
|
||||
if theme_object is None:
|
||||
return {"success": False, "message": "主题不存在!"}
|
||||
|
||||
filename = "【理想大模型】" + str(theme_object["theme_name"]) + "(" + str(question) + ")" + str(time.time()) + ".docx"
|
||||
output_file = None
|
||||
try:
|
||||
# markdown_content 替换换行符
|
||||
# markdown_content = markdown_content.replace("\n", "")
|
||||
|
||||
# 创建临时Markdown文件
|
||||
temp_md = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".md")
|
||||
with open(temp_md, "w", encoding="utf-8") as f:
|
||||
f.write(markdown_content)
|
||||
# 使用pandoc转换
|
||||
output_file = os.path.join(tempfile.gettempdir(), filename)
|
||||
subprocess.run(['pandoc', '--reference-doc', 'static/template/templates.docx', '-s', temp_md, '-o', output_file, '--resource-path=static'], check=True)
|
||||
|
||||
# 读取生成的Word文件
|
||||
with open(output_file, "rb") as f:
|
||||
stream = BytesIO(f.read())
|
||||
|
||||
# 返回响应
|
||||
encoded_filename = urllib.parse.quote(filename)
|
||||
return StreamingResponse(
|
||||
stream,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Connection": "close",
|
||||
"Content-Type": "application/vnd.ms-excel;charset=UTF-8",
|
||||
"Content-Disposition": f"attachment; filename={encoded_filename}"}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
finally:
|
||||
...
|
||||
# 清理临时文件
|
||||
try:
|
||||
if temp_md and os.path.exists(temp_md):
|
||||
os.remove(temp_md)
|
||||
if output_file and os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp files: {str(e)}")
|
||||
|
||||
|
||||
|
||||
########################################
|
||||
# 【TeachingModel-7】保存个人历史问题答案
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-08-14
|
||||
# 备注:
|
||||
########################################
|
||||
@router.post("/saveAnswer")
|
||||
async def save_answer(request: Request):
|
||||
# 获取参数
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
|
||||
question = await get_request_str_param(request, "question", True, True)
|
||||
answer = await get_request_str_param(request, "answer", True, True)
|
||||
|
||||
# 验证是否存在个人历史问题
|
||||
select_question_sql: str = f"SELECT * FROM t_ai_teaching_model_question WHERE is_deleted = 0 and bureau_id = '{bureau_id}' AND theme_id = {theme_id} AND question_type = 2 AND question = '{question}' AND person_id = '{person_id}' order by create_time desc limit 1"
|
||||
print(select_question_sql)
|
||||
select_question_result = await find_by_sql(select_question_sql, ())
|
||||
if select_question_result is None: # 个人历史问题不存在
|
||||
return {"success": False, "message": "个人历史问题不存在!"}
|
||||
|
||||
question_id = select_question_result[0].get("id")
|
||||
# 保存答案
|
||||
update_answer_sql: str = f"UPDATE t_ai_teaching_model_question SET question_answer = '{answer}', update_time = now() WHERE id = {question_id}"
|
||||
print(update_answer_sql)
|
||||
await execute_sql(update_answer_sql, ())
|
||||
# 结果返回
|
||||
return {"success": True, "message": "保存成功!"}
|
23
dsLightRag/Routes/TeachingModel/api/TestController.py
Normal file
23
dsLightRag/Routes/TeachingModel/api/TestController.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# routes/TestController.py
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from utils.ParseRequest import *
|
||||
|
||||
# 创建一个路由实例
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/parse", response_model=dict)
|
||||
async def parse_query(request: Request):
|
||||
request_data = await parse_request_data(request)
|
||||
return request_data
|
||||
@router.post("/parse", response_model=dict)
|
||||
async def parse_form(request: Request):
|
||||
request_data = await parse_request_data(request)
|
||||
return request_data
|
||||
|
||||
@router.post("/parse_json", response_model=dict)
|
||||
async def parse_json(request: Request):
|
||||
request_data = await parse_request_data(request)
|
||||
return request_data
|
175
dsLightRag/Routes/TeachingModel/api/ThemeController.py
Normal file
175
dsLightRag/Routes/TeachingModel/api/ThemeController.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# routes/ThemeController.py
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from Util.Database import *
|
||||
from Util.ParseRequest import *
|
||||
from Routes.TeachingModel.auth.dependencies import *
|
||||
from Util.PageUtil import *
|
||||
from Util.TranslateUtil import *
|
||||
|
||||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
|
||||
# 功能:【Theme-1】主题管理列表
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-07-14
|
||||
# 备注:
|
||||
@router.get("/list")
|
||||
async def list(request: Request):
|
||||
# 获取参数
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
|
||||
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
|
||||
page_number = await get_request_num_param(request, "page_number", False, True,1)
|
||||
page_size = await get_request_num_param(request, "page_size", False, True, 10)
|
||||
theme_name = await get_request_str_param(request, "theme_name", False, True)
|
||||
|
||||
print(stage_id, person_id, subject_id, page_number, page_size, theme_name)
|
||||
|
||||
# 拼接查询SQL语句
|
||||
select_theme_sql: str = " SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and person_id = '" + person_id + "'"
|
||||
if stage_id != -1:
|
||||
select_theme_sql += " and stage_id = " + str(stage_id)
|
||||
if subject_id != -1:
|
||||
select_theme_sql += " and subject_id = " + str(subject_id)
|
||||
if theme_name != "":
|
||||
select_theme_sql += " and theme_name like '%" + theme_name + "%'"
|
||||
select_theme_sql += " ORDER BY create_time DESC"
|
||||
|
||||
# 查询主题列表
|
||||
page = await get_page_data_by_sql(select_theme_sql, page_number, page_size)
|
||||
person_ids = ""
|
||||
for item in page["list"]:
|
||||
person_ids += "'" + item["person_id"] + "',"
|
||||
if person_ids != "":
|
||||
person_ids = person_ids[:-1]
|
||||
else:
|
||||
person_ids = "''"
|
||||
|
||||
person_map = await get_person_map(person_ids)
|
||||
stage_map = await get_stage_map()
|
||||
subject_map = await get_subject_map()
|
||||
for item in page["list"]:
|
||||
item["stage_name"] = stage_map.get(str(item["stage_id"]), "未知学段")
|
||||
item["subject_name"] = subject_map.get(str(item["subject_id"]), "未知学科")
|
||||
item["person_name"] = person_map.get(str(item["person_id"]), "未知姓名")
|
||||
|
||||
return {"success": True, "message": "查询成功!", "data": page}
|
||||
|
||||
|
||||
# 功能:【Theme-2】保存主题管理
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-07-14
|
||||
# 备注:
|
||||
@router.post("/save")
|
||||
async def save(request: Request):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", False, True, 0)
|
||||
theme_name = await get_request_str_param(request, "theme_name", True, True)
|
||||
if len(theme_name) > 50:
|
||||
return {"success": False, "message": "主题名称不能超过50个字符!"}
|
||||
short_name = await get_request_str_param(request, "short_name", True, True)
|
||||
if len(short_name) > 50:
|
||||
return {"success": False, "message": "主题英文简称不能超过50个字符!"}
|
||||
theme_icon = await get_request_str_param(request, "theme_icon", False, True)
|
||||
if len(theme_name) > 200:
|
||||
return {"success": False, "message": "主题图标不能超过200个字符!"}
|
||||
stage_id = await get_request_num_param(request, "stage_id", True, True, None)
|
||||
subject_id = await get_request_num_param(request, "subject_id", True, True, None)
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
|
||||
|
||||
# 校验参数
|
||||
check_theme_sql = "SELECT theme_name FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and bureau_id = '" + bureau_id + "' and theme_name = '" + theme_name + "'"
|
||||
if id != 0:
|
||||
check_theme_sql += " and id <> " + str(id)
|
||||
print(check_theme_sql)
|
||||
check_theme_result = await find_by_sql(check_theme_sql,())
|
||||
if check_theme_result:
|
||||
return {"success": False, "message": "该主题名称已存在!"}
|
||||
if short_name.length > 50:
|
||||
return {"success": False, "message": "主题英文简称不能超过50个字符!"}
|
||||
check_short_name_sql = "SELECT short_name FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and bureau_id = '" + bureau_id + "' and short_name = '" + short_name + "'"
|
||||
if id != 0:
|
||||
check_short_name_sql += " and id <> " + str(id)
|
||||
print(check_short_name_sql)
|
||||
check_short_name_result = await find_by_sql(check_short_name_sql,())
|
||||
if check_short_name_result:
|
||||
return {"success": False, "message": "该主题英文简称已存在!"}
|
||||
|
||||
# 组装参数
|
||||
param = {"theme_name": theme_name,"short_name": short_name,"theme_icon": theme_icon,"stage_id": stage_id,"subject_id": subject_id,"person_id": person_id,"bureau_id": bureau_id}
|
||||
|
||||
# 保存数据
|
||||
if id == 0:
|
||||
param["search_flag"] = 0
|
||||
param["train_flag"] = 0
|
||||
param["quote_count"] = 0
|
||||
# 插入数据
|
||||
id = await insert("t_ai_teaching_model_theme", param, False)
|
||||
return {"success": True, "message": "保存成功!", "data": {"insert_id" : id}}
|
||||
else:
|
||||
# 更新数据
|
||||
await update("t_ai_teaching_model_theme", param, "id", id, False)
|
||||
return {"success": True, "message": "更新成功!", "data": {"update_id" : id}}
|
||||
|
||||
|
||||
# 功能:【Theme-3】获取主题信息
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-07-14
|
||||
# 备注:
|
||||
@router.get("/get")
|
||||
async def get(request: Request):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", True, True, None)
|
||||
theme_obj = await find_by_id("t_ai_teaching_model_theme", "id", id)
|
||||
if theme_obj is None:
|
||||
return {"success": False, "message": "未查询到该主题信息!"}
|
||||
|
||||
stage_map = await get_stage_map()
|
||||
subject_map = await get_subject_map()
|
||||
theme_obj["stage_name"] = stage_map.get(str(theme_obj["stage_id"]), "未知学段")
|
||||
theme_obj["subject_name"] = subject_map.get(str(theme_obj["subject_id"]), "未知学科")
|
||||
theme_obj["person_name"] = await find_person_name_by_id(theme_obj["person_id"])
|
||||
|
||||
return {"success": True, "message": "查询成功!", "data": {"theme": theme_obj}}
|
||||
|
||||
|
||||
@router.post("/delete")
|
||||
async def delete(request: Request):
|
||||
# 获取参数
|
||||
id = await get_request_num_param(request, "id", True, True, None)
|
||||
result = await delete_by_id("t_ai_teaching_model_theme", "id", id)
|
||||
if not result:
|
||||
return {"success": False, "message": "删除失败!"}
|
||||
return {"success": True, "message": "删除成功!"}
|
||||
|
||||
|
||||
|
||||
# 功能:【Theme-5】根据学段学科获取主题列表
|
||||
# 作者:Kalman.CHENG ☆
|
||||
# 时间:2025-07-31
|
||||
# 备注:
|
||||
@router.get("/getListByStageSubject")
|
||||
async def get_list_by_stage_subject(request: Request):
|
||||
# 获取参数
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
stage_id = await get_request_num_param(request, "stage_id", False, True, -1)
|
||||
subject_id = await get_request_num_param(request, "subject_id", False, True, -1)
|
||||
|
||||
# 拼接查询SQL语句
|
||||
select_theme_sql: str = " select id as theme_id, theme_name from t_ai_teaching_model_theme where is_deleted = 0 and person_id = '" + person_id + "'"
|
||||
if stage_id != -1:
|
||||
select_theme_sql += " and stage_id = " + str(stage_id)
|
||||
if subject_id != -1:
|
||||
select_theme_sql += " and subject_id = " + str(subject_id)
|
||||
select_theme_result = await find_by_sql(select_theme_sql,())
|
||||
|
||||
return {"success": True, "message": "查询成功!", "data": {"theme_list": select_theme_result}}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
50
dsLightRag/Routes/TeachingModel/api/UserController.py
Normal file
50
dsLightRag/Routes/TeachingModel/api/UserController.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# routes/UserController.py
|
||||
import re
|
||||
|
||||
from fastapi import APIRouter, Request, Response, Depends
|
||||
from Routes.TeachingModel.auth.dependencies import *
|
||||
from Util.CommonUtil import md5_encrypt
|
||||
from Util.Database import *
|
||||
from Util.ParseRequest import *
|
||||
|
||||
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
|
||||
router = APIRouter(dependencies=[Depends(get_current_user)])
|
||||
|
||||
# 【Base-User-1】维护用户手机号
|
||||
@router.post("/modifyTelephone")
|
||||
async def modify_telephone(request: Request):
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
telephone = await get_request_str_param(request, "telephone", True, True)
|
||||
# 校验手机号码格式
|
||||
if not re.match(r"^1[3-9]\d{9}$", telephone):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="手机号码格式错误")
|
||||
# 校验手机号码是否已被注册
|
||||
select_telephone_sql: str = "select * from t_sys_loginperson where b_use = 1 and telephone = '" + telephone + "' and person_id <> '" + person_id + "'"
|
||||
userlist = await find_by_sql(select_telephone_sql, ())
|
||||
if userlist is not None:
|
||||
return {"success": False, "message": "手机号码已被注册"}
|
||||
else:
|
||||
update_telephone_sql: str = "update t_sys_loginperson set telephone = '" + telephone + "' where person_id = '" + person_id + "'"
|
||||
await execute_sql(update_telephone_sql, ())
|
||||
return {"success": True, "message": "修改成功"}
|
||||
|
||||
|
||||
# 【Base-User-2】维护用户密码
|
||||
@router.post("/modifyPassword")
|
||||
async def modify_password(request: Request):
|
||||
person_id = await get_request_str_param(request, "person_id", True, True)
|
||||
old_password = await get_request_str_param(request, "old_password", True, True)
|
||||
password = await get_request_str_param(request, "password", True, True)
|
||||
# 校验旧密码是否正确
|
||||
select_password_sql: str = "select pwdmd5 from t_sys_loginperson where person_id = '" + person_id + "' and b_use = 1"
|
||||
userlist = await find_by_sql(select_password_sql, ())
|
||||
if len(userlist) == 0:
|
||||
return {"success": False, "message": "用户不存在"}
|
||||
else:
|
||||
if userlist[0]["pwdmd5"] != md5_encrypt(old_password):
|
||||
return {"success": False, "message": "旧密码错误"}
|
||||
else:
|
||||
update_password_sql: str = "update t_sys_loginperson set original_pwd = '" + password + "',pwdmd5 = '" + md5_encrypt(password) + "' where person_id = '" + person_id + "'"
|
||||
await execute_sql(update_password_sql, ())
|
||||
return {"success": True, "message": "修改成功"}
|
||||
|
29
dsLightRag/Routes/TeachingModel/auth/dependencies.py
Normal file
29
dsLightRag/Routes/TeachingModel/auth/dependencies.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# dependencies.py
|
||||
|
||||
from fastapi import Request, HTTPException, status, Header
|
||||
from Util.JwtUtil import *
|
||||
|
||||
async def get_current_user(token: str = Header(None, description="Authorization token")):
|
||||
if token is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未提供令牌,请登录后使用!",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
token = token.split(" ")[1] if token.startswith("Bearer ") else token
|
||||
payload = verify_token(token)
|
||||
if payload is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无法验证凭据,请登录后使用!",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
user_id = payload.get("user_id")
|
||||
identity_id = payload.get("identity_id")
|
||||
if user_id is None or identity_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无法验证用户,请重新登录后使用!",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user_id
|
102
dsLightRag/Routes/TeachingModel/tasks/BackgroundTasks.py
Normal file
102
dsLightRag/Routes/TeachingModel/tasks/BackgroundTasks.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
|
||||
from Util.Database import find_by_sql, find_by_id, execute_sql
|
||||
from Util.DocxUtil import get_docx_content_by_pandoc
|
||||
from Util.LightRagUtil import initialize_rag
|
||||
|
||||
# 更详细地控制日志输出
|
||||
logger = logging.getLogger('lightrag')
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
logger.addHandler(handler)
|
||||
|
||||
# 后台任务,监控是否有新的未训练的文档进行训练
|
||||
async def train_document_task():
|
||||
print(datetime.datetime.now(), "线程5秒后开始运行【监控是否有需要处理的文档】")
|
||||
await asyncio.sleep(5) # 使用 asyncio.sleep 而不是 time.sleep
|
||||
# 这里放置你的线程逻辑
|
||||
while True:
|
||||
print("测试定时任务运行")
|
||||
handle_flag = False
|
||||
if handle_flag:
|
||||
# 这里可以放置你的线程要执行的代码
|
||||
logging.info("开始查询是否有待处理文档:")
|
||||
no_train_document_sql: str = " SELECT * FROM t_ai_teaching_model_document WHERE train_flag in (0,3) ORDER BY create_time DESC"
|
||||
no_train_document_result = await find_by_sql(no_train_document_sql, ())
|
||||
logger.info(no_train_document_result)
|
||||
if not no_train_document_result:
|
||||
print(datetime.datetime.now(), "没有需要处理的文档")
|
||||
else:
|
||||
print(datetime.datetime.now(), "存在未训练的文档" + str(len(no_train_document_result))+"个")
|
||||
# 这里可以根据train_flag的值来判断是训练还是删除
|
||||
document = no_train_document_result[0]
|
||||
theme = await find_by_id("t_ai_teaching_model_theme", "id", document["theme_id"])
|
||||
document_name = document["document_name"] + "." + document["document_suffix"]
|
||||
working_dir = "Topic/" + theme["short_name"]
|
||||
document_path = document["document_path"]
|
||||
if document["train_flag"] == 0:
|
||||
# 训练文档
|
||||
update_sql_document: str = " UPDATE t_ai_teaching_model_document SET train_flag = 1 WHERE id = " + str(document["id"])
|
||||
await execute_sql(update_sql_document, ())
|
||||
update_sql_theme: str = " UPDATE t_ai_teaching_model_theme SET train_flag = 1 WHERE id = " + str(theme["id"])
|
||||
await execute_sql(update_sql_theme, ())
|
||||
logging.info(f"开始处理文档:{document_name}, 还有{len(no_train_document_result) - 1}个文档需要处理!")
|
||||
# 训练代码开始
|
||||
# content = get_docx_content_by_pandoc(document_path)
|
||||
try:
|
||||
# 注意:默认设置使用NetworkX
|
||||
rag = await initialize_rag(working_dir)
|
||||
# 获取docx文件的内容
|
||||
content = get_docx_content_by_pandoc(document_path)
|
||||
await rag.ainsert(content, ids=[document_name], file_paths=[document_name])
|
||||
logger.info(f"Inserted content from {document_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred: {e}")
|
||||
finally:
|
||||
await rag.finalize_storages()
|
||||
# 训练结束,更新训练状态
|
||||
update_document_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 2 WHERE id = " + str(document["id"])
|
||||
await execute_sql(update_document_sql, ())
|
||||
elif document["train_flag"] == 3:
|
||||
update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 4 WHERE id = " + str(document["id"])
|
||||
await execute_sql(update_sql, ())
|
||||
logging.info(f"开始删除文档:{document_name}, 还有{len(no_train_document_result) - 1}个文档需要处理!")
|
||||
# 删除文档开始
|
||||
try:
|
||||
# 注意:默认设置使用NetworkX
|
||||
rag = await initialize_rag(working_dir)
|
||||
await rag.adelete_by_doc_id(doc_id = document_name)
|
||||
logger.info(f"Deleted content from {document_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred: {e}")
|
||||
finally:
|
||||
await rag.finalize_storages()
|
||||
# 删除结束,更新训练状态
|
||||
update_document_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 5 WHERE id = " + str(document["id"])
|
||||
await execute_sql(update_document_sql, ())
|
||||
|
||||
# 整体更新主题状态
|
||||
select_document_sql: str = f"select train_flag, count(1) as train_count from t_ai_teaching_model_document where theme_id = {theme['id']} and is_deleted = 0 and train_flag in (0,1,2) group by train_flag"
|
||||
select_document_result = await find_by_sql(select_document_sql, ())
|
||||
train_document_count_map = {}
|
||||
for item in select_document_result:
|
||||
train_document_count_map[str(item["train_flag"])] = int(item["train_count"])
|
||||
train_document_count_1 = train_document_count_map.get("1", 0)
|
||||
train_document_count_2 = train_document_count_map.get("2", 0)
|
||||
if train_document_count_2 > 0:
|
||||
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 1, train_flag = 2 WHERE id = {theme['id']}"
|
||||
await execute_sql(update_theme_sql, ())
|
||||
else:
|
||||
if train_document_count_1 > 0:
|
||||
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 0, train_flag = 1 WHERE id = {theme['id']}"
|
||||
await execute_sql(update_theme_sql, ())
|
||||
else:
|
||||
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 0, train_flag = 0 WHERE id = {theme['id']}"
|
||||
await execute_sql(update_theme_sql, ())
|
||||
# 添加适当的等待时间,避免频繁查询
|
||||
await asyncio.sleep(20) # 开发阶段每20秒一次
|
||||
# await asyncio.sleep(600) # 每十分钟查询一次
|
Reference in New Issue
Block a user