瀏覽代碼

Copy-Paste augmentation for YOLOv5 (#3845)

* Copy-paste augmentation initial commit

* if any segments

* Add obscuration rejection

* Add copy_paste hyperparameter

* Update comments
modifyDataloader
Glenn Jocher GitHub 3 年之前
父節點
當前提交
c6c88dc601
沒有發現已知的金鑰在資料庫的簽署中 GPG 金鑰 ID: 4AEE18F83AFDEB23
共有 7 個檔案被更改,包括 58 行新增21 行删除
  1. +1
    -0
      data/hyps/hyp.finetune.yaml
  2. +1
    -0
      data/hyps/hyp.finetune_objects365.yaml
  3. +1
    -0
      data/hyps/hyp.scratch-p6.yaml
  4. +1
    -0
      data/hyps/hyp.scratch.yaml
  5. +3
    -2
      train.py
  6. +26
    -18
      utils/datasets.py
  7. +25
    -1
      utils/metrics.py

+ 1
- 0
data/hyps/hyp.finetune.yaml 查看文件

@@ -36,3 +36,4 @@ flipud: 0.00856
fliplr: 0.5
mosaic: 1.0
mixup: 0.243
copy_paste: 0.0

+ 1
- 0
data/hyps/hyp.finetune_objects365.yaml 查看文件

@@ -26,3 +26,4 @@ flipud: 0.0
fliplr: 0.5
mosaic: 1.0
mixup: 0.0
copy_paste: 0.0

+ 1
- 0
data/hyps/hyp.scratch-p6.yaml 查看文件

@@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)

+ 1
- 0
data/hyps/hyp.scratch.yaml 查看文件

@@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)

+ 3
- 2
train.py 查看文件

@@ -6,7 +6,6 @@ Usage:

import argparse
import logging
import math
import os
import random
import sys
@@ -16,6 +15,7 @@ from copy import deepcopy
from pathlib import Path
from threading import Thread

import math
import numpy as np
import torch.distributed as dist
import torch.nn as nn
@@ -591,7 +591,8 @@ def main(opt):
'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
'mixup': (1, 0.0, 1.0), # image mixup (probability)
'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)

with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict

+ 26
- 18
utils/datasets.py 查看文件

@@ -25,6 +25,7 @@ from tqdm import tqdm

from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
xyn2xy, segment2box, segments2boxes, resample_segments, clean_str
from utils.metrics import bbox_ioa
from utils.torch_utils import torch_distributed_zero_first

# Parameters
@@ -683,6 +684,7 @@ def load_mosaic(self, index):
# img4, labels4 = replicate(img4, labels4) # replicate

# Augment
img4, labels4, segments4 = copy_paste(img4, labels4, segments4, probability=self.hyp['copy_paste'])
img4, labels4 = random_perspective(img4, labels4, segments4,
degrees=self.hyp['degrees'],
translate=self.hyp['translate'],
@@ -907,6 +909,30 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s
return img, targets


def copy_paste(img, labels, segments, probability=0.5):
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
n = len(segments)
if probability and n:
h, w, c = img.shape # height, width, channels
im_new = np.zeros(img.shape, np.uint8)
for j in random.sample(range(n), k=round(probability * n)):
l, s = labels[j], segments[j]
box = w - l[3], l[2], w - l[1], l[4]
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
labels = np.concatenate((labels, [[l[0], *box]]), 0)
segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)

result = cv2.bitwise_and(src1=img, src2=im_new)
result = cv2.flip(result, 1) # augment segments (flip left-right)
i = result > 0 # pixels to replace
# i[:, :] = result.max(2).reshape(h, w, 1) # act over ch
img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug

return img, labels, segments


def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
@@ -919,24 +945,6 @@ def cutout(image, labels):
# Applies image cutout augmentation https://arxiv.org/abs/1708.04552
h, w = image.shape[:2]

def bbox_ioa(box1, box2):
# Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
box2 = box2.transpose()

# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]

# Intersection area
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)

# box2 area
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16

# Intersection over box2 area
return inter_area / box2_area

# create random masks
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
for s in scales:

+ 25
- 1
utils/metrics.py 查看文件

@@ -1,9 +1,9 @@
# Model validation metrics

import math
import warnings
from pathlib import Path

import math
import matplotlib.pyplot as plt
import numpy as np
import torch
@@ -253,6 +253,30 @@ def box_iou(box1, box2):
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)


def bbox_ioa(box1, box2, eps=1E-7):
""" Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
box1: np.array of shape(4)
box2: np.array of shape(nx4)
returns: np.array of shape(n)
"""

box2 = box2.transpose()

# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]

# Intersection area
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)

# box2 area
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps

# Intersection over box2 area
return inter_area / box2_area


def wh_iou(wh1, wh2):
# Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
wh1 = wh1[:, None] # [N,1,2]

Loading…
取消
儲存