Browse Source

Update caching (#1496)

5.0
Glenn Jocher GitHub 4 years ago
parent
commit
89c7a5b8dc
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 93 additions and 87 deletions
  1. +93
    -87
      utils/datasets.py

+ 93
- 87
utils/datasets.py View File

@@ -22,12 +22,11 @@ from tqdm import tqdm
from utils.general import xyxy2xywh, xywh2xyxy
from utils.torch_utils import torch_distributed_zero_first

logger = logging.getLogger(__name__)

# Parameters
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng'] # acceptable image suffixes
vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
logger = logging.getLogger(__name__)

# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
@@ -168,14 +167,14 @@ class LoadImages: # for inference
ret_val, img0 = self.cap.read()

self.frame += 1
logger.debug('video %g/%g (%g/%g) %s: ', self.count + 1, self.nf, self.frame, self.nframes, path)
print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')

else:
# Read image
self.count += 1
img0 = cv2.imread(path) # BGR
assert img0 is not None, 'Image Not Found ' + path
logger.debug('image %g/%g %s: ', self.count, self.nf, path)
print('image %g/%g %s: ' % (self.count, self.nf, path), end='')

# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
@@ -237,7 +236,7 @@ class LoadWebcam: # for inference
# Print
assert ret_val, 'Camera Error %s' % self.pipe
img_path = 'webcam.jpg'
logger.debug('webcam %g: ', self.count)
print('webcam %g: ' % self.count, end='')

# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
@@ -268,7 +267,7 @@ class LoadStreams: # multiple IP or RTSP cameras
self.sources = sources
for i, s in enumerate(sources):
# Start the thread to read frames from the video stream
logger.debug('%g/%g: %s... ', i + 1, n, s)
print('%g/%g: %s... ' % (i + 1, n, s), end='')
cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s)
assert cap.isOpened(), 'Failed to open %s' % s
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -276,14 +275,15 @@ class LoadStreams: # multiple IP or RTSP cameras
fps = cap.get(cv2.CAP_PROP_FPS) % 100
_, self.imgs[i] = cap.read() # guarantee first frame
thread = Thread(target=self.update, args=([i, cap]), daemon=True)
logger.debug(' success (%gx%g at %.2f FPS).', w, h, fps)
print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
thread.start()
print('') # newline

# check for common shapes
s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect:
logger.warning('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')

def update(self, index, cap):
# Read next stream frame in a daemon thread
@@ -324,6 +324,12 @@ class LoadStreams: # multiple IP or RTSP cameras
return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years


def img2label_paths(img_paths):
# Define label paths as a function of image paths
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
return [x.replace(sa, sb, 1).replace('.' + x.split('.')[-1], '.txt') for x in img_paths]


class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
@@ -336,11 +342,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride

def img2label_paths(img_paths):
# Define label paths as a function of image paths
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
return [x.replace(sa, sb, 1).replace('.' + x.split('.')[-1], '.txt') for x in img_paths]

try:
f = [] # image files
for p in path if isinstance(path, list) else [path]:
@@ -361,14 +362,20 @@ class LoadImagesAndLabels(Dataset): # for training/testing

# Check cache
self.label_files = img2label_paths(self.img_files) # labels
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
if os.path.isfile(cache_path):
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') # cached labels
if cache_path.is_file():
cache = torch.load(cache_path) # load
if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
cache = self.cache_labels(cache_path) # re-cache
else:
cache = self.cache_labels(cache_path) # cache

# Display cache
[nf, nm, ne, nc, n] = cache.pop('results') # found, missing, empty, corrupted, total
desc = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm(None, desc=desc, total=n, initial=n)
assert nf > 0 or not augment, f'No labels found in {cache_path}. Can not train without labels. See {help_url}'

# Read cache
cache.pop('hash') # remove hash
labels, shapes = zip(*cache.values())
@@ -376,6 +383,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.shapes = np.array(shapes, dtype=np.float64)
self.img_files = list(cache.keys()) # update
self.label_files = img2label_paths(cache.keys()) # update
if single_cls:
for x in self.labels:
x[:, 0] = 0

n = len(shapes) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
@@ -407,67 +417,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing

self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride

# Check labels
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
pbar = enumerate(self.label_files)
if rank in [-1, 0]:
pbar = tqdm(pbar)
for i, file in pbar:
l = self.labels[i] # label
if l is not None and l.shape[0]:
assert l.shape[1] == 5, '> 5 label columns: %s' % file
assert (l >= 0).all(), 'negative labels: %s' % file
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
nd += 1 # logger.warning('WARNING: duplicate rows in %s', self.label_files[i]) # duplicate rows
if single_cls:
l[:, 0] = 0 # force dataset into single-class mode
self.labels[i] = l
nf += 1 # file found

# Create subdataset (a smaller dataset)
if create_datasubset and ns < 1E4:
if ns == 0:
create_folder(path='./datasubset')
os.makedirs('./datasubset/images')
exclude_classes = 43
if exclude_classes not in l[:, 0]:
ns += 1
# shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
with open('./datasubset/images.txt', 'a') as f:
f.write(self.img_files[i] + '\n')

# Extract object detection boxes for a second stage classifier
if extract_bounding_boxes:
p = Path(self.img_files[i])
img = cv2.imread(str(p))
h, w = img.shape[:2]
for j, x in enumerate(l):
f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
if not os.path.exists(Path(f).parent):
os.makedirs(Path(f).parent) # make new output folder

b = x[1:] * [w, h, w, h] # box
b[2:] = b[2:].max() # rectangle to square
b[2:] = b[2:] * 1.3 + 30 # pad
b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)

b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
else:
ne += 1 # logger.info('empty labels for image %s', self.img_files[i]) # file empty
# os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove

if rank in [-1, 0]:
pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
cache_path, nf, nm, ne, nd, n)
if nf == 0:
s = 'WARNING: No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
logger.info(s)
assert not augment, '%s. Can not train without labels.' % s

# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
self.imgs = [None] * n
if cache_images:
@@ -480,28 +429,50 @@ class LoadImagesAndLabels(Dataset): # for training/testing
gb += self.imgs[i].nbytes
pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)

def cache_labels(self, path='labels.cache'):
def cache_labels(self, path=Path('./labels.cache')):
# Cache dataset labels, check images and read shapes
x = {} # dict
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
for (img, label) in pbar:
for i, (im_file, lb_file) in enumerate(pbar):
try:
l = []
im = Image.open(img)
# verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
if os.path.isfile(label):
with open(label, 'r') as f:

# verify labels
l = []
if os.path.isfile(lb_file):
nf += 1 # label found
with open(lb_file, 'r') as f:
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels
if len(l) == 0:
l = np.zeros((0, 5), dtype=np.float32)
x[img] = [l, shape]
if len(l):
assert l.shape[1] == 5, 'labels require 5 columns each'
assert (l >= 0).all(), 'negative labels'
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
else:
ne += 1 # label empty
l = np.zeros((0, 5), dtype=np.float32)
else:
nm += 1 # label missing
x[im_file] = [l, shape]
except Exception as e:
logger.warning('WARNING: Ignoring corrupted image and/or label %s: %s', img, e)
nc += 1
print('WARNING: Ignoring corrupted image and/or label %s: %s' % (im_file, e))

pbar.desc = f"Scanning '{path.parent / path.stem}' for images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"

if nf == 0:
print(f'WARNING: No labels found in {path}. See {help_url}')

x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = [nf, nm, ne, nc, i]
torch.save(x, path) # save for next time
logging.info(f"New cache created: '{path}'")
return x

def __len__(self):
@@ -509,7 +480,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing

# def __iter__(self):
# self.count = -1
# logger.info('ran dataset iter')
# print('ran dataset iter')
# #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
# return self

@@ -906,6 +877,41 @@ def flatten_recursive(path='../coco128'):
shutil.copyfile(file, new_path / Path(file).name)


def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128')
# Convert detection dataset into classification dataset, with one directory per class

path = Path(path) # images dir
shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
files = list(path.rglob('*.*'))
n = len(files) # number of files
for im_file in tqdm(files, total=n):
if im_file.suffix[1:] in img_formats:
# image
im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
h, w = im.shape[:2]

# labels
lb_file = Path(img2label_paths([str(im_file)])[0])
if Path(lb_file).exists():
with open(lb_file, 'r') as f:
lb = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels

for j, x in enumerate(lb):
c = int(x[0]) # class
f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
if not f.parent.is_dir():
f.parent.mkdir(parents=True)

b = x[1:] * [w, h, w, h] # box
# b[2:] = b[2:].max() # rectangle to square
b[2:] = b[2:] * 1.2 + 3 # pad
b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)

b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'


def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0)): # from utils.datasets import *; autosplit('../coco128')
""" Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
# Arguments

Loading…
Cancel
Save