Files
dsProject/dsLightRag/Util/OCR_URL_3_YoloCut.py

106 lines
4.0 KiB
Python
Raw Normal View History

2025-08-14 15:45:08 +08:00
# pip install ultralytics
# 最快YOLOv8-docLayNet已训练好直接下权重
"""
# 直接浏览器或 wget 均可
wget https://huggingface.co/hantian/yolo-doclaynet/resolve/main/yolov8n-doclaynet.pt
下载后将文件放到你的工作目录或 ./weights/ 子目录即可
#命令
yolo task=detect mode=predict model=yolov8n-docLayNet.pt source=D:\dsWork\dsProject\dsLightRag\Test\split_images\segment_5.png imgsz=1280 save_txt
"""
import logging
from ultralytics import YOLO
import cv2
import os
# 直接获取模块专属日志器(无需重复配置)
logger = logging.getLogger(__name__)
def yolo_cut(image_path, md5):
# 创建保存截取图片的目录
OUTPUT_DIR = f'extracted/{md5}'
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 文件名称
fNum = image_path.replace('\\', '/').split('/')[-1].split('.')[0]
# 定义目标类别: ID -> 名称映射
TARGET_CLASSES = {
6: "IMG", # 几何图形
9: "TXT" # 文本
}
# 初始化类别计数器
box_counters = {class_id: 0 for class_id in TARGET_CLASSES.keys()}
model = YOLO(r"D:\Model\yolov8n-doclaynet.pt")
# 计算image_path中文件的前缀md5值就是_前面的内容
md5 = image_path.split("\\")[-1].split("_")[0]
result = model(image_path)
original_image = cv2.imread(image_path)
if original_image is None:
logger.error(f"无法读取图像文件: {image_path}")
exit(1)
if result and len(result) > 0:
detection_result = result[0]
if hasattr(detection_result, 'boxes') and len(detection_result.boxes) > 0:
logger.info(f"开始筛选并截取类别ID={list(TARGET_CLASSES.keys())}的区域...")
cnt = 0
for i, box in enumerate(detection_result.boxes):
class_id = int(box.cls[0])
if class_id in TARGET_CLASSES:
# 获取类别名称和计数器
class_name = TARGET_CLASSES[class_id]
box_counters[class_id] += 1
current_count = box_counters[class_id]
cnt += 1
# 提取边界框坐标
x1, y1, x2, y2 = box.xyxy[0].tolist()
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
confidence = box.conf[0].item()
# 确保坐标在图像范围内
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(original_image.shape[1], x2)
y2 = min(original_image.shape[0], y2)
# 裁剪区域
cropped_region = original_image[y1:y2, x1:x2]
# 生成保存路径 (类别名称+序号)
save_path = os.path.join(OUTPUT_DIR, f"{fNum}_{cnt}_{class_name}.png")
# 保存截取的图片
cv2.imwrite(save_path, cropped_region)
logger.info(f"{class_name}区域 {current_count}:")
logger.info(f" 坐标: (左上角: ({x1}, {y1}), 右下角: ({x2}, {y2}))")
logger.info(f" 尺寸: {x2 - x1}px × {y2 - y1}px")
logger.info(f" 置信度: {confidence:.4f}")
logger.info(f" 已保存至: {save_path}\n")
# 打印各类别统计结果
for class_id, class_name in TARGET_CLASSES.items():
count = box_counters[class_id]
if count == 0:
logger.info(f"未检测到{class_name}区域 (类别ID={class_id})")
else:
logger.info(f"已保存至: {os.path.abspath(OUTPUT_DIR)}")
logger.info(f"所有结果已保存至: {os.path.abspath(OUTPUT_DIR)}")
else:
logger.info("未检测到任何区域")
else:
logger.info("未获取到有效的检测结果")
return OUTPUT_DIR
# result[0].show()