Merge branch 'main' of http://10.10.14.176:3000/huanghai/dsProject
This commit is contained in:
@@ -47,19 +47,24 @@ async def train_document_task():
|
|||||||
logging.info(f"开始处理文档:{document_name}, 还有{len(no_train_document_result) - 1}个文档需要处理!")
|
logging.info(f"开始处理文档:{document_name}, 还有{len(no_train_document_result) - 1}个文档需要处理!")
|
||||||
# 训练代码开始
|
# 训练代码开始
|
||||||
# content = get_docx_content_by_pandoc(document_path)
|
# content = get_docx_content_by_pandoc(document_path)
|
||||||
|
train_result = True
|
||||||
try:
|
try:
|
||||||
# 注意:默认设置使用NetworkX
|
# 注意:默认设置使用NetworkX
|
||||||
rag = await initialize_rag(working_dir)
|
rag = await initialize_rag(working_dir)
|
||||||
# 获取docx文件的内容
|
# 获取docx文件的内容
|
||||||
content = get_docx_content_by_pandoc(document_path)
|
content = get_docx_content_by_pandoc(document_path)
|
||||||
await rag.ainsert(content, ids=[document_name], file_paths=[document_name])
|
if content is not None:
|
||||||
logger.info(f"Inserted content from {document_name}")
|
await rag.ainsert(content, ids=[document_name], file_paths=[document_name])
|
||||||
|
logger.info(f"Inserted content from {document_name}")
|
||||||
|
else:
|
||||||
|
train_result = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred: {e}")
|
logger.error(f"An error occurred: {e}")
|
||||||
finally:
|
finally:
|
||||||
await rag.finalize_storages()
|
await rag.finalize_storages()
|
||||||
# 训练结束,更新训练状态
|
# 训练结束,更新训练状态
|
||||||
update_document_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 2 WHERE id = " + str(document["id"])
|
train_flag = "2" if train_result else "7"
|
||||||
|
update_document_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = " + train_flag + " WHERE id = " + str(document["id"])
|
||||||
await execute_sql(update_document_sql, ())
|
await execute_sql(update_document_sql, ())
|
||||||
elif document["train_flag"] == 3:
|
elif document["train_flag"] == 3:
|
||||||
update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 4 WHERE id = " + str(document["id"])
|
update_sql: str = " UPDATE t_ai_teaching_model_document SET train_flag = 4 WHERE id = " + str(document["id"])
|
||||||
@@ -80,22 +85,30 @@ async def train_document_task():
|
|||||||
await execute_sql(update_document_sql, ())
|
await execute_sql(update_document_sql, ())
|
||||||
|
|
||||||
# 整体更新主题状态
|
# 整体更新主题状态
|
||||||
select_document_sql: str = f"select train_flag, count(1) as train_count from t_ai_teaching_model_document where theme_id = {theme['id']} and is_deleted = 0 and train_flag in (0,1,2) group by train_flag"
|
select_document_sql: str = f"select train_flag, count(1) as train_count from t_ai_teaching_model_document where theme_id = {theme['id']} and is_deleted = 0 and train_flag in (0,1,2,7) group by train_flag"
|
||||||
select_document_result = await find_by_sql(select_document_sql, ())
|
select_document_result = await find_by_sql(select_document_sql, ())
|
||||||
train_document_count_map = {}
|
train_document_count_map = {}
|
||||||
for item in select_document_result:
|
for item in select_document_result:
|
||||||
train_document_count_map[str(item["train_flag"])] = int(item["train_count"])
|
train_document_count_map[str(item["train_flag"])] = int(item["train_count"])
|
||||||
|
train_document_count_0 = train_document_count_map.get("0", 0)
|
||||||
train_document_count_1 = train_document_count_map.get("1", 0)
|
train_document_count_1 = train_document_count_map.get("1", 0)
|
||||||
train_document_count_2 = train_document_count_map.get("2", 0)
|
train_document_count_2 = train_document_count_map.get("2", 0)
|
||||||
|
train_document_count_7 = train_document_count_map.get("7", 0)
|
||||||
|
search_flag = 0
|
||||||
if train_document_count_2 > 0:
|
if train_document_count_2 > 0:
|
||||||
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 1, train_flag = 2 WHERE id = {theme['id']}"
|
search_flag = 1
|
||||||
|
# 训练未开始:初始化| 或者没有is_deleted=0的文档
|
||||||
|
if train_document_count_0 == 0 and train_document_count_1 == 0 and train_document_count_2 == 0 and train_document_count_7 == 0:
|
||||||
|
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = {search_flag}, train_flag = 0 WHERE id = {theme['id']}"
|
||||||
await execute_sql(update_theme_sql, ())
|
await execute_sql(update_theme_sql, ())
|
||||||
else:
|
# 训练进行中:单个文档训练中|同时存在训练完成+未训练的文档
|
||||||
if train_document_count_1 > 0:
|
if train_document_count_2 > 0 and (train_document_count_1 > 0 or train_document_count_7 > 0):
|
||||||
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 0, train_flag = 1 WHERE id = {theme['id']}"
|
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = {search_flag}, train_flag = 1 WHERE id = {theme['id']}"
|
||||||
await execute_sql(update_theme_sql, ())
|
await execute_sql(update_theme_sql, ())
|
||||||
else:
|
# 训练已完成:所有is_deleted=0的文档都训练完成
|
||||||
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = 0, train_flag = 0 WHERE id = {theme['id']}"
|
if train_document_count_2 > 0 and train_document_count_0 == 0 and train_document_count_1 == 0 and train_document_count_7 == 0:
|
||||||
await execute_sql(update_theme_sql, ())
|
update_theme_sql: str = f"UPDATE t_ai_teaching_model_theme SET search_flag = {search_flag}, train_flag = 2 WHERE id = {theme['id']}"
|
||||||
|
await execute_sql(update_theme_sql, ())
|
||||||
|
|
||||||
# 添加适当的等待时间,避免频繁查询
|
# 添加适当的等待时间,避免频繁查询
|
||||||
await asyncio.sleep(120) # 每二分钟查询一次
|
await asyncio.sleep(120) # 每二分钟查询一次
|
||||||
|
@@ -93,50 +93,53 @@ def get_docx_content_by_pandoc(docx_file):
|
|||||||
resize_images_in_directory('./static/Images/' + md5_value + '/media')
|
resize_images_in_directory('./static/Images/' + md5_value + '/media')
|
||||||
# 读取然后修改内容,输出到新的文件
|
# 读取然后修改内容,输出到新的文件
|
||||||
img_idx = 0 # 图片索引
|
img_idx = 0 # 图片索引
|
||||||
with open(temp_markdown, 'r', encoding='utf-8') as f:
|
if os.path.exists(temp_markdown):
|
||||||
for line in f:
|
with open(temp_markdown, 'r', encoding='utf-8') as f:
|
||||||
line = line.strip()
|
for line in f:
|
||||||
if not line:
|
line = line.strip()
|
||||||
continue
|
if not line:
|
||||||
# 跳过图片高度描述行
|
continue
|
||||||
if line.startswith('height=') and (line.endswith('in"}') or line.endswith('in"')):
|
# 跳过图片高度描述行
|
||||||
continue
|
if line.startswith('height=') and (line.endswith('in"}') or line.endswith('in"')):
|
||||||
# height="1.91044072615923in"
|
continue
|
||||||
# 使用find()方法安全地检查图片模式
|
# height="1.91044072615923in"
|
||||||
is_img = line.find(" >= 0 and (
|
# 使用find()方法安全地检查图片模式
|
||||||
line.find(".png") > 0 or
|
is_img = line.find(" >= 0 and (
|
||||||
line.find(".jpg") > 0 or
|
line.find(".png") > 0 or
|
||||||
line.find(".jpeg") > 0
|
line.find(".jpg") > 0 or
|
||||||
)
|
line.find(".jpeg") > 0
|
||||||
if is_img:
|
)
|
||||||
# {width="3.1251607611548557in"
|
if is_img:
|
||||||
# height="3.694634733158355in"}
|
# {width="3.1251607611548557in"
|
||||||
# {width="3.1251607611548557in"
|
# height="3.694634733158355in"}
|
||||||
pos = line.find(")")
|
# {width="3.1251607611548557in"
|
||||||
q = line[:pos + 1]
|
pos = line.find(")")
|
||||||
q = q.replace("./static", ".")
|
q = line[:pos + 1]
|
||||||
# Modify by Kalman.CHENG ☆: 增加逻辑对图片路径处理,在(和static之间加上/
|
q = q.replace("./static", ".")
|
||||||
left_idx = line.find("(")
|
# Modify by Kalman.CHENG ☆: 增加逻辑对图片路径处理,在(和static之间加上/
|
||||||
static_idx = line.find("static")
|
left_idx = line.find("(")
|
||||||
if left_idx == -1 or static_idx == -1 or left_idx > static_idx:
|
static_idx = line.find("static")
|
||||||
print("路径中不包含(+~+static的已知格式")
|
if left_idx == -1 or static_idx == -1 or left_idx > static_idx:
|
||||||
else:
|
print("路径中不包含(+~+static的已知格式")
|
||||||
between_content = q[left_idx+1:static_idx].strip()
|
|
||||||
if between_content:
|
|
||||||
q = q[:left_idx+1] + '\\' + q[static_idx:]
|
|
||||||
else:
|
else:
|
||||||
q = q[:static_idx] + '\\' + q[static_idx:]
|
between_content = q[left_idx+1:static_idx].strip()
|
||||||
print(f"q3:{q}")
|
if between_content:
|
||||||
#q = q[4:-1]
|
q = q[:left_idx+1] + '\\' + q[static_idx:]
|
||||||
#q='<img src="'+q+'" alt="我是图片">'
|
else:
|
||||||
img_idx += 1
|
q = q[:static_idx] + '\\' + q[static_idx:]
|
||||||
content += q + "\n"
|
print(f"q3:{q}")
|
||||||
else:
|
#q = q[4:-1]
|
||||||
content += line.strip().replace("**", "") + "\n"
|
#q='<img src="'+q+'" alt="我是图片">'
|
||||||
content = content.replace("\phantom", "")
|
img_idx += 1
|
||||||
# 将content回写到markdown文件
|
content += q + "\n"
|
||||||
with open(temp_markdown, 'w', encoding='utf-8') as f:
|
else:
|
||||||
f.write(content)
|
content += line.strip().replace("**", "") + "\n"
|
||||||
# 删除临时文件 output_file
|
content = content.replace("\phantom", "")
|
||||||
# os.remove(temp_markdown)
|
# 将content回写到markdown文件
|
||||||
return content.replace("\n\n", "\n").replace("\\", "/")
|
with open(temp_markdown, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(content)
|
||||||
|
# 删除临时文件 output_file
|
||||||
|
# os.remove(temp_markdown)
|
||||||
|
return content.replace("\n\n", "\n").replace("\\", "/")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
Reference in New Issue
Block a user