Files
dsProject/dsLightRag/Util/GGB/GGB_1_CUT.py
2025-08-14 15:45:08 +08:00

115 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# pip install ultralytics
# 最快YOLOv8-docLayNet已训练好直接下权重
"""
# 直接浏览器或 wget 均可
wget https://huggingface.co/hantian/yolo-doclaynet/resolve/main/yolov8n-doclaynet.pt
下载后将文件放到你的工作目录或 ./weights/ 子目录即可 。
"""
import hashlib
import logging
import os
import cv2
from ultralytics import YOLO
logger = logging.getLogger(__name__)
# 定义需要扩展的像素值
EXTEND_PIXELS = 10
# 图片类型
IMAGE_CLASS_ID = 6
# 文本类型
TEXT_CLASS_ID = 9
# 利用YOLOv8-docLayNet进行图片裁剪
def yoloCut(image_path):
# 创建保存截取图片的目录
md5 = hashlib.md5(image_path.encode()).hexdigest() # input_path的md5值
output_dir = f'extracted/{md5}'
processed_image_path = os.path.join(output_dir, 'processed_image.png')
# 检查输出目录和处理后的图像是否存在
if os.path.exists(output_dir) and os.path.exists(processed_image_path):
# 尝试获取已存在的图像文件列表
img_list = []
for file_name in os.listdir(output_dir):
if file_name.startswith('image_') and file_name.endswith('.png'):
img_path = os.path.join(output_dir, file_name)
img_path = img_path.replace('\\', '/')
img_list.append(img_path)
# 如果找到至少一个image_*.png文件则认为文件集完整
if img_list:
logger.info(f"检测到已存在的输出文件,直接返回: {output_dir}")
return output_dir, processed_image_path, img_list
# 如果目录不存在或文件不完整,则创建目录并执行检测
os.makedirs(output_dir, exist_ok=True)
# 模型位置
model = YOLO(r"D:\Model\yolov8n-doclaynet.pt")
result = model(image_path)
# 读取原始图像
original_image = cv2.imread(image_path)
if original_image is None:
logger.error(f"无法读取图像文件: {image_path}")
exit(1)
# 创建原始图像的副本用于处理
processed_image = original_image.copy()
img_list = []
# 初始化图像计数器
image_counter = 0
if result and len(result) > 0:
detection_result = result[0]
if hasattr(detection_result, 'boxes') and len(detection_result.boxes) > 0:
# 先收集所有图像区域
image_regions = []
for i, box in enumerate(detection_result.boxes):
class_id = int(box.cls[0])
# 只处理图像类别
if class_id == IMAGE_CLASS_ID:
# 提取边界框坐标
x1, y1, x2, y2 = box.xyxy[0].tolist()
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
# 扩展边界
x1_extended = max(0, x1 - EXTEND_PIXELS)
y1_extended = max(0, y1 - EXTEND_PIXELS)
x2_extended = min(original_image.shape[1], x2 + EXTEND_PIXELS)
y2_extended = min(original_image.shape[0], y2 + EXTEND_PIXELS)
image_regions.append((x1_extended, y1_extended, x2_extended, y2_extended))
# 按y坐标排序图像区域从上到下
image_regions.sort(key=lambda r: r[1])
# 处理每个图像区域
for x1_extended, y1_extended, x2_extended, y2_extended in image_regions:
# 裁剪扩展后的区域
cropped_region = original_image[y1_extended:y2_extended, x1_extended:x2_extended]
# 增加计数器
image_counter += 1
# 生成保存路径
save_path = os.path.join(output_dir, f"image_{image_counter}.png")
save_path = save_path.replace('\\', '/')
img_list.append(save_path)
# 保存截取的图片
cv2.imwrite(save_path, cropped_region)
# 在处理后的图像中用白色填充该区域
cv2.rectangle(processed_image, (x1_extended, y1_extended), (x2_extended, y2_extended), (255, 255, 255),
-1)
# 保存处理后的图像
cv2.imwrite(processed_image_path, processed_image)
else:
logger.info("未检测到任何区域")
return output_dir, processed_image_path, img_list