|
|
@@ -26,6 +26,11 @@ for orientation in ExifTags.TAGS.keys(): |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
def get_hash(files): |
|
|
|
# Returns a single hash value of a list of files |
|
|
|
return sum(os.path.getsize(f) for f in files if os.path.isfile(f)) |
|
|
|
|
|
|
|
|
|
|
|
def exif_size(img): |
|
|
|
# Returns exif-corrected PIL size |
|
|
|
s = img.size # (width, height) |
|
|
@@ -280,7 +285,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, |
|
|
|
cache_images=False, single_cls=False, stride=32, pad=0.0): |
|
|
|
try: |
|
|
|
f = [] |
|
|
|
f = [] # image files |
|
|
|
for p in path if isinstance(path, list) else [path]: |
|
|
|
p = str(Path(p)) # os-agnostic |
|
|
|
parent = str(Path(p).parent) + os.sep |
|
|
@@ -292,7 +297,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
f += glob.iglob(p + os.sep + '*.*') |
|
|
|
else: |
|
|
|
raise Exception('%s does not exist' % p) |
|
|
|
path = p # *.npy dir |
|
|
|
self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats] |
|
|
|
except Exception as e: |
|
|
|
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url)) |
|
|
@@ -314,20 +318,22 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
self.stride = stride |
|
|
|
|
|
|
|
# Define labels |
|
|
|
self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') |
|
|
|
for x in self.img_files] |
|
|
|
|
|
|
|
# Read image shapes (wh) |
|
|
|
sp = path.replace('.txt', '') + '.shapes' # shapefile path |
|
|
|
try: |
|
|
|
with open(sp, 'r') as f: # read existing shapefile |
|
|
|
s = [x.split() for x in f.read().splitlines()] |
|
|
|
assert len(s) == n, 'Shapefile out of sync' |
|
|
|
except: |
|
|
|
s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')] |
|
|
|
np.savetxt(sp, s, fmt='%g') # overwrites existing (if any) |
|
|
|
self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in |
|
|
|
self.img_files] |
|
|
|
|
|
|
|
# Check cache |
|
|
|
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels |
|
|
|
if os.path.isfile(cache_path): |
|
|
|
cache = torch.load(cache_path) # load |
|
|
|
if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed |
|
|
|
cache = self.cache_labels(cache_path) # re-cache |
|
|
|
else: |
|
|
|
cache = self.cache_labels(cache_path) # cache |
|
|
|
|
|
|
|
self.shapes = np.array(s, dtype=np.float64) |
|
|
|
# Get labels |
|
|
|
labels, shapes = zip(*[cache[x] for x in self.img_files]) |
|
|
|
self.shapes = np.array(shapes, dtype=np.float64) |
|
|
|
self.labels = list(labels) |
|
|
|
|
|
|
|
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232 |
|
|
|
if self.rect: |
|
|
@@ -337,6 +343,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
irect = ar.argsort() |
|
|
|
self.img_files = [self.img_files[i] for i in irect] |
|
|
|
self.label_files = [self.label_files[i] for i in irect] |
|
|
|
self.labels = [self.labels[i] for i in irect] |
|
|
|
self.shapes = s[irect] # wh |
|
|
|
ar = ar[irect] |
|
|
|
|
|
|
@@ -353,33 +360,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride |
|
|
|
|
|
|
|
# Cache labels |
|
|
|
self.imgs = [None] * n |
|
|
|
self.labels = [np.zeros((0, 5), dtype=np.float32)] * n |
|
|
|
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False |
|
|
|
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate |
|
|
|
np_labels_path = str(Path(self.label_files[0]).parent) + '.npy' # saved labels in *.npy file |
|
|
|
if os.path.isfile(np_labels_path): |
|
|
|
s = np_labels_path # print string |
|
|
|
x = np.load(np_labels_path, allow_pickle=True) |
|
|
|
if len(x) == n: |
|
|
|
self.labels = x |
|
|
|
labels_loaded = True |
|
|
|
else: |
|
|
|
s = path.replace('images', 'labels') |
|
|
|
|
|
|
|
pbar = tqdm(self.label_files) |
|
|
|
for i, file in enumerate(pbar): |
|
|
|
if labels_loaded: |
|
|
|
l = self.labels[i] |
|
|
|
# np.savetxt(file, l, '%g') # save *.txt from *.npy file |
|
|
|
else: |
|
|
|
try: |
|
|
|
with open(file, 'r') as f: |
|
|
|
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) |
|
|
|
except: |
|
|
|
nm += 1 # print('missing labels for image %s' % self.img_files[i]) # file missing |
|
|
|
continue |
|
|
|
|
|
|
|
l = self.labels[i] # label |
|
|
|
if l.shape[0]: |
|
|
|
assert l.shape[1] == 5, '> 5 label columns: %s' % file |
|
|
|
assert (l >= 0).all(), 'negative labels: %s' % file |
|
|
@@ -425,15 +410,13 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty |
|
|
|
# os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove |
|
|
|
|
|
|
|
pbar.desc = 'Caching labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % ( |
|
|
|
s, nf, nm, ne, nd, n) |
|
|
|
assert nf > 0 or n == 20288, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url) |
|
|
|
if not labels_loaded and n > 1000: |
|
|
|
print('Saving labels to %s for faster future loading' % np_labels_path) |
|
|
|
np.save(np_labels_path, self.labels) # save for next time |
|
|
|
pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % ( |
|
|
|
cache_path, nf, nm, ne, nd, n) |
|
|
|
assert nf > 0, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url) |
|
|
|
|
|
|
|
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) |
|
|
|
if cache_images: # if training |
|
|
|
self.imgs = [None] * n |
|
|
|
if cache_images: |
|
|
|
gb = 0 # Gigabytes of cached images |
|
|
|
pbar = tqdm(range(len(self.img_files)), desc='Caching images') |
|
|
|
self.img_hw0, self.img_hw = [None] * n, [None] * n |
|
|
@@ -442,15 +425,30 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
gb += self.imgs[i].nbytes |
|
|
|
pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9) |
|
|
|
|
|
|
|
# Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3 |
|
|
|
detect_corrupted_images = False |
|
|
|
if detect_corrupted_images: |
|
|
|
from skimage import io # conda install -c conda-forge scikit-image |
|
|
|
for file in tqdm(self.img_files, desc='Detecting corrupted images'): |
|
|
|
try: |
|
|
|
_ = io.imread(file) |
|
|
|
except: |
|
|
|
print('Corrupted image detected: %s' % file) |
|
|
|
def cache_labels(self, path='labels.cache'): |
|
|
|
# Cache dataset labels, check images and read shapes |
|
|
|
x = {} # dict |
|
|
|
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) |
|
|
|
for (img, label) in pbar: |
|
|
|
try: |
|
|
|
l = [] |
|
|
|
image = Image.open(img) |
|
|
|
image.verify() # PIL verify |
|
|
|
# _ = io.imread(img) # skimage verify (from skimage import io) |
|
|
|
shape = exif_size(image) # image size |
|
|
|
if os.path.isfile(label): |
|
|
|
with open(label, 'r') as f: |
|
|
|
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels |
|
|
|
if len(l) == 0: |
|
|
|
l = np.zeros((0, 5), dtype=np.float32) |
|
|
|
x[img] = [l, shape] |
|
|
|
except Exception as e: |
|
|
|
x[img] = None |
|
|
|
print('WARNING: %s: %s' % (img, e)) |
|
|
|
|
|
|
|
x['hash'] = get_hash(self.label_files + self.img_files) |
|
|
|
torch.save(x, path) # save for next time |
|
|
|
return x |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.img_files) |