* Albumentations integration * ToGray p=0.01 * print confirmation * create instance in dataloader init method * improved version handling * transform not defined fix * assert string update * create check_version() * add spaces * update class commentmodifyDataloader
@@ -27,4 +27,5 @@ pandas | |||
# extras -------------------------------------- | |||
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172 | |||
# pycocotools>=2.0 # COCO mAP | |||
# albumentations>=1.0.0 | |||
thop # FLOPs computation |
@@ -1,15 +1,43 @@ | |||
# YOLOv5 image augmentation functions | |||
import logging | |||
import random | |||
import cv2 | |||
import math | |||
import numpy as np | |||
from utils.general import segment2box, resample_segments | |||
from utils.general import colorstr, segment2box, resample_segments, check_version | |||
from utils.metrics import bbox_ioa | |||
class Albumentations: | |||
# YOLOv5 Albumentations class (optional, only used if package is installed) | |||
def __init__(self): | |||
self.transform = None | |||
try: | |||
import albumentations as A | |||
check_version(A.__version__, '1.0.0') # version requirement | |||
self.transform = A.Compose([ | |||
A.Blur(p=0.1), | |||
A.MedianBlur(p=0.1), | |||
A.ToGray(p=0.01)], | |||
bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])) | |||
logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms)) | |||
except ImportError: # package not installed, skip | |||
pass | |||
except Exception as e: | |||
logging.info(colorstr('albumentations: ') + f'{e}') | |||
def __call__(self, im, labels, p=1.0): | |||
if self.transform and random.random() < p: | |||
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed | |||
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])]) | |||
return im, labels | |||
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5): | |||
# HSV color-space augmentation | |||
if hgain or sgain or vgain: |
@@ -22,7 +22,7 @@ from PIL import Image, ExifTags | |||
from torch.utils.data import Dataset | |||
from tqdm import tqdm | |||
from utils.augmentations import augment_hsv, copy_paste, letterbox, mixup, random_perspective | |||
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective | |||
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \ | |||
xyn2xy, segments2boxes, clean_str | |||
from utils.torch_utils import torch_distributed_zero_first | |||
@@ -372,6 +372,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
self.mosaic_border = [-img_size // 2, -img_size // 2] | |||
self.stride = stride | |||
self.path = path | |||
self.albumentations = Albumentations() if augment else None | |||
try: | |||
f = [] # image files | |||
@@ -539,9 +540,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
if labels.size: # normalized xywh to pixel xyxy format | |||
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1]) | |||
if self.augment: | |||
# Augment imagespace | |||
if not mosaic: | |||
if self.augment: | |||
img, labels = random_perspective(img, labels, | |||
degrees=hyp['degrees'], | |||
translate=hyp['translate'], | |||
@@ -549,32 +548,35 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
shear=hyp['shear'], | |||
perspective=hyp['perspective']) | |||
# Augment colorspace | |||
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v']) | |||
# Apply cutouts | |||
# if random.random() < 0.9: | |||
# labels = cutout(img, labels) | |||
nL = len(labels) # number of labels | |||
if nL: | |||
nl = len(labels) # number of labels | |||
if nl: | |||
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized | |||
if self.augment: | |||
# flip up-down | |||
# Albumentations | |||
img, labels = self.albumentations(img, labels) | |||
# HSV color-space | |||
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v']) | |||
# Flip up-down | |||
if random.random() < hyp['flipud']: | |||
img = np.flipud(img) | |||
if nL: | |||
if nl: | |||
labels[:, 2] = 1 - labels[:, 2] | |||
# flip left-right | |||
# Flip left-right | |||
if random.random() < hyp['fliplr']: | |||
img = np.fliplr(img) | |||
if nL: | |||
if nl: | |||
labels[:, 1] = 1 - labels[:, 1] | |||
labels_out = torch.zeros((nL, 6)) | |||
if nL: | |||
# Cutouts | |||
# if random.random() < 0.9: | |||
# labels = cutout(img, labels) | |||
labels_out = torch.zeros((nl, 6)) | |||
if nl: | |||
labels_out[:, 1:] = torch.from_numpy(labels) | |||
# Convert |
@@ -3,7 +3,6 @@ | |||
import contextlib | |||
import glob | |||
import logging | |||
import math | |||
import os | |||
import platform | |||
import random | |||
@@ -17,6 +16,7 @@ from pathlib import Path | |||
from subprocess import check_output | |||
import cv2 | |||
import math | |||
import numpy as np | |||
import pandas as pd | |||
import pkg_resources as pkg | |||
@@ -136,13 +136,16 @@ def check_git_status(err_msg=', for updates see https://github.com/ultralytics/y | |||
print(f'{e}{err_msg}') | |||
def check_python(minimum='3.6.2', required=True): | |||
def check_python(minimum='3.6.2'): | |||
# Check current python version vs. required python version | |||
current = platform.python_version() | |||
result = pkg.parse_version(current) >= pkg.parse_version(minimum) | |||
if required: | |||
assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed' | |||
return result | |||
check_version(platform.python_version(), minimum, name='Python ') | |||
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False): | |||
# Check version vs. required version | |||
current, minimum = (pkg.parse_version(x) for x in (current, minimum)) | |||
result = (current == minimum) if pinned else (current >= minimum) | |||
assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' | |||
def check_requirements(requirements='requirements.txt', exclude=()): |