教育垂直领域大模型平台

modify by Kalman.CHENG ☆
main
Kalman.CHENG 1 week ago
parent b8b4b08b3f
commit fca710e071

@ -1,4 +1,4 @@
# routes/LoginController.py
# routes/DmController.py
from fastapi import APIRouter, Depends

@ -1,4 +1,4 @@
# routes/LoginController.py
# routes/DocumentController.py
import os
from fastapi import APIRouter, Request, Response, Depends, UploadFile, File

@ -129,3 +129,31 @@ async def login(request: Request, response: Response):
else:
return {"success": False, "message": "用户名或密码错误"}
# 【Base-Login-3】通过手机号获取Person的ID
@router.get("/getPersonIdByTelephone")
async def get_person_id_by_telephone(request: Request):
telephone = await get_request_str_param(request, "telephone", True, True)
if not telephone:
return {"success": False, "message": "手机号不能为空"}
select_user_sql: str = "SELECT person_id FROM t_sys_loginperson WHERE telephone = '" + telephone + "' and b_use = 1 "
userlist = await find_by_sql(select_user_sql,())
user = userlist[0] if userlist else None
if user:
return {"success": True, "message": "查询成功", "data": {"person_id": user['person_id']}}
else:
return {"success": False, "message": "未查询到相关信息"}
# 【Base-Login-4】忘记密码重设不登录的状态
@router.post("/resetPassword")
async def reset_password(request: Request):
person_id = await get_request_str_param(request, "person_id", True, True)
password = await get_request_str_param(request, "password", True, True)
if not person_id or not password:
return {"success": False, "message": "用户ID和新密码不能为空"}
password_md5 = md5_encrypt(password)
update_user_sql: str = "UPDATE t_sys_loginperson SET original_pwd = '" + password + "', pwdmd5 = '" + password_md5 + "' WHERE person_id = '" + person_id + "'"
await execute_sql(update_user_sql)
return {"success": True, "message": "密码修改成功"}

@ -1,4 +1,4 @@
# routes/LoginController.py
# routes/QuestionController.py
from fastapi import APIRouter, Request, Response, Depends
from auth.dependencies import *

@ -1,4 +1,4 @@
# routes/LoginController.py
# routes/TestController.py
from fastapi import APIRouter, Request

@ -1,4 +1,4 @@
# routes/LoginController.py
# routes/ThemeController.py
from fastapi import APIRouter, Depends
from utils.ParseRequest import *

@ -0,0 +1,32 @@
# routes/UserController.py
import re
from fastapi import APIRouter, Request, Response, Depends
from auth.dependencies import *
from utils.Database import *
from utils.ParseRequest import *
# 创建一个路由实例,需要依赖get_current_user,登录后才能访问
router = APIRouter(dependencies=[Depends(get_current_user)])
# 【Base-User-1】维护用户手机号
@router.post("/modifyTelephone")
async def modify_telephone(request: Request):
person_id = await get_request_str_param(request, "person_id", True, True)
telephone = await get_request_str_param(request, "telephone", True, True)
# 校验手机号码格式
if not re.match(r"^1[3-9]\d{9}$", telephone):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="手机号码格式错误")
# 校验手机号码是否已被注册
select_telephone_sql: str = "select * from t_sys_loginperson where b_use = 1 and telephone = '" + telephone + "' and person_id <> '" + person_id + "'"
userlist = await find_by_sql(select_telephone_sql, ())
if len(userlist) > 0:
return {"success": False, "message": "手机号码已被注册"}
else:
update_telephone_sql: str = "update t_sys_loginperson set telephone = '" + telephone + "' where person_id = '" + person_id + "'"
await execute_sql(update_telephone_sql)
return {"success": True, "message": "修改成功"}
# 【Base-User-2】维护用户密码
# @router.post("/modifyPassword")

@ -1,6 +1,7 @@
import threading
import logging
import uvicorn
import asyncio
from fastapi.middleware.cors import CORSMiddleware
from starlette.staticfiles import StaticFiles
@ -18,11 +19,12 @@ logging.basicConfig(
)
async def lifespan(app: FastAPI):
# 启动线程
thread = threading.Thread(target=train_document_task, daemon=True)
thread.start()
# 创建数据库连接池
await init_database()
# 启动异步任务
asyncio.create_task(train_document_task())
yield
await shutdown_database()
@ -41,8 +43,10 @@ app.add_middleware(
app.mount("/static", StaticFiles(directory="Static"), name="static")
# 注册路由
# 登录相关
# 登录相关(不用登录)
app.include_router(login_router, prefix="/api/login", tags=["login"])
# 用户相关
app.include_router(user_router, prefix="/api/user", tags=["user"])
# 主题相关
app.include_router(theme_router, prefix="/api/theme", tags=["theme"])
# 文档相关

@ -5,6 +5,7 @@ from api.controller.ThemeController import router as theme_router
from api.controller.QuestionController import router as question_router
from api.controller.TestController import router as test_router
from api.controller.DmController import router as dm_router
from api.controller.UserController import router as user_router
# 导出所有路由
__all__ = ["login_router", "document_router", "theme_router", "question_router", "dm_router", "test_router"]
__all__ = ["login_router", "document_router", "theme_router", "question_router", "dm_router", "test_router", "user_router"]

@ -1,12 +1,52 @@
import asyncio
import logging
import time
from utils.Database import *
from utils.DocxUtil import get_docx_content_by_pandoc
from utils.LightRagUtil import initialize_pg_rag
# 使用PG库后这个是没有用的,但目前的项目代码要求必传,就写一个吧。
WORKING_DIR = f"./output"
# 后台任务,监控是否有新的未训练的文档进行训练
def train_document_task():
async def train_document_task():
print("线程5秒后开始运行【监控是否有新的未训练的文档进行训练】")
time.sleep(5) # 线程5秒后开始运行
await asyncio.sleep(5) # 使用 asyncio.sleep 而不是 time.sleep
# 这里放置你的线程逻辑
while True:
# 这里可以放置你的线程要执行的代码
logging.info("线程正在运行")
time.sleep(1000) # 每隔10秒执行一次
logging.info("开始查询是否有未训练的文档")
no_train_document_sql: str = " SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and train_flag = 0 ORDER BY create_time DESC"
no_train_document_result = await find_by_sql(no_train_document_sql, ())
if not no_train_document_result:
logging.info("没有未训练的文档")
else:
logging.info("存在未训练的文档" + str(len(no_train_document_result))+"")
# document = no_train_document_result[0]
# print("开始训练文档:" + document["document_name"])
# theme = await find_by_id("t_ai_teaching_model_theme", "id", document["theme_id"])
# # 训练开始前,更新训练状态
# update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 1 WHERE id = " + str(document["id"])
# execute_sql(update_sql)
# document_name = document["document_name"] + "." + document["document_suffix"]
# logging.info("开始训练文档:" + document_name)
# workspace = theme["short_name"]
# docx_name = document_name
# docx_path = document["document_path"]
# logging.info(f"开始处理文档:{docx_name}, 还有%s个文档需要处理", len(no_train_document_result) - 1)
# # 训练代码开始
# try:
# rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
# # 获取docx文件的内容
# content = get_docx_content_by_pandoc(docx_path)
# await rag.insert(input=content, file_paths=[docx_name])
# finally:
# if rag:
# await rag.finalize_storages()
# # 训练结束,更新训练状态
# update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 2 WHERE id = " + str(document["id"])
# execute_sql(update_sql)
# 添加适当的等待时间,避免频繁查询
await asyncio.sleep(60) # 每分钟查询一次

@ -204,4 +204,16 @@ async def delete_by_id(table_name, property_name, property_value):
raise Exception(f"为表[{table_name}]删除数据失败: {e}")
else:
logging.error("参数不全")
return False
return False
# 执行一个SQL语句
async def execute_sql(sql):
logging.debug(sql)
try:
async with pool.acquire() as conn:
await conn.fetch(sql)
except Exception as e:
logging.error(f"数据库查询错误: {e}")
logging.error(f"执行的SQL语句: {sql}")
raise Exception(f"执行SQL失败: {e}")

@ -1,8 +1,56 @@
import logging
import os
import subprocess
import uuid
from PIL import Image
import os
# 在程序开始时添加以下配置
logging.basicConfig(
level=logging.INFO, # 设置日志级别为INFO
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 或者如果你想更详细地控制日志输出
logger = logging.getLogger('DocxUtil')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
def resize_images_in_directory(directory_path, max_width=640, max_height=480):
"""
遍历目录下所有图片并缩放到指定尺寸
:param directory_path: 图片目录路径
:param max_width: 最大宽度
:param max_height: 最大高度
"""
# 支持的图片格式
valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
for root, _, files in os.walk(directory_path):
for filename in files:
if filename.lower().endswith(valid_extensions):
file_path = os.path.join(root, filename)
try:
with Image.open(file_path) as img:
# 计算缩放比例
width, height = img.size
ratio = min(max_width / width, max_height / height)
# 如果图片已经小于目标尺寸,则跳过
if ratio >= 1:
continue
# 计算新尺寸并缩放
new_size = (int(width * ratio), int(height * ratio))
resized_img = img.resize(new_size, Image.Resampling.LANCZOS)
# 保存图片(覆盖原文件)
resized_img.save(file_path)
logger.info(f"已缩放: {file_path} -> {new_size}")
except Exception as e:
logger.error(f"处理 {file_path} 时出错: {str(e)}")
def get_docx_content_by_pandoc(docx_file):
# 最后拼接的内容
content = ""
@ -15,6 +63,9 @@ def get_docx_content_by_pandoc(docx_file):
os.mkdir("./static/Images/" + file_name)
subprocess.run(['pandoc', docx_file, '-f', 'docx', '-t', 'markdown', '-o', temp_markdown,
'--extract-media=./static/Images/' + file_name])
# 遍历目录 './static/Images/'+file_name 下所有的图片缩小于640*480的尺寸上
resize_images_in_directory('./static/Images/' + file_name+'/media')
# 读取然后修改内容,输出到新的文件
img_idx = 0 # 图片索引
with open(temp_markdown, 'r', encoding='utf-8') as f:
@ -23,8 +74,9 @@ def get_docx_content_by_pandoc(docx_file):
if not line:
continue
# 跳过图片高度描述行
if line.startswith('height=') and line.endswith('in"}'):
if line.startswith('height=') and (line.endswith('in"}') or line.endswith('in"')):
continue
# height="1.91044072615923in"
# 使用find()方法安全地检查图片模式
is_img = line.find("![](") >= 0 and (
line.find(".png") > 0 or

@ -1,9 +1,7 @@
import logging
import logging.config
import os
import numpy as np
from lightrag import LightRAG
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
@ -25,7 +23,7 @@ def configure_logging():
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(
os.path.join(log_dir, "./logs/lightrag.log")
os.path.join(log_dir, "./Logs/lightrag.log")
)
print(f"\nLightRAG log file: {log_file_path}\n")
@ -97,10 +95,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
)
async def initialize_rag(working_dir):
async def initialize_rag(working_dir, graph_storage=None):
if graph_storage is None:
graph_storage = 'NetworkXStorage'
rag = LightRAG(
working_dir=working_dir,
llm_model_func=llm_model_func,
graph_storage=graph_storage,
embedding_func=EmbeddingFunc(
embedding_dim=EMBED_DIM,
max_token_size=EMBED_MAX_TOKEN_SIZE,
@ -139,4 +140,40 @@ def create_embedding_func():
api_key=EMBED_API_KEY,
base_url=EMBED_BASE_URL,
),
)
)
# AGE
os.environ["AGE_GRAPH_NAME"] = AGE_GRAPH_NAME
os.environ["POSTGRES_HOST"] = POSTGRES_HOST
os.environ["POSTGRES_PORT"] = str(POSTGRES_PORT)
os.environ["POSTGRES_USER"] = POSTGRES_USER
os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD
os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE
async def initialize_pg_rag(WORKING_DIR, workspace='default'):
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
llm_model_name=LLM_MODEL_NAME,
llm_model_max_async=4,
llm_model_max_token_size=32768,
enable_llm_cache_for_entity_extract=True,
embedding_func=EmbeddingFunc(
embedding_dim=EMBED_DIM,
max_token_size=EMBED_MAX_TOKEN_SIZE,
func=embedding_func
),
kv_storage="PGKVStorage",
doc_status_storage="PGDocStatusStorage",
graph_storage="PGGraphStorage",
vector_storage="PGVectorStorage",
auto_manage_storages_states=False,
vector_db_storage_cls_kwargs={"workspace": workspace}
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag

Loading…
Cancel
Save