From b3c22ecceaa9c0b52795bfe982f91d9cf46827b4 Mon Sep 17 00:00:00 2001 From: chengminglong <123204464@qq.com> Date: Wed, 20 Aug 2025 11:07:09 +0800 Subject: [PATCH] =?UTF-8?q?commit=20by=20Kalman.CHENG=20=E2=98=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../TeachingModel/tasks/BackgroundTasks.py | 37 +++++--- dsLightRag/Util/DocxUtil.py | 95 ++++++++++--------- 2 files changed, 74 insertions(+), 58 deletions(-) diff --git a/dsLightRag/Routes/TeachingModel/tasks/BackgroundTasks.py b/dsLightRag/Routes/TeachingModel/tasks/BackgroundTasks.py index 4a71b606..c6ee9c4a 100644 --- a/dsLightRag/Routes/TeachingModel/tasks/BackgroundTasks.py +++ b/dsLightRag/Routes/TeachingModel/tasks/BackgroundTasks.py @@ -47,19 +47,24 @@ async def train_document_task(): logging.info(f"开始处理文档:{document_name}, 还有{len(no_train_document_result) - 1}个文档需要处理!") # 训练代码开始 # content = get_docx_content_by_pandoc(document_path) + train_result = True try: # 注意:默认设置使用NetworkX rag = await initialize_rag(working_dir) # 获取docx文件的内容 content = get_docx_content_by_pandoc(document_path) - await rag.ainsert(content, ids=[document_name], file_paths=[document_name]) - logger.info(f"Inserted content from {document_name}") + if content is not None: + await rag.ainsert(content, ids=[document_name], file_paths=[document_name]) + logger.info(f"Inserted content from {document_name}") + else: + train_result = False except Exception as e: logger.error(f"An error occurred: {e}") finally: await rag.finalize_storages() # 训练结束,更新训练状态 - update_document_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 2 WHERE id = " + str(document["id"]) + train_flag = "2" if train_result else "7" + update_document_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = " + train_flag + " WHERE id = " + str(document["id"]) await execute_sql(update_document_sql, ()) elif document["train_flag"] == 3: update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 4 WHERE id = " + str(document["id"]) @@ -80,22 +85,30 @@ async def train_document_task(): await execute_sql(update_document_sql, ()) # 整体更新主题状态 - select_document_sql: str = f"select train_flag, count(1) as train_count from t_ai_teaching_model_document where theme_id = {theme['id']} and is_deleted = 0 and train_flag in (0,1,2) group by train_flag" + select_document_sql: str = f"select train_flag, count(1) as train_count from t_ai_teaching_model_document where theme_id = {theme['id']} and is_deleted = 0 and train_flag in (0,1,2,7) group by train_flag" select_document_result = await find_by_sql(select_document_sql, ()) train_document_count_map = {} for item in select_document_result: train_document_count_map[str(item["train_flag"])] = int(item["train_count"]) + train_document_count_0 = train_document_count_map.get("0", 0) train_document_count_1 = train_document_count_map.get("1", 0) train_document_count_2 = train_document_count_map.get("2", 0) + train_document_count_7 = train_document_count_map.get("7", 0) + search_flag = 0 if train_document_count_2 > 0: - update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 1, train_flag = 2 WHERE id = {theme['id']}" + search_flag = 1 + # 训练未开始:初始化| 或者没有is_deleted=0的文档 + if train_document_count_0 == 0 and train_document_count_1 == 0 and train_document_count_2 == 0 and train_document_count_7 == 0: + update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = {search_flag}, train_flag = 0 WHERE id = {theme['id']}" await execute_sql(update_theme_sql, ()) - else: - if train_document_count_1 > 0: - update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 0, train_flag = 1 WHERE id = {theme['id']}" - await execute_sql(update_theme_sql, ()) - else: - update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 0, train_flag = 0 WHERE id = {theme['id']}" - await execute_sql(update_theme_sql, ()) + # 训练进行中:单个文档训练中|同时存在训练完成+未训练的文档 + if train_document_count_2 > 0 and (train_document_count_1 > 0 or train_document_count_7 > 0): + update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = {search_flag}, train_flag = 1 WHERE id = {theme['id']}" + await execute_sql(update_theme_sql, ()) + # 训练已完成:所有is_deleted=0的文档都训练完成 + if train_document_count_2 > 0 and train_document_count_0 == 0 and train_document_count_1 == 0 and train_document_count_7 == 0: + update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = {search_flag}, train_flag = 2 WHERE id = {theme['id']}" + await execute_sql(update_theme_sql, ()) + # 添加适当的等待时间,避免频繁查询 await asyncio.sleep(120) # 每二分钟查询一次 diff --git a/dsLightRag/Util/DocxUtil.py b/dsLightRag/Util/DocxUtil.py index e777f42a..8d2a9209 100644 --- a/dsLightRag/Util/DocxUtil.py +++ b/dsLightRag/Util/DocxUtil.py @@ -93,50 +93,53 @@ def get_docx_content_by_pandoc(docx_file): resize_images_in_directory('./static/Images/' + md5_value + '/media') # 读取然后修改内容,输出到新的文件 img_idx = 0 # 图片索引 - with open(temp_markdown, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if not line: - continue - # 跳过图片高度描述行 - if line.startswith('height=') and (line.endswith('in"}') or line.endswith('in"')): - continue - # height="1.91044072615923in" - # 使用find()方法安全地检查图片模式 - is_img = line.find("![](") >= 0 and ( - line.find(".png") > 0 or - line.find(".jpg") > 0 or - line.find(".jpeg") > 0 - ) - if is_img: - # ![](media/image3.png){width="3.1251607611548557in" - # height="3.694634733158355in"} - # ![](../static/Images/01b20e04085e406ea5375791da58a60f/media/image3.png){width="3.1251607611548557in" - pos = line.find(")") - q = line[:pos + 1] - q = q.replace("./static", ".") - # Modify by Kalman.CHENG ☆: 增加逻辑对图片路径处理,在(和static之间加上/ - left_idx = line.find("(") - static_idx = line.find("static") - if left_idx == -1 or static_idx == -1 or left_idx > static_idx: - print("路径中不包含(+~+static的已知格式") - else: - between_content = q[left_idx+1:static_idx].strip() - if between_content: - q = q[:left_idx+1] + '\\' + q[static_idx:] + if os.path.exists(temp_markdown): + with open(temp_markdown, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + # 跳过图片高度描述行 + if line.startswith('height=') and (line.endswith('in"}') or line.endswith('in"')): + continue + # height="1.91044072615923in" + # 使用find()方法安全地检查图片模式 + is_img = line.find("![](") >= 0 and ( + line.find(".png") > 0 or + line.find(".jpg") > 0 or + line.find(".jpeg") > 0 + ) + if is_img: + # ![](media/image3.png){width="3.1251607611548557in" + # height="3.694634733158355in"} + # ![](../static/Images/01b20e04085e406ea5375791da58a60f/media/image3.png){width="3.1251607611548557in" + pos = line.find(")") + q = line[:pos + 1] + q = q.replace("./static", ".") + # Modify by Kalman.CHENG ☆: 增加逻辑对图片路径处理,在(和static之间加上/ + left_idx = line.find("(") + static_idx = line.find("static") + if left_idx == -1 or static_idx == -1 or left_idx > static_idx: + print("路径中不包含(+~+static的已知格式") else: - q = q[:static_idx] + '\\' + q[static_idx:] - print(f"q3:{q}") - #q = q[4:-1] - #q='我是图片' - img_idx += 1 - content += q + "\n" - else: - content += line.strip().replace("**", "") + "\n" - content = content.replace("\phantom", "") - # 将content回写到markdown文件 - with open(temp_markdown, 'w', encoding='utf-8') as f: - f.write(content) - # 删除临时文件 output_file - # os.remove(temp_markdown) - return content.replace("\n\n", "\n").replace("\\", "/") + between_content = q[left_idx+1:static_idx].strip() + if between_content: + q = q[:left_idx+1] + '\\' + q[static_idx:] + else: + q = q[:static_idx] + '\\' + q[static_idx:] + print(f"q3:{q}") + #q = q[4:-1] + #q='我是图片' + img_idx += 1 + content += q + "\n" + else: + content += line.strip().replace("**", "") + "\n" + content = content.replace("\phantom", "") + # 将content回写到markdown文件 + with open(temp_markdown, 'w', encoding='utf-8') as f: + f.write(content) + # 删除临时文件 output_file + # os.remove(temp_markdown) + return content.replace("\n\n", "\n").replace("\\", "/") + else: + return None