diff --git a/dsAiTeachingModel/api/controller/DmController.py b/dsAiTeachingModel/api/controller/DmController.py index 7dbd0e51..2f719317 100644 --- a/dsAiTeachingModel/api/controller/DmController.py +++ b/dsAiTeachingModel/api/controller/DmController.py @@ -1,4 +1,4 @@ -# routes/LoginController.py +# routes/DmController.py from fastapi import APIRouter, Depends diff --git a/dsAiTeachingModel/api/controller/DocumentController.py b/dsAiTeachingModel/api/controller/DocumentController.py index d88da710..bd0e2998 100644 --- a/dsAiTeachingModel/api/controller/DocumentController.py +++ b/dsAiTeachingModel/api/controller/DocumentController.py @@ -1,4 +1,4 @@ -# routes/LoginController.py +# routes/DocumentController.py import os from fastapi import APIRouter, Request, Response, Depends, UploadFile, File diff --git a/dsAiTeachingModel/api/controller/LoginController.py b/dsAiTeachingModel/api/controller/LoginController.py index 307fd3b6..b3368d35 100644 --- a/dsAiTeachingModel/api/controller/LoginController.py +++ b/dsAiTeachingModel/api/controller/LoginController.py @@ -129,3 +129,31 @@ async def login(request: Request, response: Response): else: return {"success": False, "message": "用户名或密码错误"} + +# 【Base-Login-3】通过手机号获取Person的ID +@router.get("/getPersonIdByTelephone") +async def get_person_id_by_telephone(request: Request): + telephone = await get_request_str_param(request, "telephone", True, True) + if not telephone: + return {"success": False, "message": "手机号不能为空"} + select_user_sql: str = "SELECT person_id FROM t_sys_loginperson WHERE telephone = '" + telephone + "' and b_use = 1 " + userlist = await find_by_sql(select_user_sql,()) + user = userlist[0] if userlist else None + if user: + return {"success": True, "message": "查询成功", "data": {"person_id": user['person_id']}} + else: + return {"success": False, "message": "未查询到相关信息"} + + + +# 【Base-Login-4】忘记密码重设,不登录的状态 +@router.post("/resetPassword") +async def reset_password(request: Request): + person_id = await get_request_str_param(request, "person_id", True, True) + password = await get_request_str_param(request, "password", True, True) + if not person_id or not password: + return {"success": False, "message": "用户ID和新密码不能为空"} + password_md5 = md5_encrypt(password) + update_user_sql: str = "UPDATE t_sys_loginperson SET original_pwd = '" + password + "', pwdmd5 = '" + password_md5 + "' WHERE person_id = '" + person_id + "'" + await execute_sql(update_user_sql) + return {"success": True, "message": "密码修改成功"} \ No newline at end of file diff --git a/dsAiTeachingModel/api/controller/QuestionController.py b/dsAiTeachingModel/api/controller/QuestionController.py index 89456bc5..48b7ed39 100644 --- a/dsAiTeachingModel/api/controller/QuestionController.py +++ b/dsAiTeachingModel/api/controller/QuestionController.py @@ -1,4 +1,4 @@ -# routes/LoginController.py +# routes/QuestionController.py from fastapi import APIRouter, Request, Response, Depends from auth.dependencies import * diff --git a/dsAiTeachingModel/api/controller/TestController.py b/dsAiTeachingModel/api/controller/TestController.py index 4a572ff3..5c6a8ed5 100644 --- a/dsAiTeachingModel/api/controller/TestController.py +++ b/dsAiTeachingModel/api/controller/TestController.py @@ -1,4 +1,4 @@ -# routes/LoginController.py +# routes/TestController.py from fastapi import APIRouter, Request diff --git a/dsAiTeachingModel/api/controller/ThemeController.py b/dsAiTeachingModel/api/controller/ThemeController.py index 297817d1..3bd9fcd5 100644 --- a/dsAiTeachingModel/api/controller/ThemeController.py +++ b/dsAiTeachingModel/api/controller/ThemeController.py @@ -1,4 +1,4 @@ -# routes/LoginController.py +# routes/ThemeController.py from fastapi import APIRouter, Depends from utils.ParseRequest import * diff --git a/dsAiTeachingModel/api/controller/UserController.py b/dsAiTeachingModel/api/controller/UserController.py new file mode 100644 index 00000000..e23d8f5f --- /dev/null +++ b/dsAiTeachingModel/api/controller/UserController.py @@ -0,0 +1,32 @@ +# routes/UserController.py +import re + +from fastapi import APIRouter, Request, Response, Depends +from auth.dependencies import * +from utils.Database import * +from utils.ParseRequest import * + +# 创建一个路由实例,需要依赖get_current_user,登录后才能访问 +router = APIRouter(dependencies=[Depends(get_current_user)]) + +# 【Base-User-1】维护用户手机号 +@router.post("/modifyTelephone") +async def modify_telephone(request: Request): + person_id = await get_request_str_param(request, "person_id", True, True) + telephone = await get_request_str_param(request, "telephone", True, True) + # 校验手机号码格式 + if not re.match(r"^1[3-9]\d{9}$", telephone): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="手机号码格式错误") + # 校验手机号码是否已被注册 + select_telephone_sql: str = "select * from t_sys_loginperson where b_use = 1 and telephone = '" + telephone + "' and person_id <> '" + person_id + "'" + userlist = await find_by_sql(select_telephone_sql, ()) + if len(userlist) > 0: + return {"success": False, "message": "手机号码已被注册"} + else: + update_telephone_sql: str = "update t_sys_loginperson set telephone = '" + telephone + "' where person_id = '" + person_id + "'" + await execute_sql(update_telephone_sql) + return {"success": True, "message": "修改成功"} + + +# 【Base-User-2】维护用户密码 +# @router.post("/modifyPassword") diff --git a/dsAiTeachingModel/main.py b/dsAiTeachingModel/main.py index 18a578f1..8f99a901 100644 --- a/dsAiTeachingModel/main.py +++ b/dsAiTeachingModel/main.py @@ -1,6 +1,7 @@ import threading -import logging + import uvicorn +import asyncio from fastapi.middleware.cors import CORSMiddleware from starlette.staticfiles import StaticFiles @@ -18,11 +19,12 @@ logging.basicConfig( ) async def lifespan(app: FastAPI): - # 启动线程 - thread = threading.Thread(target=train_document_task, daemon=True) - thread.start() # 创建数据库连接池 await init_database() + + # 启动异步任务 + asyncio.create_task(train_document_task()) + yield await shutdown_database() @@ -41,8 +43,10 @@ app.add_middleware( app.mount("/static", StaticFiles(directory="Static"), name="static") # 注册路由 -# 登录相关 +# 登录相关(不用登录) app.include_router(login_router, prefix="/api/login", tags=["login"]) +# 用户相关 +app.include_router(user_router, prefix="/api/user", tags=["user"]) # 主题相关 app.include_router(theme_router, prefix="/api/theme", tags=["theme"]) # 文档相关 diff --git a/dsAiTeachingModel/routes/__init__.py b/dsAiTeachingModel/routes/__init__.py index 5bde8674..4fa720b9 100644 --- a/dsAiTeachingModel/routes/__init__.py +++ b/dsAiTeachingModel/routes/__init__.py @@ -5,6 +5,7 @@ from api.controller.ThemeController import router as theme_router from api.controller.QuestionController import router as question_router from api.controller.TestController import router as test_router from api.controller.DmController import router as dm_router +from api.controller.UserController import router as user_router # 导出所有路由 -__all__ = ["login_router", "document_router", "theme_router", "question_router", "dm_router", "test_router"] +__all__ = ["login_router", "document_router", "theme_router", "question_router", "dm_router", "test_router", "user_router"] diff --git a/dsAiTeachingModel/tasks/BackgroundTasks.py b/dsAiTeachingModel/tasks/BackgroundTasks.py index e90bdc52..d43dc190 100644 --- a/dsAiTeachingModel/tasks/BackgroundTasks.py +++ b/dsAiTeachingModel/tasks/BackgroundTasks.py @@ -1,12 +1,52 @@ +import asyncio import logging import time +from utils.Database import * +from utils.DocxUtil import get_docx_content_by_pandoc +from utils.LightRagUtil import initialize_pg_rag + +# 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。 +WORKING_DIR = f"./output" + # 后台任务,监控是否有新的未训练的文档进行训练 -def train_document_task(): +async def train_document_task(): print("线程5秒后开始运行【监控是否有新的未训练的文档进行训练】") - time.sleep(5) # 线程5秒后开始运行 + await asyncio.sleep(5) # 使用 asyncio.sleep 而不是 time.sleep # 这里放置你的线程逻辑 while True: # 这里可以放置你的线程要执行的代码 - logging.info("线程正在运行") - time.sleep(1000) # 每隔10秒执行一次 + logging.info("开始查询是否有未训练的文档") + no_train_document_sql: str = " SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and train_flag = 0 ORDER BY create_time DESC" + no_train_document_result = await find_by_sql(no_train_document_sql, ()) + if not no_train_document_result: + logging.info("没有未训练的文档") + else: + logging.info("存在未训练的文档" + str(len(no_train_document_result))+"个") + # document = no_train_document_result[0] + # print("开始训练文档:" + document["document_name"]) + # theme = await find_by_id("t_ai_teaching_model_theme", "id", document["theme_id"]) + # # 训练开始前,更新训练状态 + # update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 1 WHERE id = " + str(document["id"]) + # execute_sql(update_sql) + # document_name = document["document_name"] + "." + document["document_suffix"] + # logging.info("开始训练文档:" + document_name) + # workspace = theme["short_name"] + # docx_name = document_name + # docx_path = document["document_path"] + # logging.info(f"开始处理文档:{docx_name}, 还有%s个文档需要处理!", len(no_train_document_result) - 1) + # # 训练代码开始 + # try: + # rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) + # # 获取docx文件的内容 + # content = get_docx_content_by_pandoc(docx_path) + # await rag.insert(input=content, file_paths=[docx_name]) + # finally: + # if rag: + # await rag.finalize_storages() + # # 训练结束,更新训练状态 + # update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 2 WHERE id = " + str(document["id"]) + # execute_sql(update_sql) + + # 添加适当的等待时间,避免频繁查询 + await asyncio.sleep(60) # 每分钟查询一次 diff --git a/dsAiTeachingModel/utils/Database.py b/dsAiTeachingModel/utils/Database.py index 85580029..4ac15243 100644 --- a/dsAiTeachingModel/utils/Database.py +++ b/dsAiTeachingModel/utils/Database.py @@ -204,4 +204,16 @@ async def delete_by_id(table_name, property_name, property_value): raise Exception(f"为表[{table_name}]删除数据失败: {e}") else: logging.error("参数不全") - return False \ No newline at end of file + return False + + +# 执行一个SQL语句 +async def execute_sql(sql): + logging.debug(sql) + try: + async with pool.acquire() as conn: + await conn.fetch(sql) + except Exception as e: + logging.error(f"数据库查询错误: {e}") + logging.error(f"执行的SQL语句: {sql}") + raise Exception(f"执行SQL失败: {e}") \ No newline at end of file diff --git a/dsAiTeachingModel/utils/DocxUtil.py b/dsAiTeachingModel/utils/DocxUtil.py index 82e26d2c..6c8051fc 100644 --- a/dsAiTeachingModel/utils/DocxUtil.py +++ b/dsAiTeachingModel/utils/DocxUtil.py @@ -1,8 +1,56 @@ +import logging import os import subprocess import uuid +from PIL import Image +import os + +# 在程序开始时添加以下配置 +logging.basicConfig( + level=logging.INFO, # 设置日志级别为INFO + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +# 或者如果你想更详细地控制日志输出 +logger = logging.getLogger('DocxUtil') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) +logger.addHandler(handler) +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) +def resize_images_in_directory(directory_path, max_width=640, max_height=480): + """ + 遍历目录下所有图片并缩放到指定尺寸 + :param directory_path: 图片目录路径 + :param max_width: 最大宽度 + :param max_height: 最大高度 + """ + # 支持的图片格式 + valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') + + for root, _, files in os.walk(directory_path): + for filename in files: + if filename.lower().endswith(valid_extensions): + file_path = os.path.join(root, filename) + try: + with Image.open(file_path) as img: + # 计算缩放比例 + width, height = img.size + ratio = min(max_width / width, max_height / height) + # 如果图片已经小于目标尺寸,则跳过 + if ratio >= 1: + continue + # 计算新尺寸并缩放 + new_size = (int(width * ratio), int(height * ratio)) + resized_img = img.resize(new_size, Image.Resampling.LANCZOS) + + # 保存图片(覆盖原文件) + resized_img.save(file_path) + logger.info(f"已缩放: {file_path} -> {new_size}") + except Exception as e: + logger.error(f"处理 {file_path} 时出错: {str(e)}") def get_docx_content_by_pandoc(docx_file): # 最后拼接的内容 content = "" @@ -15,6 +63,9 @@ def get_docx_content_by_pandoc(docx_file): os.mkdir("./static/Images/" + file_name) subprocess.run(['pandoc', docx_file, '-f', 'docx', '-t', 'markdown', '-o', temp_markdown, '--extract-media=./static/Images/' + file_name]) + # 遍历目录 './static/Images/'+file_name 下所有的图片,缩小于640*480的尺寸上 + + resize_images_in_directory('./static/Images/' + file_name+'/media') # 读取然后修改内容,输出到新的文件 img_idx = 0 # 图片索引 with open(temp_markdown, 'r', encoding='utf-8') as f: @@ -23,8 +74,9 @@ def get_docx_content_by_pandoc(docx_file): if not line: continue # 跳过图片高度描述行 - if line.startswith('height=') and line.endswith('in"}'): + if line.startswith('height=') and (line.endswith('in"}') or line.endswith('in"')): continue + # height="1.91044072615923in" # 使用find()方法安全地检查图片模式 is_img = line.find("![](") >= 0 and ( line.find(".png") > 0 or diff --git a/dsAiTeachingModel/utils/LightRagUtil.py b/dsAiTeachingModel/utils/LightRagUtil.py index 528f5963..e791c4a8 100644 --- a/dsAiTeachingModel/utils/LightRagUtil.py +++ b/dsAiTeachingModel/utils/LightRagUtil.py @@ -1,9 +1,7 @@ import logging import logging.config import os - import numpy as np - from lightrag import LightRAG from lightrag.kg.shared_storage import initialize_pipeline_status from lightrag.llm.openai import openai_complete_if_cache, openai_embed @@ -25,7 +23,7 @@ def configure_logging(): log_dir = os.getenv("LOG_DIR", os.getcwd()) log_file_path = os.path.abspath( - os.path.join(log_dir, "./logs/lightrag.log") + os.path.join(log_dir, "./Logs/lightrag.log") ) print(f"\nLightRAG log file: {log_file_path}\n") @@ -97,10 +95,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray: ) -async def initialize_rag(working_dir): +async def initialize_rag(working_dir, graph_storage=None): + if graph_storage is None: + graph_storage = 'NetworkXStorage' rag = LightRAG( working_dir=working_dir, llm_model_func=llm_model_func, + graph_storage=graph_storage, embedding_func=EmbeddingFunc( embedding_dim=EMBED_DIM, max_token_size=EMBED_MAX_TOKEN_SIZE, @@ -139,4 +140,40 @@ def create_embedding_func(): api_key=EMBED_API_KEY, base_url=EMBED_BASE_URL, ), - ) \ No newline at end of file + ) + + +# AGE +os.environ["AGE_GRAPH_NAME"] = AGE_GRAPH_NAME +os.environ["POSTGRES_HOST"] = POSTGRES_HOST +os.environ["POSTGRES_PORT"] = str(POSTGRES_PORT) +os.environ["POSTGRES_USER"] = POSTGRES_USER +os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD +os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE + + +async def initialize_pg_rag(WORKING_DIR, workspace='default'): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + llm_model_name=LLM_MODEL_NAME, + llm_model_max_async=4, + llm_model_max_token_size=32768, + enable_llm_cache_for_entity_extract=True, + embedding_func=EmbeddingFunc( + embedding_dim=EMBED_DIM, + max_token_size=EMBED_MAX_TOKEN_SIZE, + func=embedding_func + ), + kv_storage="PGKVStorage", + doc_status_storage="PGDocStatusStorage", + graph_storage="PGGraphStorage", + vector_storage="PGVectorStorage", + auto_manage_storages_states=False, + vector_db_storage_cls_kwargs={"workspace": workspace} + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag