|
|
@@ -20,7 +20,8 @@ from PIL import Image, ExifTags |
|
|
|
from torch.utils.data import Dataset |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
from utils.general import xyxy2xywh, xywh2xyxy, xywhn2xyxy, clean_str |
|
|
|
from utils.general import xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, resample_segments, \ |
|
|
|
clean_str |
|
|
|
from utils.torch_utils import torch_distributed_zero_first |
|
|
|
|
|
|
|
# Parameters |
|
|
@@ -374,21 +375,23 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
self.label_files = img2label_paths(self.img_files) # labels |
|
|
|
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels |
|
|
|
if cache_path.is_file(): |
|
|
|
cache = torch.load(cache_path) # load |
|
|
|
if cache['hash'] != get_hash(self.label_files + self.img_files) or 'results' not in cache: # changed |
|
|
|
cache = self.cache_labels(cache_path, prefix) # re-cache |
|
|
|
cache, exists = torch.load(cache_path), True # load |
|
|
|
if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed |
|
|
|
cache, exists = self.cache_labels(cache_path, prefix), False # re-cache |
|
|
|
else: |
|
|
|
cache = self.cache_labels(cache_path, prefix) # cache |
|
|
|
cache, exists = self.cache_labels(cache_path, prefix), False # cache |
|
|
|
|
|
|
|
# Display cache |
|
|
|
[nf, nm, ne, nc, n] = cache.pop('results') # found, missing, empty, corrupted, total |
|
|
|
desc = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" |
|
|
|
tqdm(None, desc=prefix + desc, total=n, initial=n) |
|
|
|
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total |
|
|
|
if exists: |
|
|
|
d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" |
|
|
|
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results |
|
|
|
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' |
|
|
|
|
|
|
|
# Read cache |
|
|
|
cache.pop('hash') # remove hash |
|
|
|
labels, shapes = zip(*cache.values()) |
|
|
|
cache.pop('version') # remove version |
|
|
|
labels, shapes, self.segments = zip(*cache.values()) |
|
|
|
self.labels = list(labels) |
|
|
|
self.shapes = np.array(shapes, dtype=np.float64) |
|
|
|
self.img_files = list(cache.keys()) # update |
|
|
@@ -451,6 +454,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
im = Image.open(im_file) |
|
|
|
im.verify() # PIL verify |
|
|
|
shape = exif_size(im) # image size |
|
|
|
segments = [] # instance segments |
|
|
|
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' |
|
|
|
assert im.format.lower() in img_formats, f'invalid image format {im.format}' |
|
|
|
|
|
|
@@ -458,7 +462,12 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
if os.path.isfile(lb_file): |
|
|
|
nf += 1 # label found |
|
|
|
with open(lb_file, 'r') as f: |
|
|
|
l = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels |
|
|
|
l = [x.split() for x in f.read().strip().splitlines()] |
|
|
|
if any([len(x) > 8 for x in l]): # is segment |
|
|
|
classes = np.array([x[0] for x in l], dtype=np.float32) |
|
|
|
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) |
|
|
|
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) |
|
|
|
l = np.array(l, dtype=np.float32) |
|
|
|
if len(l): |
|
|
|
assert l.shape[1] == 5, 'labels require 5 columns each' |
|
|
|
assert (l >= 0).all(), 'negative labels' |
|
|
@@ -470,7 +479,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
else: |
|
|
|
nm += 1 # label missing |
|
|
|
l = np.zeros((0, 5), dtype=np.float32) |
|
|
|
x[im_file] = [l, shape] |
|
|
|
x[im_file] = [l, shape, segments] |
|
|
|
except Exception as e: |
|
|
|
nc += 1 |
|
|
|
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') |
|
|
@@ -482,7 +491,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
print(f'{prefix}WARNING: No labels found in {path}. See {help_url}') |
|
|
|
|
|
|
|
x['hash'] = get_hash(self.label_files + self.img_files) |
|
|
|
x['results'] = [nf, nm, ne, nc, i + 1] |
|
|
|
x['results'] = nf, nm, ne, nc, i + 1 |
|
|
|
x['version'] = 0.1 # cache version |
|
|
|
torch.save(x, path) # save for next time |
|
|
|
logging.info(f'{prefix}New cache created: {path}') |
|
|
|
return x |
|
|
@@ -652,7 +662,7 @@ def hist_equalize(img, clahe=True, bgr=False): |
|
|
|
def load_mosaic(self, index): |
|
|
|
# loads images in a 4-mosaic |
|
|
|
|
|
|
|
labels4 = [] |
|
|
|
labels4, segments4 = [], [] |
|
|
|
s = self.img_size |
|
|
|
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y |
|
|
|
indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(3)] # 3 additional image indices |
|
|
@@ -680,19 +690,21 @@ def load_mosaic(self, index): |
|
|
|
padh = y1a - y1b |
|
|
|
|
|
|
|
# Labels |
|
|
|
labels = self.labels[index].copy() |
|
|
|
labels, segments = self.labels[index].copy(), self.segments[index].copy() |
|
|
|
if labels.size: |
|
|
|
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format |
|
|
|
segments = [xyn2xy(x, w, h, padw, padh) for x in segments] |
|
|
|
labels4.append(labels) |
|
|
|
segments4.extend(segments) |
|
|
|
|
|
|
|
# Concat/clip labels |
|
|
|
if len(labels4): |
|
|
|
labels4 = np.concatenate(labels4, 0) |
|
|
|
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective |
|
|
|
# img4, labels4 = replicate(img4, labels4) # replicate |
|
|
|
labels4 = np.concatenate(labels4, 0) |
|
|
|
for x in (labels4[:, 1:], *segments4): |
|
|
|
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective() |
|
|
|
# img4, labels4 = replicate(img4, labels4) # replicate |
|
|
|
|
|
|
|
# Augment |
|
|
|
img4, labels4 = random_perspective(img4, labels4, |
|
|
|
img4, labels4 = random_perspective(img4, labels4, segments4, |
|
|
|
degrees=self.hyp['degrees'], |
|
|
|
translate=self.hyp['translate'], |
|
|
|
scale=self.hyp['scale'], |
|
|
@@ -706,7 +718,7 @@ def load_mosaic(self, index): |
|
|
|
def load_mosaic9(self, index): |
|
|
|
# loads images in a 9-mosaic |
|
|
|
|
|
|
|
labels9 = [] |
|
|
|
labels9, segments9 = [], [] |
|
|
|
s = self.img_size |
|
|
|
indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(8)] # 8 additional image indices |
|
|
|
for i, index in enumerate(indices): |
|
|
@@ -739,30 +751,34 @@ def load_mosaic9(self, index): |
|
|
|
x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords |
|
|
|
|
|
|
|
# Labels |
|
|
|
labels = self.labels[index].copy() |
|
|
|
labels, segments = self.labels[index].copy(), self.segments[index].copy() |
|
|
|
if labels.size: |
|
|
|
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format |
|
|
|
segments = [xyn2xy(x, w, h, padx, pady) for x in segments] |
|
|
|
labels9.append(labels) |
|
|
|
segments9.extend(segments) |
|
|
|
|
|
|
|
# Image |
|
|
|
img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax] |
|
|
|
hp, wp = h, w # height, width previous |
|
|
|
|
|
|
|
# Offset |
|
|
|
yc, xc = [int(random.uniform(0, s)) for x in self.mosaic_border] # mosaic center x, y |
|
|
|
yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y |
|
|
|
img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s] |
|
|
|
|
|
|
|
# Concat/clip labels |
|
|
|
if len(labels9): |
|
|
|
labels9 = np.concatenate(labels9, 0) |
|
|
|
labels9[:, [1, 3]] -= xc |
|
|
|
labels9[:, [2, 4]] -= yc |
|
|
|
labels9 = np.concatenate(labels9, 0) |
|
|
|
labels9[:, [1, 3]] -= xc |
|
|
|
labels9[:, [2, 4]] -= yc |
|
|
|
c = np.array([xc, yc]) # centers |
|
|
|
segments9 = [x - c for x in segments9] |
|
|
|
|
|
|
|
np.clip(labels9[:, 1:], 0, 2 * s, out=labels9[:, 1:]) # use with random_perspective |
|
|
|
# img9, labels9 = replicate(img9, labels9) # replicate |
|
|
|
for x in (labels9[:, 1:], *segments9): |
|
|
|
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective() |
|
|
|
# img9, labels9 = replicate(img9, labels9) # replicate |
|
|
|
|
|
|
|
# Augment |
|
|
|
img9, labels9 = random_perspective(img9, labels9, |
|
|
|
img9, labels9 = random_perspective(img9, labels9, segments9, |
|
|
|
degrees=self.hyp['degrees'], |
|
|
|
translate=self.hyp['translate'], |
|
|
|
scale=self.hyp['scale'], |
|
|
@@ -823,7 +839,8 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale |
|
|
|
return img, ratio, (dw, dh) |
|
|
|
|
|
|
|
|
|
|
|
def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, border=(0, 0)): |
|
|
|
def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, |
|
|
|
border=(0, 0)): |
|
|
|
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) |
|
|
|
# targets = [cls, xyxy] |
|
|
|
|
|
|
@@ -875,37 +892,38 @@ def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shea |
|
|
|
# Transform label coordinates |
|
|
|
n = len(targets) |
|
|
|
if n: |
|
|
|
# warp points |
|
|
|
xy = np.ones((n * 4, 3)) |
|
|
|
xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 |
|
|
|
xy = xy @ M.T # transform |
|
|
|
if perspective: |
|
|
|
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale |
|
|
|
else: # affine |
|
|
|
xy = xy[:, :2].reshape(n, 8) |
|
|
|
|
|
|
|
# create new boxes |
|
|
|
x = xy[:, [0, 2, 4, 6]] |
|
|
|
y = xy[:, [1, 3, 5, 7]] |
|
|
|
xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T |
|
|
|
|
|
|
|
# # apply angle-based reduction of bounding boxes |
|
|
|
# radians = a * math.pi / 180 |
|
|
|
# reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5 |
|
|
|
# x = (xy[:, 2] + xy[:, 0]) / 2 |
|
|
|
# y = (xy[:, 3] + xy[:, 1]) / 2 |
|
|
|
# w = (xy[:, 2] - xy[:, 0]) * reduction |
|
|
|
# h = (xy[:, 3] - xy[:, 1]) * reduction |
|
|
|
# xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T |
|
|
|
|
|
|
|
# clip boxes |
|
|
|
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) |
|
|
|
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) |
|
|
|
use_segments = any(x.any() for x in segments) |
|
|
|
new = np.zeros((n, 4)) |
|
|
|
if use_segments: # warp segments |
|
|
|
segments = resample_segments(segments) # upsample |
|
|
|
for i, segment in enumerate(segments): |
|
|
|
xy = np.ones((len(segment), 3)) |
|
|
|
xy[:, :2] = segment |
|
|
|
xy = xy @ M.T # transform |
|
|
|
xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine |
|
|
|
|
|
|
|
# clip |
|
|
|
new[i] = segment2box(xy, width, height) |
|
|
|
|
|
|
|
else: # warp boxes |
|
|
|
xy = np.ones((n * 4, 3)) |
|
|
|
xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 |
|
|
|
xy = xy @ M.T # transform |
|
|
|
xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine |
|
|
|
|
|
|
|
# create new boxes |
|
|
|
x = xy[:, [0, 2, 4, 6]] |
|
|
|
y = xy[:, [1, 3, 5, 7]] |
|
|
|
new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T |
|
|
|
|
|
|
|
# clip |
|
|
|
new[:, [0, 2]] = new[:, [0, 2]].clip(0, width) |
|
|
|
new[:, [1, 3]] = new[:, [1, 3]].clip(0, height) |
|
|
|
|
|
|
|
# filter candidates |
|
|
|
i = box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T) |
|
|
|
i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10) |
|
|
|
targets = targets[i] |
|
|
|
targets[:, 1:5] = xy[i] |
|
|
|
targets[:, 1:5] = new[i] |
|
|
|
|
|
|
|
return img, targets |
|
|
|
|