|
|
@@ -9,7 +9,7 @@ import random |
|
|
|
import shutil |
|
|
|
import time |
|
|
|
from itertools import repeat |
|
|
|
from multiprocessing.pool import ThreadPool |
|
|
|
from multiprocessing.pool import ThreadPool, Pool |
|
|
|
from pathlib import Path |
|
|
|
from threading import Thread |
|
|
|
|
|
|
@@ -29,6 +29,7 @@ from utils.torch_utils import torch_distributed_zero_first |
|
|
|
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' |
|
|
|
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes |
|
|
|
vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes |
|
|
|
num_threads = min(8, os.cpu_count()) # number of multiprocessing threads |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
# Get orientation exif tag |
|
|
@@ -447,7 +448,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
if cache_images: |
|
|
|
gb = 0 # Gigabytes of cached images |
|
|
|
self.img_hw0, self.img_hw = [None] * n, [None] * n |
|
|
|
results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads |
|
|
|
results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) |
|
|
|
pbar = tqdm(enumerate(results), total=n) |
|
|
|
for i, x in pbar: |
|
|
|
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) |
|
|
@@ -458,53 +459,24 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
def cache_labels(self, path=Path('./labels.cache'), prefix=''): |
|
|
|
# Cache dataset labels, check images and read shapes |
|
|
|
x = {} # dict |
|
|
|
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate |
|
|
|
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) |
|
|
|
for i, (im_file, lb_file) in enumerate(pbar): |
|
|
|
try: |
|
|
|
# verify images |
|
|
|
im = Image.open(im_file) |
|
|
|
im.verify() # PIL verify |
|
|
|
shape = exif_size(im) # image size |
|
|
|
segments = [] # instance segments |
|
|
|
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' |
|
|
|
assert im.format.lower() in img_formats, f'invalid image format {im.format}' |
|
|
|
|
|
|
|
# verify labels |
|
|
|
if os.path.isfile(lb_file): |
|
|
|
nf += 1 # label found |
|
|
|
with open(lb_file, 'r') as f: |
|
|
|
l = [x.split() for x in f.read().strip().splitlines() if len(x)] |
|
|
|
if any([len(x) > 8 for x in l]): # is segment |
|
|
|
classes = np.array([x[0] for x in l], dtype=np.float32) |
|
|
|
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) |
|
|
|
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) |
|
|
|
l = np.array(l, dtype=np.float32) |
|
|
|
if len(l): |
|
|
|
assert l.shape[1] == 5, 'labels require 5 columns each' |
|
|
|
assert (l >= 0).all(), 'negative labels' |
|
|
|
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' |
|
|
|
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' |
|
|
|
else: |
|
|
|
ne += 1 # label empty |
|
|
|
l = np.zeros((0, 5), dtype=np.float32) |
|
|
|
else: |
|
|
|
nm += 1 # label missing |
|
|
|
l = np.zeros((0, 5), dtype=np.float32) |
|
|
|
x[im_file] = [l, shape, segments] |
|
|
|
except Exception as e: |
|
|
|
nc += 1 |
|
|
|
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') |
|
|
|
|
|
|
|
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \ |
|
|
|
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" |
|
|
|
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt |
|
|
|
desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..." |
|
|
|
with Pool(num_threads) as pool: |
|
|
|
pbar = tqdm(pool.imap_unordered(verify_image_label, |
|
|
|
zip(self.img_files, self.label_files, repeat(prefix))), |
|
|
|
desc=desc, total=len(self.img_files)) |
|
|
|
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f in pbar: |
|
|
|
if im_file: |
|
|
|
x[im_file] = [l, shape, segments] |
|
|
|
nm, nf, ne, nc = nm + nm_f, nf + nf_f, ne + ne_f, nc + nc_f |
|
|
|
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted" |
|
|
|
pbar.close() |
|
|
|
|
|
|
|
if nf == 0: |
|
|
|
logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}') |
|
|
|
|
|
|
|
x['hash'] = get_hash(self.label_files + self.img_files) |
|
|
|
x['results'] = nf, nm, ne, nc, i + 1 |
|
|
|
x['results'] = nf, nm, ne, nc, len(self.img_files) |
|
|
|
x['version'] = 0.2 # cache version |
|
|
|
try: |
|
|
|
torch.save(x, path) # save cache for next time |
|
|
@@ -1069,3 +1041,44 @@ def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False): |
|
|
|
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label |
|
|
|
with open(path / txt[i], 'a') as f: |
|
|
|
f.write(str(img) + '\n') # add image to txt file |
|
|
|
|
|
|
|
|
|
|
|
def verify_image_label(params): |
|
|
|
# Verify one image-label pair |
|
|
|
im_file, lb_file, prefix = params |
|
|
|
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt |
|
|
|
try: |
|
|
|
# verify images |
|
|
|
im = Image.open(im_file) |
|
|
|
im.verify() # PIL verify |
|
|
|
shape = exif_size(im) # image size |
|
|
|
segments = [] # instance segments |
|
|
|
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' |
|
|
|
assert im.format.lower() in img_formats, f'invalid image format {im.format}' |
|
|
|
|
|
|
|
# verify labels |
|
|
|
if os.path.isfile(lb_file): |
|
|
|
nf = 1 # label found |
|
|
|
with open(lb_file, 'r') as f: |
|
|
|
l = [x.split() for x in f.read().strip().splitlines() if len(x)] |
|
|
|
if any([len(x) > 8 for x in l]): # is segment |
|
|
|
classes = np.array([x[0] for x in l], dtype=np.float32) |
|
|
|
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) |
|
|
|
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) |
|
|
|
l = np.array(l, dtype=np.float32) |
|
|
|
if len(l): |
|
|
|
assert l.shape[1] == 5, 'labels require 5 columns each' |
|
|
|
assert (l >= 0).all(), 'negative labels' |
|
|
|
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' |
|
|
|
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' |
|
|
|
else: |
|
|
|
ne = 1 # label empty |
|
|
|
l = np.zeros((0, 5), dtype=np.float32) |
|
|
|
else: |
|
|
|
nm = 1 # label missing |
|
|
|
l = np.zeros((0, 5), dtype=np.float32) |
|
|
|
return im_file, l, shape, segments, nm, nf, ne, nc |
|
|
|
except Exception as e: |
|
|
|
nc = 1 |
|
|
|
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') |
|
|
|
return [None] * 4 + [nm, nf, ne, nc] |