Browse Source

Update `probability` to `p` (#3980)

modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
b3dabdcc38
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 39 deletions
  1. +2
    -2
      models/common.py
  2. +32
    -33
      utils/augmentations.py
  3. +3
    -4
      utils/datasets.py

+ 2
- 2
models/common.py View File

@@ -215,7 +215,7 @@ class NMS(nn.Module):


class AutoShape(nn.Module):
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold
classes = None # (optional list) filter by class
@@ -287,7 +287,7 @@ class AutoShape(nn.Module):


class Detections:
# detections class for YOLOv5 inference results
# YOLOv5 detections class for inference results
def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
super(Detections, self).__init__()
d = pred[0].device # device

+ 32
- 33
utils/augmentations.py View File

@@ -50,12 +50,12 @@ def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)

img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed


def hist_equalize(im, clahe=True, bgr=False):
# Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
# Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255
yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
if clahe:
c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
@@ -76,7 +76,7 @@ def replicate(im, labels):
bh, bw = y2b - y1b, x2b - x1b
yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # im4[ymin:ymax, xmin:xmax]
labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)

return im, labels
@@ -162,8 +162,8 @@ def random_perspective(im, targets=(), segments=(), degrees=10, translate=.1, sc
# Visualize
# import matplotlib.pyplot as plt
# ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
# ax[0].imshow(img[:, :, ::-1]) # base
# ax[1].imshow(img2[:, :, ::-1]) # warped
# ax[0].imshow(im[:, :, ::-1]) # base
# ax[1].imshow(im2[:, :, ::-1]) # warped

# Transform label coordinates
n = len(targets)
@@ -204,13 +204,13 @@ def random_perspective(im, targets=(), segments=(), degrees=10, translate=.1, sc
return im, targets


def copy_paste(im, labels, segments, probability=0.5):
def copy_paste(im, labels, segments, p=0.5):
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
n = len(segments)
if probability and n:
if p and n:
h, w, c = im.shape # height, width, channels
im_new = np.zeros(im.shape, np.uint8)
for j in random.sample(range(n), k=round(probability * n)):
for j in random.sample(range(n), k=round(p * n)):
l, s = labels[j], segments[j]
box = w - l[3], l[2], w - l[1], l[4]
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
@@ -223,35 +223,34 @@ def copy_paste(im, labels, segments, probability=0.5):
result = cv2.flip(result, 1) # augment segments (flip left-right)
i = result > 0 # pixels to replace
# i[:, :] = result.max(2).reshape(h, w, 1) # act over ch
im[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug

return im, labels, segments


def cutout(im, labels):
def cutout(im, labels, p=0.5):
# Applies image cutout augmentation https://arxiv.org/abs/1708.04552
h, w = im.shape[:2]

# create random masks
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
for s in scales:
mask_h = random.randint(1, int(h * s))
mask_w = random.randint(1, int(w * s))

# box
xmin = max(0, random.randint(0, w) - mask_w // 2)
ymin = max(0, random.randint(0, h) - mask_h // 2)
xmax = min(w, xmin + mask_w)
ymax = min(h, ymin + mask_h)

# apply random color mask
im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]

# return unobscured labels
if len(labels) and s > 0.03:
box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
labels = labels[ioa < 0.60] # remove >60% obscured labels
if random.random() < p:
h, w = im.shape[:2]
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
for s in scales:
mask_h = random.randint(1, int(h * s)) # create random masks
mask_w = random.randint(1, int(w * s))

# box
xmin = max(0, random.randint(0, w) - mask_w // 2)
ymin = max(0, random.randint(0, h) - mask_h // 2)
xmax = min(w, xmin + mask_w)
ymax = min(h, ymin + mask_h)

# apply random color mask
im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]

# return unobscured labels
if len(labels) and s > 0.03:
box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
labels = labels[ioa < 0.60] # remove >60% obscured labels

return labels


+ 3
- 4
utils/datasets.py View File

@@ -22,7 +22,7 @@ from PIL import Image, ExifTags
from torch.utils.data import Dataset
from tqdm import tqdm

from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective, cutout
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
xyn2xy, segments2boxes, clean_str
from utils.torch_utils import torch_distributed_zero_first
@@ -572,8 +572,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
labels[:, 1] = 1 - labels[:, 1]

# Cutouts
# if random.random() < 0.9:
# labels = cutout(img, labels)
# labels = cutout(img, labels, p=0.5)

labels_out = torch.zeros((nl, 6))
if nl:
@@ -682,7 +681,7 @@ def load_mosaic(self, index):
# img4, labels4 = replicate(img4, labels4) # replicate

# Augment
img4, labels4, segments4 = copy_paste(img4, labels4, segments4, probability=self.hyp['copy_paste'])
img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
img4, labels4 = random_perspective(img4, labels4, segments4,
degrees=self.hyp['degrees'],
translate=self.hyp['translate'],

Loading…
Cancel
Save