Browse Source

YOLOv5 + Albumentations integration (#3882)

* Albumentations integration

* ToGray p=0.01

* print confirmation

* create instance in dataloader init method

* improved version handling

* transform not defined fix

* assert string update

* create check_version()

* add spaces

* update class comment
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
33202b7f0b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 27 deletions
  1. +1
    -0
      requirements.txt
  2. +29
    -1
      utils/augmentations.py
  3. +21
    -19
      utils/datasets.py
  4. +10
    -7
      utils/general.py

+ 1
- 0
requirements.txt View File

@@ -27,4 +27,5 @@ pandas
# extras --------------------------------------
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
# pycocotools>=2.0 # COCO mAP
# albumentations>=1.0.0
thop # FLOPs computation

+ 29
- 1
utils/augmentations.py View File

@@ -1,15 +1,43 @@
# YOLOv5 image augmentation functions

import logging
import random

import cv2
import math
import numpy as np

from utils.general import segment2box, resample_segments
from utils.general import colorstr, segment2box, resample_segments, check_version
from utils.metrics import bbox_ioa


class Albumentations:
# YOLOv5 Albumentations class (optional, only used if package is installed)
def __init__(self):
self.transform = None
try:
import albumentations as A
check_version(A.__version__, '1.0.0') # version requirement

self.transform = A.Compose([
A.Blur(p=0.1),
A.MedianBlur(p=0.1),
A.ToGray(p=0.01)],
bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms))
except ImportError: # package not installed, skip
pass
except Exception as e:
logging.info(colorstr('albumentations: ') + f'{e}')

def __call__(self, im, labels, p=1.0):
if self.transform and random.random() < p:
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
return im, labels


def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
# HSV color-space augmentation
if hgain or sgain or vgain:

+ 21
- 19
utils/datasets.py View File

@@ -22,7 +22,7 @@ from PIL import Image, ExifTags
from torch.utils.data import Dataset
from tqdm import tqdm

from utils.augmentations import augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
xyn2xy, segments2boxes, clean_str
from utils.torch_utils import torch_distributed_zero_first
@@ -372,6 +372,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
self.path = path
self.albumentations = Albumentations() if augment else None

try:
f = [] # image files
@@ -539,9 +540,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
if labels.size: # normalized xywh to pixel xyxy format
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

if self.augment:
# Augment imagespace
if not mosaic:
if self.augment:
img, labels = random_perspective(img, labels,
degrees=hyp['degrees'],
translate=hyp['translate'],
@@ -549,32 +548,35 @@ class LoadImagesAndLabels(Dataset): # for training/testing
shear=hyp['shear'],
perspective=hyp['perspective'])

# Augment colorspace
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

# Apply cutouts
# if random.random() < 0.9:
# labels = cutout(img, labels)

nL = len(labels) # number of labels
if nL:
nl = len(labels) # number of labels
if nl:
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized

if self.augment:
# flip up-down
# Albumentations
img, labels = self.albumentations(img, labels)

# HSV color-space
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

# Flip up-down
if random.random() < hyp['flipud']:
img = np.flipud(img)
if nL:
if nl:
labels[:, 2] = 1 - labels[:, 2]

# flip left-right
# Flip left-right
if random.random() < hyp['fliplr']:
img = np.fliplr(img)
if nL:
if nl:
labels[:, 1] = 1 - labels[:, 1]

labels_out = torch.zeros((nL, 6))
if nL:
# Cutouts
# if random.random() < 0.9:
# labels = cutout(img, labels)

labels_out = torch.zeros((nl, 6))
if nl:
labels_out[:, 1:] = torch.from_numpy(labels)

# Convert

+ 10
- 7
utils/general.py View File

@@ -3,7 +3,6 @@
import contextlib
import glob
import logging
import math
import os
import platform
import random
@@ -17,6 +16,7 @@ from pathlib import Path
from subprocess import check_output

import cv2
import math
import numpy as np
import pandas as pd
import pkg_resources as pkg
@@ -136,13 +136,16 @@ def check_git_status(err_msg=', for updates see https://github.com/ultralytics/y
print(f'{e}{err_msg}')


def check_python(minimum='3.6.2', required=True):
def check_python(minimum='3.6.2'):
# Check current python version vs. required python version
current = platform.python_version()
result = pkg.parse_version(current) >= pkg.parse_version(minimum)
if required:
assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed'
return result
check_version(platform.python_version(), minimum, name='Python ')


def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
# Check version vs. required version
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum)
assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'


def check_requirements(requirements='requirements.txt', exclude=()):

Loading…
Cancel
Save