106 lines
4.0 KiB
Python
106 lines
4.0 KiB
Python
# pip install ultralytics
|
||
# 最快:YOLOv8-docLayNet(已训练好,直接下权重)
|
||
|
||
"""
|
||
# 直接浏览器或 wget 均可
|
||
wget https://huggingface.co/hantian/yolo-doclaynet/resolve/main/yolov8n-doclaynet.pt
|
||
|
||
下载后将文件放到你的工作目录或 ./weights/ 子目录即可 。
|
||
|
||
#命令
|
||
yolo task=detect mode=predict model=yolov8n-docLayNet.pt source=D:\dsWork\dsProject\dsLightRag\Test\split_images\segment_5.png imgsz=1280 save_txt
|
||
"""
|
||
import logging
|
||
|
||
from ultralytics import YOLO
|
||
import cv2
|
||
import os
|
||
|
||
# 直接获取模块专属日志器(无需重复配置)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def yolo_cut(image_path, md5):
|
||
# 创建保存截取图片的目录
|
||
OUTPUT_DIR = f'extracted/{md5}'
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
# 文件名称
|
||
fNum = image_path.replace('\\', '/').split('/')[-1].split('.')[0]
|
||
|
||
# 定义目标类别: ID -> 名称映射
|
||
TARGET_CLASSES = {
|
||
6: "IMG", # 几何图形
|
||
9: "TXT" # 文本
|
||
}
|
||
# 初始化类别计数器
|
||
box_counters = {class_id: 0 for class_id in TARGET_CLASSES.keys()}
|
||
|
||
model = YOLO(r"D:\Model\yolov8n-doclaynet.pt")
|
||
# 计算image_path中文件的前缀md5值,就是_前面的内容
|
||
md5 = image_path.split("\\")[-1].split("_")[0]
|
||
result = model(image_path)
|
||
|
||
original_image = cv2.imread(image_path)
|
||
if original_image is None:
|
||
logger.error(f"无法读取图像文件: {image_path}")
|
||
exit(1)
|
||
|
||
if result and len(result) > 0:
|
||
detection_result = result[0]
|
||
|
||
if hasattr(detection_result, 'boxes') and len(detection_result.boxes) > 0:
|
||
logger.info(f"开始筛选并截取类别ID={list(TARGET_CLASSES.keys())}的区域...")
|
||
cnt = 0
|
||
for i, box in enumerate(detection_result.boxes):
|
||
class_id = int(box.cls[0])
|
||
|
||
if class_id in TARGET_CLASSES:
|
||
# 获取类别名称和计数器
|
||
class_name = TARGET_CLASSES[class_id]
|
||
box_counters[class_id] += 1
|
||
current_count = box_counters[class_id]
|
||
cnt += 1
|
||
# 提取边界框坐标
|
||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
||
confidence = box.conf[0].item()
|
||
|
||
# 确保坐标在图像范围内
|
||
x1 = max(0, x1)
|
||
y1 = max(0, y1)
|
||
x2 = min(original_image.shape[1], x2)
|
||
y2 = min(original_image.shape[0], y2)
|
||
|
||
# 裁剪区域
|
||
cropped_region = original_image[y1:y2, x1:x2]
|
||
|
||
# 生成保存路径 (类别名称+序号)
|
||
save_path = os.path.join(OUTPUT_DIR, f"{fNum}_{cnt}_{class_name}.png")
|
||
|
||
# 保存截取的图片
|
||
cv2.imwrite(save_path, cropped_region)
|
||
|
||
logger.info(f"{class_name}区域 {current_count}:")
|
||
logger.info(f" 坐标: (左上角: ({x1}, {y1}), 右下角: ({x2}, {y2}))")
|
||
logger.info(f" 尺寸: {x2 - x1}px × {y2 - y1}px")
|
||
logger.info(f" 置信度: {confidence:.4f}")
|
||
logger.info(f" 已保存至: {save_path}\n")
|
||
|
||
# 打印各类别统计结果
|
||
for class_id, class_name in TARGET_CLASSES.items():
|
||
count = box_counters[class_id]
|
||
if count == 0:
|
||
logger.info(f"未检测到{class_name}区域 (类别ID={class_id})")
|
||
else:
|
||
logger.info(f"已保存至: {os.path.abspath(OUTPUT_DIR)}")
|
||
|
||
logger.info(f"所有结果已保存至: {os.path.abspath(OUTPUT_DIR)}")
|
||
else:
|
||
logger.info("未检测到任何区域")
|
||
else:
|
||
logger.info("未获取到有效的检测结果")
|
||
return OUTPUT_DIR
|
||
|
||
# result[0].show()
|