You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

176 lines
6.7 KiB

3 weeks ago
"""
Batch processing functionality for RAGAnything
Contains methods for processing multiple files in batches
"""
import asyncio
from typing import Optional, List
from pathlib import Path
class BatchMixin:
"""BatchMixin class containing batch processing functionality for RAGAnything"""
async def process_folder_complete(
self,
folder_path: str,
output_dir: str = None,
parse_method: str = None,
display_stats: bool = None,
split_by_character: str | None = None,
split_by_character_only: bool = False,
file_extensions: Optional[List[str]] = None,
recursive: bool = None,
max_workers: int = None,
):
"""
Process all files in a folder in batch
Args:
folder_path: Path to the folder to process
output_dir: MinerU output directory (defaults to config.mineru_output_dir)
parse_method: Parse method (defaults to config.mineru_parse_method)
display_stats: Whether to display content statistics for each file (defaults to False for batch processing)
split_by_character: Optional character to split text by
split_by_character_only: If True, split only by the specified character
file_extensions: List of file extensions to process (defaults to config.supported_file_extensions)
recursive: Whether to recursively process subfolders (defaults to config.recursive_folder_processing)
max_workers: Maximum number of concurrent workers (defaults to config.max_concurrent_files)
"""
# Ensure LightRAG is initialized
await self._ensure_lightrag_initialized()
# Use config defaults if not provided
if output_dir is None:
output_dir = self.config.mineru_output_dir
if parse_method is None:
parse_method = self.config.mineru_parse_method
if display_stats is None:
display_stats = False # Default to False for batch processing
if recursive is None:
recursive = self.config.recursive_folder_processing
if max_workers is None:
max_workers = self.config.max_concurrent_files
if file_extensions is None:
file_extensions = self.config.supported_file_extensions
folder_path = Path(folder_path)
if not folder_path.exists() or not folder_path.is_dir():
raise ValueError(
f"Folder does not exist or is not a valid directory: {folder_path}"
)
# Convert file extensions to set for faster lookup
target_extensions = set(ext.lower().strip() for ext in file_extensions)
# Log the extensions being used
self.logger.info(
f"Processing files with extensions: {sorted(target_extensions)}"
)
# Collect all files to process
files_to_process = []
if recursive:
# Recursively traverse all subfolders
for file_path in folder_path.rglob("*"):
if (
file_path.is_file()
and file_path.suffix.lower() in target_extensions
):
files_to_process.append(file_path)
else:
# Process only current folder
for file_path in folder_path.glob("*"):
if (
file_path.is_file()
and file_path.suffix.lower() in target_extensions
):
files_to_process.append(file_path)
if not files_to_process:
self.logger.info(f"No files to process found in {folder_path}")
return
self.logger.info(f"Found {len(files_to_process)} files to process")
self.logger.info("File type distribution:")
# Count file types
file_type_count = {}
for file_path in files_to_process:
ext = file_path.suffix.lower()
file_type_count[ext] = file_type_count.get(ext, 0) + 1
for ext, count in sorted(file_type_count.items()):
self.logger.info(f" {ext}: {count} files")
# Create progress tracking
processed_count = 0
failed_files = []
# Use semaphore to control concurrency
semaphore = asyncio.Semaphore(max_workers)
async def process_single_file(file_path: Path, index: int) -> None:
"""Process a single file"""
async with semaphore:
nonlocal processed_count
try:
self.logger.info(
f"[{index}/{len(files_to_process)}] Processing: {file_path}"
)
# Create separate output directory for each file
file_output_dir = Path(output_dir) / file_path.stem
file_output_dir.mkdir(parents=True, exist_ok=True)
# Process file
await self.process_document_complete(
file_path=str(file_path),
output_dir=str(file_output_dir),
parse_method=parse_method,
display_stats=display_stats,
split_by_character=split_by_character,
split_by_character_only=split_by_character_only,
)
processed_count += 1
self.logger.info(
f"[{index}/{len(files_to_process)}] Successfully processed: {file_path}"
)
except Exception as e:
self.logger.error(
f"[{index}/{len(files_to_process)}] Failed to process: {file_path}"
)
self.logger.error(f"Error: {str(e)}")
failed_files.append((file_path, str(e)))
# Create all processing tasks
tasks = []
for index, file_path in enumerate(files_to_process, 1):
task = process_single_file(file_path, index)
tasks.append(task)
# Wait for all tasks to complete
await asyncio.gather(*tasks, return_exceptions=True)
# Output processing statistics
self.logger.info("\n===== Batch Processing Complete =====")
self.logger.info(f"Total files: {len(files_to_process)}")
self.logger.info(f"Successfully processed: {processed_count}")
self.logger.info(f"Failed: {len(failed_files)}")
if failed_files:
self.logger.info("\nFailed files:")
for file_path, error in failed_files:
self.logger.info(f" - {file_path}: {error}")
return {
"total": len(files_to_process),
"success": processed_count,
"failed": len(failed_files),
"failed_files": failed_files,
}