diff --git a/dsAiTeachingModel/api/controller/DmController.py b/dsAiTeachingModel/api/controller/DmController.py index 41583878..2f719317 100644 --- a/dsAiTeachingModel/api/controller/DmController.py +++ b/dsAiTeachingModel/api/controller/DmController.py @@ -1,4 +1,4 @@ -# routes/LoginController.py +# routes/DmController.py from fastapi import APIRouter, Depends @@ -13,11 +13,11 @@ router = APIRouter(dependencies=[Depends(get_current_user)]) async def get_stage_subject_list(): # 先查询学段list select_stage_sql: str = "select stage_id, stage_name from t_dm_stage where b_use = 1 order by sort_id;" - stage_list = await find_by_sql(select_stage_sql, ()) + stage_list = await find_by_sql(select_stage_sql,()) for stage in stage_list: # 再查询学科list - select_subject_sql: str = "select subject_id, subject_name from t_dm_subject where stage_id = %s order by sort_id;" - subject_list = await find_by_sql(select_subject_sql, (stage["stage_id"],)) + select_subject_sql: str = "select subject_id, subject_name from t_dm_subject where stage_id = " + str(stage["stage_id"]) + " order by sort_id;" + subject_list = await find_by_sql(select_subject_sql,()) stage["subject_list"] = subject_list return {"success": True, "message": "成功!", "data": stage_list} diff --git a/dsAiTeachingModel/api/controller/DocumentController.py b/dsAiTeachingModel/api/controller/DocumentController.py index 34576442..bd0e2998 100644 --- a/dsAiTeachingModel/api/controller/DocumentController.py +++ b/dsAiTeachingModel/api/controller/DocumentController.py @@ -1,13 +1,132 @@ -# routes/LoginController.py +# routes/DocumentController.py +import os -from fastapi import APIRouter, Request, Response, Depends +from fastapi import APIRouter, Request, Response, Depends, UploadFile, File from auth.dependencies import get_current_user +from utils.PageUtil import * +from utils.ParseRequest import * # 创建一个路由实例,需要依赖get_current_user,登录后才能访问 router = APIRouter(dependencies=[Depends(get_current_user)]) +# 创建上传文件的目录 +UPLOAD_DIR = "upload_file" +if not os.path.exists(UPLOAD_DIR): + os.makedirs(UPLOAD_DIR) -@router.get("/") -async def test(request: Request, response: Response): - return {"success": True, "message": "成功!"} +# 合法文件扩展名 +supported_suffix_types = ['doc', 'docx', 'ppt', 'pptx', 'xls', 'xlsx'] + +# 【Document-1】文档管理列表 +@router.get("/list") +async def list(request: Request): + # 获取参数 + person_id = await get_request_str_param(request, "person_id", True, True) + stage_id = await get_request_num_param(request, "stage_id", False, True, -1) + subject_id = await get_request_num_param(request, "subject_id", False, True, -1) + document_suffix = await get_request_str_param(request, "document_suffix", False, True) + document_name = await get_request_str_param(request, "document_name", False, True) + page_number = await get_request_num_param(request, "page_number", False, True, 1) + page_size = await get_request_num_param(request, "page_size", False, True, 10) + + print(person_id, stage_id, subject_id, document_suffix, document_name, page_number, page_size) + + # 拼接查询SQL语句 + + select_document_sql: str = " SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and person_id = '" + person_id + "'" + if stage_id != -1: + select_document_sql += " AND stage_id = " + str(stage_id) + if subject_id != -1: + select_document_sql += " AND subject_id = " + str(subject_id) + if document_suffix != "": + select_document_sql += " AND document_suffix = '" + document_suffix + "'" + if document_name != "": + select_document_sql += " AND document_name = '" + document_name + "'" + select_document_sql += " ORDER BY create_time DESC " + + # 查询文档列表 + page = await get_page_data_by_sql(select_document_sql, page_number, page_size) + for item in page["list"]: + theme_info = await find_by_id("t_ai_teaching_model_theme", "id", item["theme_id"]) + item["theme_info"] = theme_info + + return {"success": True, "message": "查询成功!", "data": page} + + +# 【Document-2】保存文档管理 +@router.post("/save") +async def save(request: Request, file: UploadFile = File(...)): + # 获取参数 + id = await get_request_num_param(request, "id", False, True, 0) + stage_id = await get_request_num_param(request, "stage_id", False, True, -1) + subject_id = await get_request_num_param(request, "subject_id", False, True, -1) + theme_id = await get_request_num_param(request, "theme_id", True, True, None) + person_id = await get_request_str_param(request, "person_id", True, True) + bureau_id = await get_request_str_param(request, "bureau_id", True, True) + # 先获取theme主题信息 + theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id) + if theme_object is None: + return {"success": False, "message": "主题不存在!"} + # 获取文件名 + document_name = file.filename + # 检查文件名在该主题下是否重复 + select_theme_document_sql: str = "SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and document_name = '" + document_name + "'" + if id != 0: + select_theme_document_sql += " AND id <> " + id + theme_document = await find_by_sql(select_theme_document_sql, ()) + if theme_document is not None: + return {"success": False, "message": "该主题下文档名称重复!"} + # 获取文件扩展名 + document_suffix = file.filename.split(".")[-1] + # 检查文件扩展名 + if document_suffix not in supported_suffix_types: + return {"success": False, "message": "不支持的文件类型!"} + # 构造文件保存路径 + document_dir = UPLOAD_DIR + os.sep + str(theme_object["short_name"]) + "_" + str(theme_object["id"]) + os.sep + if not os.path.exists(document_dir): + os.makedirs(document_dir) + document_path = os.path.join(document_dir, file.filename) + # 保存文件 + try: + with open(document_path, "wb") as buffer: + buffer.write(await file.read()) + except Exception as e: + return {"success": False, "message": f"文件保存失败!{e}"} + + # 构造保存文档SQL语句 + param = {"stage_id": stage_id, "subject_id": subject_id, "document_name": document_name, "theme_id": theme_id, "document_path": document_path, "document_suffix": document_suffix, "person_id": person_id, "bureau_id": bureau_id} + + # 保存数据 + if id == 0: + param["train_flag"] = 0 + # 插入数据 + id = await insert("t_ai_teaching_model_document", param, False) + return {"success": True, "message": "保存成功!", "data": {"insert_id" : id}} + else: + # 更新数据 + await update("t_ai_teaching_model_document", param, "id", id) + return {"success": True, "message": "更新成功!", "data": {"update_id" : id}} + +# 【Document-3】获取文档信息 +@router.get("/get") +async def get(request: Request): + # 获取参数 + id = await get_request_num_param(request, "id", True, True, None) + # 查询数据 + document_object = await find_by_id("t_ai_teaching_model_document", "id", id) + if document_object is None: + return {"success": False, "message": "未查询到该文档信息!"} + theme_info = await find_by_id("t_ai_teaching_model_theme", "id", document_object["theme_id"]) + document_object["theme_info"] = theme_info + return {"success": True, "message": "查询成功!", "data": {"document": document_object}} + + +@router.post("/delete") +async def delete(request: Request): + # 获取参数 + id = await get_request_num_param(request, "id", True, True, None) + result = await delete_by_id("t_ai_teaching_model_document", "id", id) + if not result: + return {"success": False, "message": "删除失败!"} + return {"success": True, "message": "删除成功!"} \ No newline at end of file diff --git a/dsAiTeachingModel/api/controller/LoginController.py b/dsAiTeachingModel/api/controller/LoginController.py index 4a004d45..b3368d35 100644 --- a/dsAiTeachingModel/api/controller/LoginController.py +++ b/dsAiTeachingModel/api/controller/LoginController.py @@ -13,7 +13,7 @@ from utils.CookieUtil import * from utils.Database import * from utils.JwtUtil import * from utils.ParseRequest import * -from config.Config import * +from Config.Config import * # 创建一个路由实例 router = APIRouter() @@ -108,8 +108,9 @@ async def login(request: Request, response: Response): return {"success": False, "message": "用户名和密码不能为空"} password = md5_encrypt(password) - select_user_sql: str = "SELECT person_id, person_name, identity_id, login_name, xb, bureau_id, org_id, pwdmd5 FROM t_sys_loginperson WHERE login_name = %s AND b_use = 1" - user = await find_one_by_sql(select_user_sql, (username,)) + select_user_sql: str = "SELECT person_id, person_name, identity_id, login_name, xb, bureau_id, org_id, pwdmd5 FROM t_sys_loginperson WHERE login_name = '" + username + "' AND b_use = 1" + userlist = await find_by_sql(select_user_sql,()) + user = userlist[0] if userlist else None logging.info(f"查询结果: {user}") if user and user['pwdmd5'] == password: # 验证的cas用户密码,md5加密的版本 token = create_access_token({"user_id": user['person_id'], "identity_id": user['identity_id']}) @@ -128,3 +129,31 @@ async def login(request: Request, response: Response): else: return {"success": False, "message": "用户名或密码错误"} + +# 【Base-Login-3】通过手机号获取Person的ID +@router.get("/getPersonIdByTelephone") +async def get_person_id_by_telephone(request: Request): + telephone = await get_request_str_param(request, "telephone", True, True) + if not telephone: + return {"success": False, "message": "手机号不能为空"} + select_user_sql: str = "SELECT person_id FROM t_sys_loginperson WHERE telephone = '" + telephone + "' and b_use = 1 " + userlist = await find_by_sql(select_user_sql,()) + user = userlist[0] if userlist else None + if user: + return {"success": True, "message": "查询成功", "data": {"person_id": user['person_id']}} + else: + return {"success": False, "message": "未查询到相关信息"} + + + +# 【Base-Login-4】忘记密码重设,不登录的状态 +@router.post("/resetPassword") +async def reset_password(request: Request): + person_id = await get_request_str_param(request, "person_id", True, True) + password = await get_request_str_param(request, "password", True, True) + if not person_id or not password: + return {"success": False, "message": "用户ID和新密码不能为空"} + password_md5 = md5_encrypt(password) + update_user_sql: str = "UPDATE t_sys_loginperson SET original_pwd = '" + password + "', pwdmd5 = '" + password_md5 + "' WHERE person_id = '" + person_id + "'" + await execute_sql(update_user_sql) + return {"success": True, "message": "密码修改成功"} \ No newline at end of file diff --git a/dsAiTeachingModel/api/controller/QuestionController.py b/dsAiTeachingModel/api/controller/QuestionController.py index 89456bc5..48b7ed39 100644 --- a/dsAiTeachingModel/api/controller/QuestionController.py +++ b/dsAiTeachingModel/api/controller/QuestionController.py @@ -1,4 +1,4 @@ -# routes/LoginController.py +# routes/QuestionController.py from fastapi import APIRouter, Request, Response, Depends from auth.dependencies import * diff --git a/dsAiTeachingModel/api/controller/TestController.py b/dsAiTeachingModel/api/controller/TestController.py index 4a572ff3..5c6a8ed5 100644 --- a/dsAiTeachingModel/api/controller/TestController.py +++ b/dsAiTeachingModel/api/controller/TestController.py @@ -1,4 +1,4 @@ -# routes/LoginController.py +# routes/TestController.py from fastapi import APIRouter, Request diff --git a/dsAiTeachingModel/api/controller/ThemeController.py b/dsAiTeachingModel/api/controller/ThemeController.py index 26903275..3bd9fcd5 100644 --- a/dsAiTeachingModel/api/controller/ThemeController.py +++ b/dsAiTeachingModel/api/controller/ThemeController.py @@ -1,14 +1,17 @@ -# routes/LoginController.py +# routes/ThemeController.py from fastapi import APIRouter, Depends from utils.ParseRequest import * from auth.dependencies import * -from utils.Database import * +from utils.PageUtil import * # 创建一个路由实例,需要依赖get_current_user,登录后才能访问 router = APIRouter(dependencies=[Depends(get_current_user)]) - +# 功能:【Theme-1】主题管理列表 +# 作者:Kalman.CHENG ☆ +# 时间:2025-07-14 +# 备注: @router.get("/list") async def list(request: Request): # 获取参数 @@ -24,9 +27,9 @@ async def list(request: Request): # 拼接查询SQL语句 select_theme_sql: str = " SELECT * FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and person_id = '" + person_id + "'" if stage_id != -1: - select_theme_sql += " and stage_id = " + stage_id + select_theme_sql += " and stage_id = " + str(stage_id) if subject_id != -1: - select_theme_sql += " and subject_id = " + subject_id + select_theme_sql += " and subject_id = " + str(subject_id) if theme_name != "": select_theme_sql += " and theme_name = '" + theme_name + "'" select_theme_sql += " ORDER BY create_time DESC" @@ -37,16 +40,76 @@ async def list(request: Request): return {"success": True, "message": "查询成功!", "data": page} +# 功能:【Theme-2】保存主题管理 +# 作者:Kalman.CHENG ☆ +# 时间:2025-07-14 +# 备注: @router.post("/save") async def save(request: Request): # 获取参数 id = await get_request_num_param(request, "id", False, True, 0) theme_name = await get_request_str_param(request, "theme_name", True, True) + short_name = await get_request_str_param(request, "short_name", True, True) theme_icon = await get_request_str_param(request, "theme_icon", False, True) stage_id = await get_request_num_param(request, "stage_id", True, True, None) subject_id = await get_request_num_param(request, "subject_id", True, True, None) person_id = await get_request_str_param(request, "person_id", True, True) bureau_id = await get_request_str_param(request, "bureau_id", True, True) - # 业务逻辑处理 + + # 校验参数 + check_theme_sql = "SELECT theme_name FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and bureau_id = '" + bureau_id + "' and theme_name = '" + theme_name + "'" + if id != 0: + check_theme_sql += " and id <> " + id + print(check_theme_sql) + check_theme_result = await find_by_sql(check_theme_sql,()) + if check_theme_result: + return {"success": False, "message": "该主题名称已存在!"} + + check_short_name_sql = "SELECT short_name FROM t_ai_teaching_model_theme WHERE is_deleted = 0 and bureau_id = '" + bureau_id + "' and short_name = '" + short_name + "'" + if id != 0: + check_short_name_sql += " and id <> " + id + print(check_short_name_sql) + check_short_name_result = await find_by_sql(check_short_name_sql,()) + if check_short_name_result: + return {"success": False, "message": "该主题英文简称已存在!"} + + # 组装参数 + param = {"theme_name": theme_name,"short_name": short_name,"theme_icon": theme_icon,"stage_id": stage_id,"subject_id": subject_id,"person_id": person_id,"bureau_id": bureau_id} + + # 保存数据 + if id == 0: + param["search_flag"] = 0 + param["train_flag"] = 0 + # 插入数据 + id = await insert("t_ai_teaching_model_theme", param, False) + return {"success": True, "message": "保存成功!", "data": {"insert_id" : id}} + else: + # 更新数据 + await update("t_ai_teaching_model_theme", param, "id", id, False) + return {"success": True, "message": "更新成功!", "data": {"update_id" : id}} + + +# 功能:【Theme-3】获取主题信息 +# 作者:Kalman.CHENG ☆ +# 时间:2025-07-14 +# 备注: +@router.get("/get") +async def get(request: Request): + # 获取参数 + id = await get_request_num_param(request, "id", True, True, None) + theme_obj = await find_by_id("t_ai_teaching_model_theme", "id", id) + if theme_obj is None: + return {"success": False, "message": "未查询到该主题信息!"} + return {"success": True, "message": "查询成功!", "data": {"theme": theme_obj}} + + +@router.post("/delete") +async def delete(request: Request): + # 获取参数 + id = await get_request_num_param(request, "id", True, True, None) + result = await delete_by_id("t_ai_teaching_model_theme", "id", id) + if not result: + return {"success": False, "message": "删除失败!"} + return {"success": True, "message": "删除成功!"} diff --git a/dsAiTeachingModel/api/controller/UserController.py b/dsAiTeachingModel/api/controller/UserController.py new file mode 100644 index 00000000..2b1d1a3f --- /dev/null +++ b/dsAiTeachingModel/api/controller/UserController.py @@ -0,0 +1,50 @@ +# routes/UserController.py +import re + +from fastapi import APIRouter, Request, Response, Depends +from auth.dependencies import * +from utils.CommonUtil import md5_encrypt +from utils.Database import * +from utils.ParseRequest import * + +# 创建一个路由实例,需要依赖get_current_user,登录后才能访问 +router = APIRouter(dependencies=[Depends(get_current_user)]) + +# 【Base-User-1】维护用户手机号 +@router.post("/modifyTelephone") +async def modify_telephone(request: Request): + person_id = await get_request_str_param(request, "person_id", True, True) + telephone = await get_request_str_param(request, "telephone", True, True) + # 校验手机号码格式 + if not re.match(r"^1[3-9]\d{9}$", telephone): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="手机号码格式错误") + # 校验手机号码是否已被注册 + select_telephone_sql: str = "select * from t_sys_loginperson where b_use = 1 and telephone = '" + telephone + "' and person_id <> '" + person_id + "'" + userlist = await find_by_sql(select_telephone_sql, ()) + if userlist is not None: + return {"success": False, "message": "手机号码已被注册"} + else: + update_telephone_sql: str = "update t_sys_loginperson set telephone = '" + telephone + "' where person_id = '" + person_id + "'" + await execute_sql(update_telephone_sql) + return {"success": True, "message": "修改成功"} + + +# 【Base-User-2】维护用户密码 +@router.post("/modifyPassword") +async def modify_password(request: Request): + person_id = await get_request_str_param(request, "person_id", True, True) + old_password = await get_request_str_param(request, "old_password", True, True) + password = await get_request_str_param(request, "password", True, True) + # 校验旧密码是否正确 + select_password_sql: str = "select pwdmd5 from t_sys_loginperson where person_id = '" + person_id + "' and b_use = 1" + userlist = await find_by_sql(select_password_sql, ()) + if len(userlist) == 0: + return {"success": False, "message": "用户不存在"} + else: + if userlist[0]["pwdmd5"] != md5_encrypt(old_password): + return {"success": False, "message": "旧密码错误"} + else: + update_password_sql: str = "update t_sys_loginperson set original_pwd = '" + password + "',pwdmd5 = '" + md5_encrypt(password) + "' where person_id = '" + person_id + "'" + await execute_sql(update_password_sql) + return {"success": True, "message": "修改成功"} + diff --git a/dsAiTeachingModel/config/Config.py b/dsAiTeachingModel/config/Config.py index 3d4a460b..1b4ca3f3 100644 --- a/dsAiTeachingModel/config/Config.py +++ b/dsAiTeachingModel/config/Config.py @@ -1,13 +1,18 @@ -# 大模型 【DeepSeek深度求索官方】 -#LLM_API_KEY = "sk-44ae895eeb614aa1a9c6460579e322f1" -#LLM_BASE_URL = "https://api.deepseek.com" -#LLM_MODEL_NAME = "deepseek-chat" +# 阿里云的配置信息 +ALY_AK = 'LTAI5tE4tgpGcKWhbZg6C4bh' +ALY_SK = 'oizcTOZ8izbGUouboC00RcmGE8vBQ1' -# 阿里云提供的大模型服务 -LLM_API_KEY="sk-f6da0c787eff4b0389e4ad03a35a911f" +# 大模型 【DeepSeek深度求索官方】训练时用这个 +# LLM_API_KEY = "sk-44ae895eeb614aa1a9c6460579e322f1" +# LLM_BASE_URL = "https://api.deepseek.com" +# LLM_MODEL_NAME = "deepseek-chat" + +# 阿里云提供的大模型服务 【阿里云在处理文字材料时,容易引发绿网拦截,导致数据上报异常】 +LLM_API_KEY = "sk-f6da0c787eff4b0389e4ad03a35a911f" LLM_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" #LLM_MODEL_NAME = "qwen-plus" # 不要使用通义千问,会导致化学方程式不正确! LLM_MODEL_NAME = "deepseek-v3" +#LLM_MODEL_NAME = "deepseek-r1" # 使用更牛B的r1模型 EMBED_MODEL_NAME = "BAAI/bge-m3" EMBED_API_KEY = "sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl" @@ -15,21 +20,20 @@ EMBED_BASE_URL = "https://api.siliconflow.cn/v1" EMBED_DIM = 1024 EMBED_MAX_TOKEN_SIZE = 8192 - NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = "DsideaL147258369" +NEO4J_AUTH = (NEO4J_USERNAME, NEO4J_PASSWORD) - -# MYSQL配置信息 -MYSQL_HOST = "127.0.0.1" -MYSQL_PORT = 22066 -MYSQL_USER = "root" -MYSQL_PASSWORD = "DsideaL147258369" -MYSQL_DB_NAME = "base_db" -MYSQL_POOL_SIZE = 200 +# POSTGRESQL配置信息 +AGE_GRAPH_NAME = "dickens" +POSTGRES_HOST = "10.10.14.208" +POSTGRES_PORT = 5432 +POSTGRES_USER = "postgres" +POSTGRES_PASSWORD = "postgres" +POSTGRES_DATABASE = "rag" # JWT配置信息 JWT_SECRET_KEY = "ZXZnZWVr5b+r5LmQ5L2g55qE5Ye66KGM" ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 300000 # 访问令牌过期时间(分钟) +ACCESS_TOKEN_EXPIRE_MINUTES = 300000 # 访问令牌过期时间(分钟) \ No newline at end of file diff --git a/dsAiTeachingModel/main.py b/dsAiTeachingModel/main.py index 18a578f1..8f99a901 100644 --- a/dsAiTeachingModel/main.py +++ b/dsAiTeachingModel/main.py @@ -1,6 +1,7 @@ import threading -import logging + import uvicorn +import asyncio from fastapi.middleware.cors import CORSMiddleware from starlette.staticfiles import StaticFiles @@ -18,11 +19,12 @@ logging.basicConfig( ) async def lifespan(app: FastAPI): - # 启动线程 - thread = threading.Thread(target=train_document_task, daemon=True) - thread.start() # 创建数据库连接池 await init_database() + + # 启动异步任务 + asyncio.create_task(train_document_task()) + yield await shutdown_database() @@ -41,8 +43,10 @@ app.add_middleware( app.mount("/static", StaticFiles(directory="Static"), name="static") # 注册路由 -# 登录相关 +# 登录相关(不用登录) app.include_router(login_router, prefix="/api/login", tags=["login"]) +# 用户相关 +app.include_router(user_router, prefix="/api/user", tags=["user"]) # 主题相关 app.include_router(theme_router, prefix="/api/theme", tags=["theme"]) # 文档相关 diff --git a/dsAiTeachingModel/routes/__init__.py b/dsAiTeachingModel/routes/__init__.py index 5bde8674..4fa720b9 100644 --- a/dsAiTeachingModel/routes/__init__.py +++ b/dsAiTeachingModel/routes/__init__.py @@ -5,6 +5,7 @@ from api.controller.ThemeController import router as theme_router from api.controller.QuestionController import router as question_router from api.controller.TestController import router as test_router from api.controller.DmController import router as dm_router +from api.controller.UserController import router as user_router # 导出所有路由 -__all__ = ["login_router", "document_router", "theme_router", "question_router", "dm_router", "test_router"] +__all__ = ["login_router", "document_router", "theme_router", "question_router", "dm_router", "test_router", "user_router"] diff --git a/dsAiTeachingModel/tasks/BackgroundTasks.py b/dsAiTeachingModel/tasks/BackgroundTasks.py index e90bdc52..d43dc190 100644 --- a/dsAiTeachingModel/tasks/BackgroundTasks.py +++ b/dsAiTeachingModel/tasks/BackgroundTasks.py @@ -1,12 +1,52 @@ +import asyncio import logging import time +from utils.Database import * +from utils.DocxUtil import get_docx_content_by_pandoc +from utils.LightRagUtil import initialize_pg_rag + +# 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。 +WORKING_DIR = f"./output" + # 后台任务,监控是否有新的未训练的文档进行训练 -def train_document_task(): +async def train_document_task(): print("线程5秒后开始运行【监控是否有新的未训练的文档进行训练】") - time.sleep(5) # 线程5秒后开始运行 + await asyncio.sleep(5) # 使用 asyncio.sleep 而不是 time.sleep # 这里放置你的线程逻辑 while True: # 这里可以放置你的线程要执行的代码 - logging.info("线程正在运行") - time.sleep(1000) # 每隔10秒执行一次 + logging.info("开始查询是否有未训练的文档") + no_train_document_sql: str = " SELECT * FROM t_ai_teaching_model_document WHERE is_deleted = 0 and train_flag = 0 ORDER BY create_time DESC" + no_train_document_result = await find_by_sql(no_train_document_sql, ()) + if not no_train_document_result: + logging.info("没有未训练的文档") + else: + logging.info("存在未训练的文档" + str(len(no_train_document_result))+"个") + # document = no_train_document_result[0] + # print("开始训练文档:" + document["document_name"]) + # theme = await find_by_id("t_ai_teaching_model_theme", "id", document["theme_id"]) + # # 训练开始前,更新训练状态 + # update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 1 WHERE id = " + str(document["id"]) + # execute_sql(update_sql) + # document_name = document["document_name"] + "." + document["document_suffix"] + # logging.info("开始训练文档:" + document_name) + # workspace = theme["short_name"] + # docx_name = document_name + # docx_path = document["document_path"] + # logging.info(f"开始处理文档:{docx_name}, 还有%s个文档需要处理!", len(no_train_document_result) - 1) + # # 训练代码开始 + # try: + # rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) + # # 获取docx文件的内容 + # content = get_docx_content_by_pandoc(docx_path) + # await rag.insert(input=content, file_paths=[docx_name]) + # finally: + # if rag: + # await rag.finalize_storages() + # # 训练结束,更新训练状态 + # update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 2 WHERE id = " + str(document["id"]) + # execute_sql(update_sql) + + # 添加适当的等待时间,避免频繁查询 + await asyncio.sleep(60) # 每分钟查询一次 diff --git a/dsAiTeachingModel/utils/Database.py b/dsAiTeachingModel/utils/Database.py index db606938..4ac15243 100644 --- a/dsAiTeachingModel/utils/Database.py +++ b/dsAiTeachingModel/utils/Database.py @@ -1,25 +1,23 @@ # Database.py +import datetime import logging -import math +import asyncpg -import aiomysql -import asyncio -from config.Config import * +from Config.Config import * # 创建一个全局的连接池 pool = None -async def create_pool(loop): +async def create_pool(): global pool - pool = await aiomysql.create_pool( - host=MYSQL_HOST, - port=MYSQL_PORT, - user=MYSQL_USER, - password=MYSQL_PASSWORD, - db=MYSQL_DB_NAME, - minsize=1, # 设置连接池最小连接数 - maxsize=MYSQL_POOL_SIZE, # 设置连接池最大连接数 - cursorclass=aiomysql.DictCursor # 指定游标为字典模式 + pool = await asyncpg.create_pool( + host=POSTGRES_HOST, + port=POSTGRES_PORT, + user=POSTGRES_USER, + password=POSTGRES_PASSWORD, + database=POSTGRES_DATABASE, + min_size=1, # 设置连接池最小连接数 + max_size=100 # 设置连接池最大连接数 ) async def get_connection(): @@ -30,18 +28,17 @@ async def get_connection(): async def close_pool(): if pool is not None: - pool.close() - await pool.wait_closed() + await pool.close() # 初始化连接池的函数 async def init_database(): - loop = asyncio.get_event_loop() - await create_pool(loop) + await create_pool() # 关闭连接池的函数 async def shutdown_database(): await close_pool() + # 根据sql语句查询数据 async def find_by_sql(sql: str, params: tuple): if pool is None: @@ -49,79 +46,174 @@ async def find_by_sql(sql: str, params: tuple): return None try: async with pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute(sql, params) - result = await cur.fetchall() - if result: - return result - else: - return None + result = await conn.fetch(sql, *params) + # 将 asyncpg.Record 转换为字典 + result_dict = [dict(record) for record in result] + if result_dict: + return result_dict + else: + return None except Exception as e: logging.error(f"数据库查询错误: {e}") return None +# 插入数据 +async def insert(tableName, param, onlyForParam=False): + current_time = datetime.datetime.now() + columns = [] + values = [] + placeholders = [] + + for key, value in param.items(): + if value is not None: + if isinstance(value, (int, float)): + columns.append(key) + values.append(value) + placeholders.append(f"${len(values)}") + elif isinstance(value, str): + columns.append(key) + values.append(value) + placeholders.append(f"${len(values)}") + else: + columns.append(key) + values.append(None) + placeholders.append("NULL") + + if not onlyForParam: + if 'is_deleted' not in param: + columns.append("is_deleted") + values.append(0) + placeholders.append(f"${len(values)}") + + if 'create_time' not in param: + columns.append("create_time") + values.append(current_time) + placeholders.append(f"${len(values)}") + + if 'update_time' not in param: + columns.append("update_time") + values.append(current_time) + placeholders.append(f"${len(values)}") + + # 构造 SQL 语句 + column_names = ", ".join(columns) + placeholder_names = ", ".join(placeholders) + sql = f"INSERT INTO {tableName} ({column_names}) VALUES ({placeholder_names}) RETURNING id" -# 根据sql语句查询数据 -async def find_one_by_sql(sql: str, params: tuple): - if pool is None: - logging.error("数据库连接池未创建") - return None try: async with pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute(sql, params) - result = await cur.fetchone() + result = await conn.fetchrow(sql, *values) if result: - return result + return result['id'] else: + logging.error("插入数据失败: 未返回ID") return None except Exception as e: logging.error(f"数据库查询错误: {e}") - return None + logging.error(f"执行的SQL语句: {sql}") + logging.error(f"参数: {values}") + raise Exception(f"为表[{tableName}]插入数据失败: {e}") + + +# 更新数据 +async def update(table_name, param, property_name, property_value, only_for_param=False): + current_time = datetime.datetime.now() + set_clauses = [] + values = [] + + # 处理要更新的参数 + for key, value in param.items(): + if value is not None: + if isinstance(value, (int, float)): + set_clauses.append(f"{key} = ${len(values) + 1}") + values.append(value) + elif isinstance(value, str): + set_clauses.append(f"{key} = ${len(values) + 1}") + values.append(value) + else: + set_clauses.append(f"{key} = NULL") + values.append(None) + + if not only_for_param: + if 'update_time' not in param: + set_clauses.append(f"update_time = ${len(values) + 1}") + values.append(current_time) -# 查询数据条数 -async def get_total_data_count(total_data_sql): - total_data_count = 0 - total_data_count_sql = "select count(1) as count from (" + total_data_sql + ") as temp_table" - result = await find_one_by_sql(total_data_count_sql, ()) - if result: - total_data_count = result.get("count") - return total_data_count + # 构造 SQL 语句 + set_clause = ", ".join(set_clauses) + sql = f"UPDATE {table_name} SET {set_clause} WHERE {property_name} = ${len(values) + 1} RETURNING id" + print(sql) + + # 添加条件参数 + values.append(property_value) + + try: + async with pool.acquire() as conn: + result = await conn.fetchrow(sql, *values) + if result: + return result['id'] + else: + logging.error("更新数据失败: 未返回ID") + return None + except Exception as e: + logging.error(f"数据库查询错误: {e}") + logging.error(f"执行的SQL语句: {sql}") + logging.error(f"参数: {values}") + raise Exception(f"为表[{table_name}]更新数据失败: {e}") -def get_page_by_total_row(total_data_count, page_number, page_size): - total_page = (page_size != 0) and math.floor((total_data_count + page_size - 1) / page_size) or 0 - if page_number <= 0: - page_number = 1 - if 0 < total_page < page_number: - page_number = total_page - offset = page_size * page_number - page_size - limit = page_size - return total_data_count, total_page, offset, limit +# 获取Bean +# 通过主键查询 +async def find_by_id(table_name, property_name, property_value): + if table_name and property_name and property_value is not None: + # 构造 SQL 语句 + sql = f"SELECT * FROM {table_name} WHERE is_deleted = 0 AND {property_name} = $1" + logging.debug(sql) -async def get_page_data_by_sql(total_data_sql: str, page_number: int, page_size: int): - if pool is None: - logging.error("数据库连接池未创建") - return None - total_row: int = 0 - total_page: int = 0 - total_data_sql = total_data_sql.replace(";", "") - total_data_sql = total_data_sql.replace(" FROM ", " from ") - - # 查询总数 - total_data_count = await get_total_data_count(total_data_sql) - if total_data_count == 0: - return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []} + # 执行查询 + result = await find_by_sql(sql, (property_value,)) + if not result: + logging.error("查询失败: 未找到数据") + return None + # 返回第一条数据 + return result[0] else: - total_row, total_page, offset, limit = get_page_by_total_row(total_data_count, page_number, page_size) - - # 构造执行分页查询的sql语句 - page_data_sql = total_data_sql + " LIMIT %d, %d " % (offset, limit) - print(page_data_sql) - # 执行分页查询 - page_data = await find_by_sql(page_data_sql, ()) - if page_data: - return {"page_number": page_number, "page_size": page_size, "total_row": total_row, "total_page": total_page, "list": page_data} + logging.error("参数不全") + return None + +# 通过主键删除 +# 逻辑删除 +async def delete_by_id(table_name, property_name, property_value): + if table_name and property_name and property_value is not None: + sql = f"UPDATE {table_name} SET is_deleted = 1, update_time = now() WHERE {property_name} = $1 and is_deleted = 0" + logging.debug(sql) + # 执行删除 + try: + async with pool.acquire() as conn: + result = await conn.execute(sql, property_value) + if result: + return True + else: + logging.error("删除失败: 未找到数据") + return False + except Exception as e: + logging.error(f"数据库查询错误: {e}") + logging.error(f"执行的SQL语句: {sql}") + logging.error(f"参数: {property_value}") + raise Exception(f"为表[{table_name}]删除数据失败: {e}") else: - return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []} + logging.error("参数不全") + return False + + +# 执行一个SQL语句 +async def execute_sql(sql): + logging.debug(sql) + try: + async with pool.acquire() as conn: + await conn.fetch(sql) + except Exception as e: + logging.error(f"数据库查询错误: {e}") + logging.error(f"执行的SQL语句: {sql}") + raise Exception(f"执行SQL失败: {e}") \ No newline at end of file diff --git a/dsAiTeachingModel/utils/DocxUtil.py b/dsAiTeachingModel/utils/DocxUtil.py index 82e26d2c..6c8051fc 100644 --- a/dsAiTeachingModel/utils/DocxUtil.py +++ b/dsAiTeachingModel/utils/DocxUtil.py @@ -1,8 +1,56 @@ +import logging import os import subprocess import uuid +from PIL import Image +import os + +# 在程序开始时添加以下配置 +logging.basicConfig( + level=logging.INFO, # 设置日志级别为INFO + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +# 或者如果你想更详细地控制日志输出 +logger = logging.getLogger('DocxUtil') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) +logger.addHandler(handler) +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) +def resize_images_in_directory(directory_path, max_width=640, max_height=480): + """ + 遍历目录下所有图片并缩放到指定尺寸 + :param directory_path: 图片目录路径 + :param max_width: 最大宽度 + :param max_height: 最大高度 + """ + # 支持的图片格式 + valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif') + + for root, _, files in os.walk(directory_path): + for filename in files: + if filename.lower().endswith(valid_extensions): + file_path = os.path.join(root, filename) + try: + with Image.open(file_path) as img: + # 计算缩放比例 + width, height = img.size + ratio = min(max_width / width, max_height / height) + # 如果图片已经小于目标尺寸,则跳过 + if ratio >= 1: + continue + # 计算新尺寸并缩放 + new_size = (int(width * ratio), int(height * ratio)) + resized_img = img.resize(new_size, Image.Resampling.LANCZOS) + + # 保存图片(覆盖原文件) + resized_img.save(file_path) + logger.info(f"已缩放: {file_path} -> {new_size}") + except Exception as e: + logger.error(f"处理 {file_path} 时出错: {str(e)}") def get_docx_content_by_pandoc(docx_file): # 最后拼接的内容 content = "" @@ -15,6 +63,9 @@ def get_docx_content_by_pandoc(docx_file): os.mkdir("./static/Images/" + file_name) subprocess.run(['pandoc', docx_file, '-f', 'docx', '-t', 'markdown', '-o', temp_markdown, '--extract-media=./static/Images/' + file_name]) + # 遍历目录 './static/Images/'+file_name 下所有的图片,缩小于640*480的尺寸上 + + resize_images_in_directory('./static/Images/' + file_name+'/media') # 读取然后修改内容,输出到新的文件 img_idx = 0 # 图片索引 with open(temp_markdown, 'r', encoding='utf-8') as f: @@ -23,8 +74,9 @@ def get_docx_content_by_pandoc(docx_file): if not line: continue # 跳过图片高度描述行 - if line.startswith('height=') and line.endswith('in"}'): + if line.startswith('height=') and (line.endswith('in"}') or line.endswith('in"')): continue + # height="1.91044072615923in" # 使用find()方法安全地检查图片模式 is_img = line.find("![](") >= 0 and ( line.find(".png") > 0 or diff --git a/dsAiTeachingModel/utils/JwtUtil.py b/dsAiTeachingModel/utils/JwtUtil.py index 4118a695..90b30808 100644 --- a/dsAiTeachingModel/utils/JwtUtil.py +++ b/dsAiTeachingModel/utils/JwtUtil.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta from jose import JWTError, jwt -from config.Config import * +from Config.Config import * def create_access_token(data: dict): diff --git a/dsAiTeachingModel/utils/LightRagUtil.py b/dsAiTeachingModel/utils/LightRagUtil.py index 4b038c1d..e791c4a8 100644 --- a/dsAiTeachingModel/utils/LightRagUtil.py +++ b/dsAiTeachingModel/utils/LightRagUtil.py @@ -1,14 +1,12 @@ import logging import logging.config import os - import numpy as np - from lightrag import LightRAG from lightrag.kg.shared_storage import initialize_pipeline_status from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug -from config.Config import * +from Config.Config import * async def print_stream(stream): @@ -25,7 +23,7 @@ def configure_logging(): log_dir = os.getenv("LOG_DIR", os.getcwd()) log_file_path = os.path.abspath( - os.path.join(log_dir, "./logs/lightrag.log") + os.path.join(log_dir, "./Logs/lightrag.log") ) print(f"\nLightRAG log file: {log_file_path}\n") @@ -97,10 +95,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray: ) -async def initialize_rag(working_dir): +async def initialize_rag(working_dir, graph_storage=None): + if graph_storage is None: + graph_storage = 'NetworkXStorage' rag = LightRAG( working_dir=working_dir, llm_model_func=llm_model_func, + graph_storage=graph_storage, embedding_func=EmbeddingFunc( embedding_dim=EMBED_DIM, max_token_size=EMBED_MAX_TOKEN_SIZE, @@ -139,4 +140,40 @@ def create_embedding_func(): api_key=EMBED_API_KEY, base_url=EMBED_BASE_URL, ), - ) \ No newline at end of file + ) + + +# AGE +os.environ["AGE_GRAPH_NAME"] = AGE_GRAPH_NAME +os.environ["POSTGRES_HOST"] = POSTGRES_HOST +os.environ["POSTGRES_PORT"] = str(POSTGRES_PORT) +os.environ["POSTGRES_USER"] = POSTGRES_USER +os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD +os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE + + +async def initialize_pg_rag(WORKING_DIR, workspace='default'): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + llm_model_name=LLM_MODEL_NAME, + llm_model_max_async=4, + llm_model_max_token_size=32768, + enable_llm_cache_for_entity_extract=True, + embedding_func=EmbeddingFunc( + embedding_dim=EMBED_DIM, + max_token_size=EMBED_MAX_TOKEN_SIZE, + func=embedding_func + ), + kv_storage="PGKVStorage", + doc_status_storage="PGDocStatusStorage", + graph_storage="PGGraphStorage", + vector_storage="PGVectorStorage", + auto_manage_storages_states=False, + vector_db_storage_cls_kwargs={"workspace": workspace} + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag diff --git a/dsAiTeachingModel/utils/PageUtil.py b/dsAiTeachingModel/utils/PageUtil.py new file mode 100644 index 00000000..ae84601e --- /dev/null +++ b/dsAiTeachingModel/utils/PageUtil.py @@ -0,0 +1,48 @@ +import math +from utils.Database import * + + +# 查询数据条数 +async def get_total_data_count(total_data_sql): + total_data_count = 0 + total_data_count_sql = "select count(*) as num from (" + total_data_sql + ") as temp_table" + result = await find_by_sql(total_data_count_sql,()) + row = result[0] if result else None + if row: + total_data_count = row.get("num") + return total_data_count + + +def get_page_by_total_row(total_data_count, page_number, page_size): + total_page = (page_size != 0) and math.floor((total_data_count + page_size - 1) / page_size) or 0 + if page_number <= 0: + page_number = 1 + if 0 < total_page < page_number: + page_number = total_page + offset = page_size * page_number - page_size + limit = page_size + return total_data_count, total_page, offset, limit + + +async def get_page_data_by_sql(total_data_sql: str, page_number: int, page_size: int): + total_row: int = 0 + total_page: int = 0 + total_data_sql = total_data_sql.replace(";", "") + total_data_sql = total_data_sql.replace(" FROM ", " from ") + + # 查询总数 + total_data_count = await get_total_data_count(total_data_sql) + if total_data_count == 0: + return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []} + else: + total_row, total_page, offset, limit = get_page_by_total_row(total_data_count, page_number, page_size) + + # 构造执行分页查询的sql语句 + page_data_sql = total_data_sql + " LIMIT %d offset %d " % (limit, offset) + print(page_data_sql) + # 执行分页查询 + page_data = await find_by_sql(page_data_sql, ()) + if page_data: + return {"page_number": page_number, "page_size": page_size, "total_row": total_row, "total_page": total_page, "list": page_data} + else: + return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []} diff --git a/dsLightRag/.idea/dsLightRag.iml b/dsLightRag/.idea/dsLightRag.iml index 4ceb6f94..880d61c1 100644 --- a/dsLightRag/.idea/dsLightRag.iml +++ b/dsLightRag/.idea/dsLightRag.iml @@ -2,7 +2,7 @@ - + diff --git a/dsLightRag/.idea/misc.xml b/dsLightRag/.idea/misc.xml index 0bad5868..0f9b3bc1 100644 --- a/dsLightRag/.idea/misc.xml +++ b/dsLightRag/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/dsLightRag/Config/__pycache__/Config.cpython-310.pyc b/dsLightRag/Config/__pycache__/Config.cpython-310.pyc index 76e8086d..d08d90df 100644 Binary files a/dsLightRag/Config/__pycache__/Config.cpython-310.pyc and b/dsLightRag/Config/__pycache__/Config.cpython-310.pyc differ diff --git a/dsLightRag/Doc/2、Conda维护.txt b/dsLightRag/Doc/2、Conda维护.txt index 021764bf..80bfef40 100644 --- a/dsLightRag/Doc/2、Conda维护.txt +++ b/dsLightRag/Doc/2、Conda维护.txt @@ -5,7 +5,7 @@ conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/f conda config --set show_channel_urls yes # 创建虚拟环境 -conda create -n rag python=3.10 +conda create -n py310 python=3.10 # 查看当前存在哪些虚拟环境 conda env list @@ -15,16 +15,16 @@ conda info -e conda list # 激活虚拟环境 -conda activate rag +conda activate py310 # 对虚拟环境中安装额外的包 -conda install -n rag $package_name +conda install -n py310 $package_name # 删除虚拟环境 -conda remove -n rag --all +conda remove -n py310 --all # 删除环境中的某个包 -conda remove --name rag $package_name +conda remove --name py310 $package_name # 恢复默认镜像 conda config --remove-key channels diff --git a/dsLightRag/Doc/9、Postgresql支持工作空间的代码修改/postgres_impl.py b/dsLightRag/Doc/9、Postgresql支持工作空间的代码修改/postgres_impl.py index f02ad79f..d18c6cc0 100644 --- a/dsLightRag/Doc/9、Postgresql支持工作空间的代码修改/postgres_impl.py +++ b/dsLightRag/Doc/9、Postgresql支持工作空间的代码修改/postgres_impl.py @@ -965,8 +965,8 @@ class PGDocStatusStorage(DocStatusStorage): else: exist_keys = [] new_keys = set([s for s in keys if s not in exist_keys]) - print(f"keys: {keys}") - print(f"new_keys: {new_keys}") + #print(f"keys: {keys}") + #print(f"new_keys: {new_keys}") return new_keys except Exception as e: logger.error( diff --git a/dsLightRag/Start.py b/dsLightRag/Start.py index 584315f5..26479237 100644 --- a/dsLightRag/Start.py +++ b/dsLightRag/Start.py @@ -17,13 +17,8 @@ from starlette.staticfiles import StaticFiles from Util.LightRagUtil import * from Util.PostgreSQLUtil import init_postgres_pool -# 在程序开始时添加以下配置 -logging.basicConfig( - level=logging.INFO, # 设置日志级别为INFO - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -# 或者如果你想更详细地控制日志输出 +# 想更详细地控制日志输出 logger = logging.getLogger('lightrag') logger.setLevel(logging.INFO) handler = logging.StreamHandler() @@ -300,5 +295,92 @@ async def render_html(request: fastapi.Request): } +@app.get("/api/sources") +async def get_sources(page: int = 1, limit: int = 10): + try: + pg_pool = await init_postgres_pool() + async with pg_pool.acquire() as conn: + # 获取总数 + total = await conn.fetchval("SELECT COUNT(*) FROM t_wechat_source") + # 获取分页数据 + offset = (page - 1) * limit + rows = await conn.fetch( + """ + SELECT id, account_id,account_name, created_at + FROM t_wechat_source + ORDER BY created_at DESC + LIMIT $1 OFFSET $2 + """, + limit, offset + ) + + sources = [ + { + "id": row[0], + "name": row[1], + "type": row[2], + "update_time": row[3].strftime("%Y-%m-%d %H:%M:%S") if row[3] else None + } + for row in rows + ] + + return { + "code": 0, + "data": { + "list": sources, + "total": total, + "page": page, + "limit": limit + } + } + except Exception as e: + return {"code": 1, "msg": str(e)} + + +@app.get("/api/articles") +async def get_articles(page: int = 1, limit: int = 10): + try: + pg_pool = await init_postgres_pool() + async with pg_pool.acquire() as conn: + # 获取总数 + total = await conn.fetchval("SELECT COUNT(*) FROM t_wechat_articles") + # 获取分页数据 + offset = (page - 1) * limit + rows = await conn.fetch( + """ + SELECT a.id, a.title, a.source as name, + a.publish_time, a.collection_time,a.url + FROM t_wechat_articles a + ORDER BY a.collection_time DESC + LIMIT $1 OFFSET $2 + """, + limit, offset + ) + + articles = [ + { + "id": row[0], + "title": row[1], + "source": row[2], + "publish_date": row[3].strftime("%Y-%m-%d") if row[3] else None, + "collect_time": row[4].strftime("%Y-%m-%d %H:%M:%S") if row[4] else None, + "url": row[5], + } + for row in rows + ] + + return { + "code": 0, + "data": { + "list": articles, + "total": total, + "page": page, + "limit": limit + } + } + except Exception as e: + return {"code": 1, "msg": str(e)} + + if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/dsLightRag/T1_Train.py b/dsLightRag/T1_Train.py index 1db08183..88080efa 100644 --- a/dsLightRag/T1_Train.py +++ b/dsLightRag/T1_Train.py @@ -4,12 +4,6 @@ import logging from Util.DocxUtil import get_docx_content_by_pandoc from Util.LightRagUtil import initialize_pg_rag -# 在程序开始时添加以下配置 -logging.basicConfig( - level=logging.INFO, # 设置日志级别为INFO - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) - # 或者如果你想更详细地控制日志输出 logger = logging.getLogger('lightrag') logger.setLevel(logging.INFO) diff --git a/dsLightRag/Test/Test/Logs/article_bfc50bb7d7.html b/dsLightRag/Test/Test/Logs/article_bfc50bb7d7.html new file mode 100644 index 00000000..cd460649 --- /dev/null +++ b/dsLightRag/Test/Test/Logs/article_bfc50bb7d7.html @@ -0,0 +1,162 @@ + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dsLightRag/Util/WxGzhUtil.py b/dsLightRag/Util/WxGzhUtil.py new file mode 100644 index 00000000..07abec69 --- /dev/null +++ b/dsLightRag/Util/WxGzhUtil.py @@ -0,0 +1,47 @@ +from selenium import webdriver +from selenium.webdriver.chrome.options import Options +from selenium.webdriver.chrome.service import Service as ChromeService +from selenium.webdriver.common.by import By + +def init_wechat_browser(): + """初始化微信爬虫浏览器实例""" + options = Options() + options.add_argument('-headless') + service = ChromeService(executable_path=r"C:\Windows\System32\chromedriver.exe") + return webdriver.Chrome(service=service, options=options) + + +def get_article_content(url): + """ + 获取微信公众号文章内容 + :param url: 文章URL + :return: 文章内容文本 + """ + options = Options() + options.add_argument('-headless') + service = ChromeService(executable_path=r"C:\Windows\System32\chromedriver.exe") + driver = webdriver.Chrome(service=service, options=options) + + try: + driver.get(url) + html_content = driver.find_element(By.CLASS_NAME, "rich_media").text + + # 处理内容,提取空行后的文本 + lines = html_content.split('\n') + content_after_empty_line = "" + found_empty_line = False + + for line in lines: + if not found_empty_line and line.strip() == "": + found_empty_line = True + continue + + if found_empty_line: + content_after_empty_line += line + "\n" + + if not found_empty_line: + content_after_empty_line = html_content + + return content_after_empty_line.replace("\n\n", "\n") + finally: + driver.quit() \ No newline at end of file diff --git a/dsLightRag/Util/__pycache__/LightRagUtil.cpython-310.pyc b/dsLightRag/Util/__pycache__/LightRagUtil.cpython-310.pyc index 9bc1e54d..d89e8018 100644 Binary files a/dsLightRag/Util/__pycache__/LightRagUtil.cpython-310.pyc and b/dsLightRag/Util/__pycache__/LightRagUtil.cpython-310.pyc differ diff --git a/dsLightRag/Util/__pycache__/PostgreSQLUtil.cpython-310.pyc b/dsLightRag/Util/__pycache__/PostgreSQLUtil.cpython-310.pyc index aada6b85..a95beed5 100644 Binary files a/dsLightRag/Util/__pycache__/PostgreSQLUtil.cpython-310.pyc and b/dsLightRag/Util/__pycache__/PostgreSQLUtil.cpython-310.pyc differ diff --git a/dsLightRag/Util/__pycache__/WxGzhUtil.cpython-310.pyc b/dsLightRag/Util/__pycache__/WxGzhUtil.cpython-310.pyc new file mode 100644 index 00000000..5075ff6c Binary files /dev/null and b/dsLightRag/Util/__pycache__/WxGzhUtil.cpython-310.pyc differ diff --git a/dsLightRag/WxGzh/T1_LoginGetCookie.py b/dsLightRag/WxGzh/T1_LoginGetCookie.py new file mode 100644 index 00000000..83eb3e0d --- /dev/null +++ b/dsLightRag/WxGzh/T1_LoginGetCookie.py @@ -0,0 +1,78 @@ +# 详解(一)Python + Selenium 批量采集微信公众号,搭建自己的微信公众号每日AI简报,告别信息焦虑 +# https://blog.csdn.net/k352733625/article/details/149222945 + +# 微信爬爬猫---公众号文章抓取代码分析 +# https://blog.csdn.net/yajuanpi4899/article/details/121584268 + +import json +import logging + +from torch.distributed.elastic.timer import expires + +""" +# 查看selenium版本 +pip show selenium +4.34.2 + +# 查看Chrome浏览器版本 +chrome://version/ +138.0.7204.101 (正式版本) (64 位) + +# 下载驱动包 +https://googlechromelabs.github.io/chrome-for-testing/ +https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.94/win64/chromedriver-win64.zip +""" +import time +from selenium import webdriver +from selenium.webdriver.chrome.options import Options +from selenium.webdriver.chrome.service import Service as ChromeService + +if __name__ == '__main__': + # 定义一个空的字典,存放cookies内容 + cookies = {} + # 设置headers - 使用微信内置浏览器的User-Agent + header = { + "HOST": "mp.weixin.qq.com", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/53.0.2785.116 Safari/537.36 QBCore/4.0.1301.400 QQBrowser/9.0.2524.400 Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/53.0.2875.116 Safari/537.36 NetType/WIFI MicroMessenger/7.0.20.1781(0x6700143B) WindowsWechat(0x63010200)", + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", + "Accept-Encoding": "gzip, deflate, br", + "Accept-Language": "zh-CN,zh;q=0.8,en-US;q=0.6,en;q=0.5;q=0.4", + "Connection": "keep-alive" + } + # 用webdriver启动谷歌浏览器 + logging.info("启动浏览器,打开微信公众号登录界面") + options = Options() + + service = ChromeService(executable_path=r"C:\Windows\System32\chromedriver.exe") + driver = webdriver.Chrome(service=service, options=options) + # 打开微信公众号登录页面 + driver.get('https://mp.weixin.qq.com/') + # 等待5秒钟 + time.sleep(2) + # # 拿手机扫二维码! + logging.info("请拿手机扫码二维码登录公众号") + time.sleep(20) + + # 重新载入公众号登录页,登录之后会显示公众号后台首页,从这个返回内容中获取cookies信息 + driver.get('https://mp.weixin.qq.com/') + # 获取cookies + cookie_items = driver.get_cookies() + expiry=-1 + # 获取到的cookies是列表形式,将cookies转成json形式并存入本地名为cookie的文本中 + for cookie_item in cookie_items: + cookies[cookie_item['name']] = cookie_item['value'] + if('expiry' in cookie_item and cookie_item['expiry'] > expiry): + expiry = cookie_item['expiry'] + + if "slave_sid" not in cookies: + logging.info("登录公众号失败,获取cookie失败") + exit() + + # 将cookies写入文件 + cookies["expiry"] = expiry + with open('cookies.txt', mode='w', encoding="utf-8") as f: + f.write(json.dumps(cookies, indent=4, ensure_ascii=False)) + # 关闭浏览器 + driver.quit() + # 输出提示 + print("成功获取了cookies内容!") diff --git a/dsLightRag/WxGzh/T2_CollectArticle.py b/dsLightRag/WxGzh/T2_CollectArticle.py new file mode 100644 index 00000000..a6e073bb --- /dev/null +++ b/dsLightRag/WxGzh/T2_CollectArticle.py @@ -0,0 +1,223 @@ +# 详解(一)Python + Selenium 批量采集微信公众号,搭建自己的微信公众号每日AI简报,告别信息焦虑 +# https://blog.csdn.net/k352733625/article/details/149222945 + +# 微信爬爬猫---公众号文章抓取代码分析 +# https://blog.csdn.net/yajuanpi4899/article/details/121584268 + + +import asyncio +import datetime +import json +import logging +import random +import re + +import requests + +from Util.PostgreSQLUtil import init_postgres_pool +from Util.WxGzhUtil import init_wechat_browser, get_article_content + +# 删除重复的日志配置,只保留以下内容 +logger = logging.getLogger('WeiXinGongZhongHao') +logger.setLevel(logging.INFO) + +# 确保只添加一个handler +if not logger.handlers: + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + logger.addHandler(handler) + + +async def get_wechat_sources(): + """从t_wechat_source表获取微信公众号列表""" + try: + pool = await init_postgres_pool() + async with pool.acquire() as conn: + rows = await conn.fetch('SELECT * FROM t_wechat_source') + return [dict(row) for row in rows] + finally: + await pool.close() + + +""" +# 查看selenium版本 +pip show selenium +4.34.2 + +# 查看Chrome浏览器版本 +chrome://version/ +138.0.7204.101 (正式版本) (64 位) + +# 下载驱动包 +https://googlechromelabs.github.io/chrome-for-testing/ +https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.94/win64/chromedriver-win64.zip +""" +import time +from selenium.webdriver.chrome.options import Options +from selenium.webdriver.chrome.service import Service as ChromeService + + +async def is_article_exist(pool, article_url): + """检查文章URL是否已存在数据库中""" + try: + async with pool.acquire() as conn: + row = await conn.fetchrow(''' + SELECT 1 + FROM t_wechat_articles + WHERE url = $1 LIMIT 1 + ''', article_url) + return row is not None + except Exception as e: + logging.error(f"检查文章存在性失败: {e}") + return False # 出错时默认返回False,避免影响正常流程 + + +async def save_article_to_db(pool, article_title, account_name, article_url, publish_time, content, id): + # 先检查文章是否已存在 + if await is_article_exist(pool, article_url): + logger.info(f"文章已存在,跳过保存: {article_url}") + return + + try: + async with pool.acquire() as conn: + await conn.execute(''' + INSERT INTO t_wechat_articles + (title, source, url, publish_time, content, source_id) + VALUES ($1, $2, $3, $4, $5, $6) + ''', article_title, account_name, article_url, + publish_time, content, id) + except Exception as e: + logging.error(f"保存文章失败: {e}") + + +if __name__ == '__main__': + # 从文件cookies.txt中获取 + with open('cookies.txt', 'r', encoding='utf-8') as f: + content = f.read() + # 使用json还原为json对象 + cookies = json.loads(content) + # "expiry": 1787106233 + # 检查是否有过期时间 + expiry = cookies["expiry"] + if expiry: + # 换算出过期时间 + expiry_time = time.localtime(expiry) + expiry_date = time.strftime("%Y-%m-%d %H:%M:%S", expiry_time) + + # 获取当前时间戳 + current_timestamp = time.time() + # 检查是否已过期 + if current_timestamp > expiry: + logger.error("Cookie已过期") + exit() + # 移除expiry属性 + del cookies["expiry"] + logger.info(f"cookies的过期时间一般是4天,cookies过期时间:%s" % expiry_date) + options = Options() + options.add_argument('-headless') # 无头参数,调试时可以注释掉 + # 设置headers - 使用微信内置浏览器的User-Agent + header = { + "HOST": "mp.weixin.qq.com", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/53.0.2785.116 Safari/537.36 QBCore/4.0.1301.400 QQBrowser/9.0.2524.400 Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/53.0.2875.116 Safari/537.36 NetType/WIFI MicroMessenger/7.0.20.1781(0x6700143B) WindowsWechat(0x63010200)", + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", + "Accept-Encoding": "gzip, deflate, br", + "Accept-Language": "zh-CN,zh;q=0.8,en-US;q=0.6,en;q=0.5;q=0.4", + "Connection": "keep-alive" + } + + service = ChromeService(executable_path=r"C:\Windows\System32\chromedriver.exe") + # 使用统一的初始化方式 + driver = init_wechat_browser() + + # 方法3:使用requests库发送请求获取重定向URL + url = 'https://mp.weixin.qq.com' + response = requests.get(url=url, allow_redirects=False, cookies=cookies) + if 'Location' in response.headers: + redirect_url = response.headers.get("Location") + logger.info(f"重定向URL:%s"%redirect_url) + token_match = re.findall(r'token=(\d+)', redirect_url) + if token_match: + token = token_match[0] + logger.info(f"获取到的token:%s"%token) + + article_urls = [] + + # 获取公众号列表 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + gzlist = loop.run_until_complete(get_wechat_sources()) + finally: + loop.close() + + # 爬取文章 + for item in gzlist: + account_name = item["account_name"] + account_id = item["account_id"] + id = item["id"] + # 搜索微信公众号的接口地址 + search_url = 'https://mp.weixin.qq.com/cgi-bin/searchbiz?' + # 搜索微信公众号接口需要传入的参数,有三个变量:微信公众号token、随机数random、搜索的微信公众号名字 + query_id = { + 'action': 'search_biz', + 'token': token, + 'lang': 'zh_CN', + 'f': 'json', + 'ajax': '1', + 'random': random.random(), + 'query': account_name, + 'begin': '0', + 'count': '5' + } + # 打开搜索微信公众号接口地址,需要传入相关参数信息如:cookies、params、headers + search_response = requests.get(search_url, cookies=cookies, headers=header, params=query_id) + # 取搜索结果中的第一个公众号 + lists = search_response.json().get('list')[0] + # 获取这个公众号的fakeid,后面爬取公众号文章需要此字段 + fakeid = lists.get('fakeid') + logging.info("fakeid:" + fakeid) + # 微信公众号文章接口地址 + appmsg_url = 'https://mp.weixin.qq.com/cgi-bin/appmsg?' + # 搜索文章需要传入几个参数:登录的公众号token、要爬取文章的公众号fakeid、随机数random + query_id_data = { + 'token': token, + 'lang': 'zh_CN', + 'f': 'json', + 'ajax': '1', + 'random': random.random(), + 'action': 'list_ex', + 'begin': '0', # 不同页,此参数变化,变化规则为每页加5 + 'count': '5', + 'query': '', + 'fakeid': fakeid, + 'type': '9' + } + # 打开搜索的微信公众号文章列表页 + query_fakeid_response = requests.get(appmsg_url, cookies=cookies, headers=header, params=query_id_data) + fakeid_list = query_fakeid_response.json().get('app_msg_list') + + for item in fakeid_list: + article_url = item.get('link') + article_title = item.get('title') + publish_time = datetime.datetime.fromtimestamp(int(item.get("update_time"))) + + if '试卷' in article_title: # 过滤掉试卷 + continue + + logger.info(f"正在处理文章: {article_title} ({publish_time})") + content = get_article_content(article_url) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + pool = loop.run_until_complete(init_postgres_pool()) + loop.run_until_complete( + save_article_to_db(pool, article_title, account_name, article_url, publish_time, content, + id)) + finally: + loop.run_until_complete(pool.close()) + loop.close() + + time.sleep(1) + # 关闭浏览器 + driver.quit() diff --git a/dsLightRag/WxGzh/T3_TrainIntoKG.py b/dsLightRag/WxGzh/T3_TrainIntoKG.py new file mode 100644 index 00000000..86473413 --- /dev/null +++ b/dsLightRag/WxGzh/T3_TrainIntoKG.py @@ -0,0 +1,63 @@ +import asyncio +import logging + +from Util.DocxUtil import get_docx_content_by_pandoc +from Util.LightRagUtil import initialize_pg_rag +from Util.PostgreSQLUtil import init_postgres_pool + +logger = logging.getLogger('lightrag') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) +logger.addHandler(handler) +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + +# 使用PG库后,这个是没有用的,但目前的项目代码要求必传,就写一个吧。 +WORKING_DIR = f"./output" + + +async def get_unprocessed_articles(): + """从t_wechat_articles表获取未处理的文章""" + try: + pool = await init_postgres_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(''' + SELECT id, source, title, content + FROM t_wechat_articles + WHERE is_finish = 0 + ''') + return [dict(row) for row in rows] + finally: + await pool.close() + +async def main(): + # 获取未处理的文章 + articles = await get_unprocessed_articles() + logger.info(f"共获取到{len(articles)}篇未处理的文章") + + for article in articles: + workspace = 'ChangChun' + docx_name = f"{article['source']}_{article['title']}" # 组合来源和标题作为文档名 + content = article["content"] # 使用文章内容 + + logger.info(f"开始处理文档: {docx_name}") + try: + rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) + await rag.ainsert(input=content, file_paths=[docx_name]) + + # 标记为已处理 + pool = await init_postgres_pool() + async with pool.acquire() as conn: + await conn.execute(''' + UPDATE t_wechat_articles + SET is_finish = 1 + WHERE id = $1 + ''', article["id"]) + finally: + if rag: + await rag.finalize_storages() + if pool: + await pool.close() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dsLightRag/WxGzh/__init__.py b/dsLightRag/WxGzh/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dsLightRag/WxGzh/cookies.txt b/dsLightRag/WxGzh/cookies.txt new file mode 100644 index 00000000..7183301e --- /dev/null +++ b/dsLightRag/WxGzh/cookies.txt @@ -0,0 +1,17 @@ +{ + "_clsk": "2gtve8|1752546228205|1|1|mp.weixin.qq.com/weheat-agent/payload/record", + "xid": "16332bed01be1055e236ad45b33af8df", + "data_bizuin": "3514353238", + "slave_user": "gh_4f88a4e194da", + "slave_sid": "QzBRX1FWTXNMaEdJYnc4ODBaM3FJU3RRbjVJNFE2N2IzMXFyVGlRQ0V5YklvNGFOc3NBWHdjV2J5OVg5U0JBVXdfdGhSU3lObXRheG1TdFUyXzVFcTFYS3E1NTh2aTlnSlBOOUluMUljUnBkYktjeUJDM216WVJNYzJKQkx2eW9Ib1duUk1yWXI3RndTa2dK", + "rand_info": "CAESIFwUSYus3XR5tFa1+b5ytJeuGAQS02d07zNBJNfi+Ftk", + "data_ticket": "9gQ088/vC7+jqxfFxBKS2aRx/JjmzJt+8HyuDLJtQBgpVej1hfSG1A0FQKWBbHQh", + "bizuin": "3514353238", + "mm_lang": "zh_CN", + "slave_bizuin": "3514353238", + "uuid": "8c5dc8e06af66d00a4b8e8596c8662eb", + "ua_id": "y1HZNMSzYCWuaUJDAAAAAApPVJ0a_arX_A5zqoUh6P8=", + "wxuin": "52546211515015", + "_clck": "msq32d|1|fxm|0", + "expiry": 1787106233 +} \ No newline at end of file diff --git a/dsLightRag/static/ChangChun.html b/dsLightRag/static/ChangChun.html index ac938d40..a4105b5f 100644 --- a/dsLightRag/static/ChangChun.html +++ b/dsLightRag/static/ChangChun.html @@ -3,7 +3,7 @@ - 【长春市中考报考知识库】 + 【长春市教育信息资讯库】