Files
dsProject/dsRagAnything/T1_Train.py
2025-08-27 07:36:59 +08:00

154 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import logging
from raganything import RAGAnything, RAGAnythingConfig
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
from logging.handlers import RotatingFileHandler # 导入RotatingFileHandler用于日志轮转
import Config.Config
# 控制日志输出
root_logger = logging.getLogger('lightrag')
root_logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
root_logger.addHandler(handler)
# 同时保持原有的ragAnything日志记录器配置
logger = logging.getLogger('ragAnything')
logger.setLevel(logging.INFO)
# 控制台输出处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(console_handler)
# 循环滚动文件处理器控制在200K左右
file_handler = RotatingFileHandler(
'lightrag.log',
maxBytes=200 * 1024, # 200KB
backupCount=5, # 最多保留5个备份文件
encoding='utf-8',
delay=True # 延迟创建文件,直到有日志输出
)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
async def train(file_path, output_dir, working_dir):
# 设置 API 配置
api_key = Config.Config.ALY_LLM_API_KEY
base_url = Config.Config.ALY_LLM_BASE_URL
# 创建 RAGAnything 配置
config = RAGAnythingConfig(
working_dir=working_dir,
parser="mineru", # 选择解析器mineru 或 docling
parse_method="auto", # 解析方法auto, ocr 或 txt
enable_image_processing=False,
enable_table_processing=False,
enable_equation_processing=True,
)
# 定义 LLM 模型函数
def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
return openai_complete_if_cache(
Config.Config.ALY_LLM_MODEL_NAME,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=api_key,
base_url=base_url,
**kwargs,
)
# 定义视觉模型函数用于图像处理
def vision_model_func(
prompt, system_prompt=None, history_messages=[], image_data=None, messages=None, **kwargs
):
# 如果提供了messages格式用于多模态VLM增强查询直接使用
if messages:
return openai_complete_if_cache(
Config.Config.GLM_MODEL_NAME,
"",
system_prompt=None,
history_messages=[],
messages=messages,
api_key=Config.Config.GLM_API_KEY,
base_url=Config.Config.GLM_BASE_URL,
**kwargs,
)
# 传统单图片格式
elif image_data:
return openai_complete_if_cache(
Config.Config.GLM_MODEL_NAME,
"",
system_prompt=None,
history_messages=[],
messages=[
{"role": "system", "content": system_prompt}
if system_prompt
else None,
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_data}"
},
},
],
}
if image_data
else {"role": "user", "content": prompt},
],
api_key=Config.Config.GLM_API_KEY,
base_url=Config.Config.GLM_BASE_URL,
**kwargs,
)
# 纯文本格式
else:
return llm_model_func(prompt, system_prompt, history_messages, **kwargs)
# 定义嵌入函数
embedding_func = EmbeddingFunc(
embedding_dim=Config.Config.EMBED_DIM,
max_token_size=Config.Config.EMBED_MAX_TOKEN_SIZE,
func=lambda texts: openai_embed(
texts,
model=Config.Config.EMBED_MODEL_NAME,
api_key=Config.Config.EMBED_API_KEY,
base_url=Config.Config.EMBED_BASE_URL,
),
)
# 初始化 RAGAnything
rag = RAGAnything(
config=config,
llm_model_func=llm_model_func,
vision_model_func=vision_model_func,
embedding_func=embedding_func
)
# 处理文档
await rag.process_document_complete(
file_path=file_path,
output_dir=output_dir,
parse_method="auto"
)
if __name__ == "__main__":
# MinerU生成的临时文件目录
output_dir = "./Output"
# LightRag的数据库所在目录
#working_dir = "./Topic/HuangWanQiao"
working_dir = "./Topic/Geogebra"
# 文档路径
file_path = "./Doc/GeoGebra.pdf"
# 开始训练
asyncio.run(train(file_path, output_dir, working_dir))