Browse Source

Copy-Paste augmentation for YOLOv5 (#3845)

* Copy-paste augmentation initial commit

* if any segments

* Add obscuration rejection

* Add copy_paste hyperparameter

* Update comments
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
c6c88dc601
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 58 additions and 21 deletions
  1. +1
    -0
      data/hyps/hyp.finetune.yaml
  2. +1
    -0
      data/hyps/hyp.finetune_objects365.yaml
  3. +1
    -0
      data/hyps/hyp.scratch-p6.yaml
  4. +1
    -0
      data/hyps/hyp.scratch.yaml
  5. +3
    -2
      train.py
  6. +26
    -18
      utils/datasets.py
  7. +25
    -1
      utils/metrics.py

+ 1
- 0
data/hyps/hyp.finetune.yaml View File

fliplr: 0.5 fliplr: 0.5
mosaic: 1.0 mosaic: 1.0
mixup: 0.243 mixup: 0.243
copy_paste: 0.0

+ 1
- 0
data/hyps/hyp.finetune_objects365.yaml View File

fliplr: 0.5 fliplr: 0.5
mosaic: 1.0 mosaic: 1.0
mixup: 0.0 mixup: 0.0
copy_paste: 0.0

+ 1
- 0
data/hyps/hyp.scratch-p6.yaml View File

fliplr: 0.5 # image flip left-right (probability) fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability) mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)

+ 1
- 0
data/hyps/hyp.scratch.yaml View File

fliplr: 0.5 # image flip left-right (probability) fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability) mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)

+ 3
- 2
train.py View File



import argparse import argparse
import logging import logging
import math
import os import os
import random import random
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread


import math
import numpy as np import numpy as np
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
'flipud': (1, 0.0, 1.0), # image flip up-down (probability) 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability) 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
'mosaic': (1, 0.0, 1.0), # image mixup (probability) 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
'mixup': (1, 0.0, 1.0), # image mixup (probability)
'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)


with open(opt.hyp) as f: with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict hyp = yaml.safe_load(f) # load hyps dict

+ 26
- 18
utils/datasets.py View File



from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \ from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
xyn2xy, segment2box, segments2boxes, resample_segments, clean_str xyn2xy, segment2box, segments2boxes, resample_segments, clean_str
from utils.metrics import bbox_ioa
from utils.torch_utils import torch_distributed_zero_first from utils.torch_utils import torch_distributed_zero_first


# Parameters # Parameters
# img4, labels4 = replicate(img4, labels4) # replicate # img4, labels4 = replicate(img4, labels4) # replicate


# Augment # Augment
img4, labels4, segments4 = copy_paste(img4, labels4, segments4, probability=self.hyp['copy_paste'])
img4, labels4 = random_perspective(img4, labels4, segments4, img4, labels4 = random_perspective(img4, labels4, segments4,
degrees=self.hyp['degrees'], degrees=self.hyp['degrees'],
translate=self.hyp['translate'], translate=self.hyp['translate'],
return img, targets return img, targets




def copy_paste(img, labels, segments, probability=0.5):
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
n = len(segments)
if probability and n:
h, w, c = img.shape # height, width, channels
im_new = np.zeros(img.shape, np.uint8)
for j in random.sample(range(n), k=round(probability * n)):
l, s = labels[j], segments[j]
box = w - l[3], l[2], w - l[1], l[4]
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
labels = np.concatenate((labels, [[l[0], *box]]), 0)
segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)

result = cv2.bitwise_and(src1=img, src2=im_new)
result = cv2.flip(result, 1) # augment segments (flip left-right)
i = result > 0 # pixels to replace
# i[:, :] = result.max(2).reshape(h, w, 1) # act over ch
img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug

return img, labels, segments


def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
w1, h1 = box1[2] - box1[0], box1[3] - box1[1] w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
# Applies image cutout augmentation https://arxiv.org/abs/1708.04552 # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
h, w = image.shape[:2] h, w = image.shape[:2]


def bbox_ioa(box1, box2):
# Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
box2 = box2.transpose()

# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]

# Intersection area
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)

# box2 area
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16

# Intersection over box2 area
return inter_area / box2_area

# create random masks # create random masks
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
for s in scales: for s in scales:

+ 25
- 1
utils/metrics.py View File

# Model validation metrics # Model validation metrics


import math
import warnings import warnings
from pathlib import Path from pathlib import Path


import math
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)




def bbox_ioa(box1, box2, eps=1E-7):
""" Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
box1: np.array of shape(4)
box2: np.array of shape(nx4)
returns: np.array of shape(n)
"""

box2 = box2.transpose()

# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]

# Intersection area
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)

# box2 area
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps

# Intersection over box2 area
return inter_area / box2_area


def wh_iou(wh1, wh2): def wh_iou(wh1, wh2):
# Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2 # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
wh1 = wh1[:, None] # [N,1,2] wh1 = wh1[:, None] # [N,1,2]

Loading…
Cancel
Save