Files
dsProject/dsLightRag/Util/OCR_URL_3_YoloCut.py
2025-08-14 15:45:08 +08:00

106 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# pip install ultralytics
# 最快YOLOv8-docLayNet已训练好直接下权重
"""
# 直接浏览器或 wget 均可
wget https://huggingface.co/hantian/yolo-doclaynet/resolve/main/yolov8n-doclaynet.pt
下载后将文件放到你的工作目录或 ./weights/ 子目录即可 。
#命令
yolo task=detect mode=predict model=yolov8n-docLayNet.pt source=D:\dsWork\dsProject\dsLightRag\Test\split_images\segment_5.png imgsz=1280 save_txt
"""
import logging
from ultralytics import YOLO
import cv2
import os
# 直接获取模块专属日志器(无需重复配置)
logger = logging.getLogger(__name__)
def yolo_cut(image_path, md5):
# 创建保存截取图片的目录
OUTPUT_DIR = f'extracted/{md5}'
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 文件名称
fNum = image_path.replace('\\', '/').split('/')[-1].split('.')[0]
# 定义目标类别: ID -> 名称映射
TARGET_CLASSES = {
6: "IMG", # 几何图形
9: "TXT" # 文本
}
# 初始化类别计数器
box_counters = {class_id: 0 for class_id in TARGET_CLASSES.keys()}
model = YOLO(r"D:\Model\yolov8n-doclaynet.pt")
# 计算image_path中文件的前缀md5值就是_前面的内容
md5 = image_path.split("\\")[-1].split("_")[0]
result = model(image_path)
original_image = cv2.imread(image_path)
if original_image is None:
logger.error(f"无法读取图像文件: {image_path}")
exit(1)
if result and len(result) > 0:
detection_result = result[0]
if hasattr(detection_result, 'boxes') and len(detection_result.boxes) > 0:
logger.info(f"开始筛选并截取类别ID={list(TARGET_CLASSES.keys())}的区域...")
cnt = 0
for i, box in enumerate(detection_result.boxes):
class_id = int(box.cls[0])
if class_id in TARGET_CLASSES:
# 获取类别名称和计数器
class_name = TARGET_CLASSES[class_id]
box_counters[class_id] += 1
current_count = box_counters[class_id]
cnt += 1
# 提取边界框坐标
x1, y1, x2, y2 = box.xyxy[0].tolist()
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
confidence = box.conf[0].item()
# 确保坐标在图像范围内
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(original_image.shape[1], x2)
y2 = min(original_image.shape[0], y2)
# 裁剪区域
cropped_region = original_image[y1:y2, x1:x2]
# 生成保存路径 (类别名称+序号)
save_path = os.path.join(OUTPUT_DIR, f"{fNum}_{cnt}_{class_name}.png")
# 保存截取的图片
cv2.imwrite(save_path, cropped_region)
logger.info(f"{class_name}区域 {current_count}:")
logger.info(f" 坐标: (左上角: ({x1}, {y1}), 右下角: ({x2}, {y2}))")
logger.info(f" 尺寸: {x2 - x1}px × {y2 - y1}px")
logger.info(f" 置信度: {confidence:.4f}")
logger.info(f" 已保存至: {save_path}\n")
# 打印各类别统计结果
for class_id, class_name in TARGET_CLASSES.items():
count = box_counters[class_id]
if count == 0:
logger.info(f"未检测到{class_name}区域 (类别ID={class_id})")
else:
logger.info(f"已保存至: {os.path.abspath(OUTPUT_DIR)}")
logger.info(f"所有结果已保存至: {os.path.abspath(OUTPUT_DIR)}")
else:
logger.info("未检测到任何区域")
else:
logger.info("未获取到有效的检测结果")
return OUTPUT_DIR
# result[0].show()