115 lines
4.4 KiB
Python
115 lines
4.4 KiB
Python
# pip install ultralytics
|
||
# 最快:YOLOv8-docLayNet(已训练好,直接下权重)
|
||
|
||
"""
|
||
# 直接浏览器或 wget 均可
|
||
wget https://huggingface.co/hantian/yolo-doclaynet/resolve/main/yolov8n-doclaynet.pt
|
||
|
||
下载后将文件放到你的工作目录或 ./weights/ 子目录即可 。
|
||
"""
|
||
import hashlib
|
||
import logging
|
||
import os
|
||
|
||
import cv2
|
||
from ultralytics import YOLO
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 定义需要扩展的像素值
|
||
EXTEND_PIXELS = 10
|
||
# 图片类型
|
||
IMAGE_CLASS_ID = 6
|
||
# 文本类型
|
||
TEXT_CLASS_ID = 9
|
||
|
||
|
||
# 利用YOLOv8-docLayNet进行图片裁剪
|
||
def yoloCut(image_path):
|
||
# 创建保存截取图片的目录
|
||
md5 = hashlib.md5(image_path.encode()).hexdigest() # input_path的md5值
|
||
output_dir = f'extracted/{md5}'
|
||
processed_image_path = os.path.join(output_dir, 'processed_image.png')
|
||
|
||
# 检查输出目录和处理后的图像是否存在
|
||
if os.path.exists(output_dir) and os.path.exists(processed_image_path):
|
||
# 尝试获取已存在的图像文件列表
|
||
img_list = []
|
||
for file_name in os.listdir(output_dir):
|
||
if file_name.startswith('image_') and file_name.endswith('.png'):
|
||
img_path = os.path.join(output_dir, file_name)
|
||
img_path = img_path.replace('\\', '/')
|
||
img_list.append(img_path)
|
||
|
||
# 如果找到至少一个image_*.png文件,则认为文件集完整
|
||
if img_list:
|
||
logger.info(f"检测到已存在的输出文件,直接返回: {output_dir}")
|
||
return output_dir, processed_image_path, img_list
|
||
|
||
# 如果目录不存在或文件不完整,则创建目录并执行检测
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 模型位置
|
||
model = YOLO(r"D:\Model\yolov8n-doclaynet.pt")
|
||
result = model(image_path)
|
||
|
||
# 读取原始图像
|
||
original_image = cv2.imread(image_path)
|
||
if original_image is None:
|
||
logger.error(f"无法读取图像文件: {image_path}")
|
||
exit(1)
|
||
|
||
# 创建原始图像的副本用于处理
|
||
processed_image = original_image.copy()
|
||
|
||
img_list = []
|
||
# 初始化图像计数器
|
||
image_counter = 0
|
||
|
||
if result and len(result) > 0:
|
||
detection_result = result[0]
|
||
if hasattr(detection_result, 'boxes') and len(detection_result.boxes) > 0:
|
||
# 先收集所有图像区域
|
||
image_regions = []
|
||
for i, box in enumerate(detection_result.boxes):
|
||
class_id = int(box.cls[0])
|
||
|
||
# 只处理图像类别
|
||
if class_id == IMAGE_CLASS_ID:
|
||
# 提取边界框坐标
|
||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
||
|
||
# 扩展边界
|
||
x1_extended = max(0, x1 - EXTEND_PIXELS)
|
||
y1_extended = max(0, y1 - EXTEND_PIXELS)
|
||
x2_extended = min(original_image.shape[1], x2 + EXTEND_PIXELS)
|
||
y2_extended = min(original_image.shape[0], y2 + EXTEND_PIXELS)
|
||
|
||
image_regions.append((x1_extended, y1_extended, x2_extended, y2_extended))
|
||
|
||
# 按y坐标排序图像区域(从上到下)
|
||
image_regions.sort(key=lambda r: r[1])
|
||
|
||
# 处理每个图像区域
|
||
for x1_extended, y1_extended, x2_extended, y2_extended in image_regions:
|
||
# 裁剪扩展后的区域
|
||
cropped_region = original_image[y1_extended:y2_extended, x1_extended:x2_extended]
|
||
# 增加计数器
|
||
image_counter += 1
|
||
# 生成保存路径
|
||
save_path = os.path.join(output_dir, f"image_{image_counter}.png")
|
||
save_path = save_path.replace('\\', '/')
|
||
img_list.append(save_path)
|
||
# 保存截取的图片
|
||
cv2.imwrite(save_path, cropped_region)
|
||
# 在处理后的图像中用白色填充该区域
|
||
cv2.rectangle(processed_image, (x1_extended, y1_extended), (x2_extended, y2_extended), (255, 255, 255),
|
||
-1)
|
||
|
||
# 保存处理后的图像
|
||
cv2.imwrite(processed_image_path, processed_image)
|
||
else:
|
||
logger.info("未检测到任何区域")
|
||
return output_dir, processed_image_path, img_list
|