|
|
@@ -22,12 +22,11 @@ from tqdm import tqdm |
|
|
|
from utils.general import xyxy2xywh, xywh2xyxy |
|
|
|
from utils.torch_utils import torch_distributed_zero_first |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
# Parameters |
|
|
|
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' |
|
|
|
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng'] # acceptable image suffixes |
|
|
|
vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
# Get orientation exif tag |
|
|
|
for orientation in ExifTags.TAGS.keys(): |
|
|
@@ -168,14 +167,14 @@ class LoadImages: # for inference |
|
|
|
ret_val, img0 = self.cap.read() |
|
|
|
|
|
|
|
self.frame += 1 |
|
|
|
logger.debug('video %g/%g (%g/%g) %s: ', self.count + 1, self.nf, self.frame, self.nframes, path) |
|
|
|
print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='') |
|
|
|
|
|
|
|
else: |
|
|
|
# Read image |
|
|
|
self.count += 1 |
|
|
|
img0 = cv2.imread(path) # BGR |
|
|
|
assert img0 is not None, 'Image Not Found ' + path |
|
|
|
logger.debug('image %g/%g %s: ', self.count, self.nf, path) |
|
|
|
print('image %g/%g %s: ' % (self.count, self.nf, path), end='') |
|
|
|
|
|
|
|
# Padded resize |
|
|
|
img = letterbox(img0, new_shape=self.img_size)[0] |
|
|
@@ -237,7 +236,7 @@ class LoadWebcam: # for inference |
|
|
|
# Print |
|
|
|
assert ret_val, 'Camera Error %s' % self.pipe |
|
|
|
img_path = 'webcam.jpg' |
|
|
|
logger.debug('webcam %g: ', self.count) |
|
|
|
print('webcam %g: ' % self.count, end='') |
|
|
|
|
|
|
|
# Padded resize |
|
|
|
img = letterbox(img0, new_shape=self.img_size)[0] |
|
|
@@ -268,7 +267,7 @@ class LoadStreams: # multiple IP or RTSP cameras |
|
|
|
self.sources = sources |
|
|
|
for i, s in enumerate(sources): |
|
|
|
# Start the thread to read frames from the video stream |
|
|
|
logger.debug('%g/%g: %s... ', i + 1, n, s) |
|
|
|
print('%g/%g: %s... ' % (i + 1, n, s), end='') |
|
|
|
cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s) |
|
|
|
assert cap.isOpened(), 'Failed to open %s' % s |
|
|
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
@@ -276,14 +275,15 @@ class LoadStreams: # multiple IP or RTSP cameras |
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) % 100 |
|
|
|
_, self.imgs[i] = cap.read() # guarantee first frame |
|
|
|
thread = Thread(target=self.update, args=([i, cap]), daemon=True) |
|
|
|
logger.debug(' success (%gx%g at %.2f FPS).', w, h, fps) |
|
|
|
print(' success (%gx%g at %.2f FPS).' % (w, h, fps)) |
|
|
|
thread.start() |
|
|
|
print('') # newline |
|
|
|
|
|
|
|
# check for common shapes |
|
|
|
s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes |
|
|
|
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal |
|
|
|
if not self.rect: |
|
|
|
logger.warning('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') |
|
|
|
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') |
|
|
|
|
|
|
|
def update(self, index, cap): |
|
|
|
# Read next stream frame in a daemon thread |
|
|
@@ -324,6 +324,12 @@ class LoadStreams: # multiple IP or RTSP cameras |
|
|
|
return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years |
|
|
|
|
|
|
|
|
|
|
|
def img2label_paths(img_paths): |
|
|
|
# Define label paths as a function of image paths |
|
|
|
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings |
|
|
|
return [x.replace(sa, sb, 1).replace('.' + x.split('.')[-1], '.txt') for x in img_paths] |
|
|
|
|
|
|
|
|
|
|
|
class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, |
|
|
|
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1): |
|
|
@@ -336,11 +342,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
self.mosaic_border = [-img_size // 2, -img_size // 2] |
|
|
|
self.stride = stride |
|
|
|
|
|
|
|
def img2label_paths(img_paths): |
|
|
|
# Define label paths as a function of image paths |
|
|
|
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings |
|
|
|
return [x.replace(sa, sb, 1).replace('.' + x.split('.')[-1], '.txt') for x in img_paths] |
|
|
|
|
|
|
|
try: |
|
|
|
f = [] # image files |
|
|
|
for p in path if isinstance(path, list) else [path]: |
|
|
@@ -361,14 +362,20 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
|
|
|
|
# Check cache |
|
|
|
self.label_files = img2label_paths(self.img_files) # labels |
|
|
|
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels |
|
|
|
if os.path.isfile(cache_path): |
|
|
|
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') # cached labels |
|
|
|
if cache_path.is_file(): |
|
|
|
cache = torch.load(cache_path) # load |
|
|
|
if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed |
|
|
|
cache = self.cache_labels(cache_path) # re-cache |
|
|
|
else: |
|
|
|
cache = self.cache_labels(cache_path) # cache |
|
|
|
|
|
|
|
# Display cache |
|
|
|
[nf, nm, ne, nc, n] = cache.pop('results') # found, missing, empty, corrupted, total |
|
|
|
desc = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" |
|
|
|
tqdm(None, desc=desc, total=n, initial=n) |
|
|
|
assert nf > 0 or not augment, f'No labels found in {cache_path}. Can not train without labels. See {help_url}' |
|
|
|
|
|
|
|
# Read cache |
|
|
|
cache.pop('hash') # remove hash |
|
|
|
labels, shapes = zip(*cache.values()) |
|
|
@@ -376,6 +383,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
self.shapes = np.array(shapes, dtype=np.float64) |
|
|
|
self.img_files = list(cache.keys()) # update |
|
|
|
self.label_files = img2label_paths(cache.keys()) # update |
|
|
|
if single_cls: |
|
|
|
for x in self.labels: |
|
|
|
x[:, 0] = 0 |
|
|
|
|
|
|
|
n = len(shapes) # number of images |
|
|
|
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index |
|
|
@@ -407,67 +417,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
|
|
|
|
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride |
|
|
|
|
|
|
|
# Check labels |
|
|
|
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False |
|
|
|
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate |
|
|
|
pbar = enumerate(self.label_files) |
|
|
|
if rank in [-1, 0]: |
|
|
|
pbar = tqdm(pbar) |
|
|
|
for i, file in pbar: |
|
|
|
l = self.labels[i] # label |
|
|
|
if l is not None and l.shape[0]: |
|
|
|
assert l.shape[1] == 5, '> 5 label columns: %s' % file |
|
|
|
assert (l >= 0).all(), 'negative labels: %s' % file |
|
|
|
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file |
|
|
|
if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows |
|
|
|
nd += 1 # logger.warning('WARNING: duplicate rows in %s', self.label_files[i]) # duplicate rows |
|
|
|
if single_cls: |
|
|
|
l[:, 0] = 0 # force dataset into single-class mode |
|
|
|
self.labels[i] = l |
|
|
|
nf += 1 # file found |
|
|
|
|
|
|
|
# Create subdataset (a smaller dataset) |
|
|
|
if create_datasubset and ns < 1E4: |
|
|
|
if ns == 0: |
|
|
|
create_folder(path='./datasubset') |
|
|
|
os.makedirs('./datasubset/images') |
|
|
|
exclude_classes = 43 |
|
|
|
if exclude_classes not in l[:, 0]: |
|
|
|
ns += 1 |
|
|
|
# shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image |
|
|
|
with open('./datasubset/images.txt', 'a') as f: |
|
|
|
f.write(self.img_files[i] + '\n') |
|
|
|
|
|
|
|
# Extract object detection boxes for a second stage classifier |
|
|
|
if extract_bounding_boxes: |
|
|
|
p = Path(self.img_files[i]) |
|
|
|
img = cv2.imread(str(p)) |
|
|
|
h, w = img.shape[:2] |
|
|
|
for j, x in enumerate(l): |
|
|
|
f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name) |
|
|
|
if not os.path.exists(Path(f).parent): |
|
|
|
os.makedirs(Path(f).parent) # make new output folder |
|
|
|
|
|
|
|
b = x[1:] * [w, h, w, h] # box |
|
|
|
b[2:] = b[2:].max() # rectangle to square |
|
|
|
b[2:] = b[2:] * 1.3 + 30 # pad |
|
|
|
b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int) |
|
|
|
|
|
|
|
b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image |
|
|
|
b[[1, 3]] = np.clip(b[[1, 3]], 0, h) |
|
|
|
assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes' |
|
|
|
else: |
|
|
|
ne += 1 # logger.info('empty labels for image %s', self.img_files[i]) # file empty |
|
|
|
# os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove |
|
|
|
|
|
|
|
if rank in [-1, 0]: |
|
|
|
pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % ( |
|
|
|
cache_path, nf, nm, ne, nd, n) |
|
|
|
if nf == 0: |
|
|
|
s = 'WARNING: No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url) |
|
|
|
logger.info(s) |
|
|
|
assert not augment, '%s. Can not train without labels.' % s |
|
|
|
|
|
|
|
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) |
|
|
|
self.imgs = [None] * n |
|
|
|
if cache_images: |
|
|
@@ -480,28 +429,50 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
gb += self.imgs[i].nbytes |
|
|
|
pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9) |
|
|
|
|
|
|
|
def cache_labels(self, path='labels.cache'): |
|
|
|
def cache_labels(self, path=Path('./labels.cache')): |
|
|
|
# Cache dataset labels, check images and read shapes |
|
|
|
x = {} # dict |
|
|
|
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate |
|
|
|
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) |
|
|
|
for (img, label) in pbar: |
|
|
|
for i, (im_file, lb_file) in enumerate(pbar): |
|
|
|
try: |
|
|
|
l = [] |
|
|
|
im = Image.open(img) |
|
|
|
# verify images |
|
|
|
im = Image.open(im_file) |
|
|
|
im.verify() # PIL verify |
|
|
|
shape = exif_size(im) # image size |
|
|
|
assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels' |
|
|
|
if os.path.isfile(label): |
|
|
|
with open(label, 'r') as f: |
|
|
|
|
|
|
|
# verify labels |
|
|
|
l = [] |
|
|
|
if os.path.isfile(lb_file): |
|
|
|
nf += 1 # label found |
|
|
|
with open(lb_file, 'r') as f: |
|
|
|
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels |
|
|
|
if len(l) == 0: |
|
|
|
l = np.zeros((0, 5), dtype=np.float32) |
|
|
|
x[img] = [l, shape] |
|
|
|
if len(l): |
|
|
|
assert l.shape[1] == 5, 'labels require 5 columns each' |
|
|
|
assert (l >= 0).all(), 'negative labels' |
|
|
|
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' |
|
|
|
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' |
|
|
|
else: |
|
|
|
ne += 1 # label empty |
|
|
|
l = np.zeros((0, 5), dtype=np.float32) |
|
|
|
else: |
|
|
|
nm += 1 # label missing |
|
|
|
x[im_file] = [l, shape] |
|
|
|
except Exception as e: |
|
|
|
logger.warning('WARNING: Ignoring corrupted image and/or label %s: %s', img, e) |
|
|
|
nc += 1 |
|
|
|
print('WARNING: Ignoring corrupted image and/or label %s: %s' % (im_file, e)) |
|
|
|
|
|
|
|
pbar.desc = f"Scanning '{path.parent / path.stem}' for images and labels... " \ |
|
|
|
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" |
|
|
|
|
|
|
|
if nf == 0: |
|
|
|
print(f'WARNING: No labels found in {path}. See {help_url}') |
|
|
|
|
|
|
|
x['hash'] = get_hash(self.label_files + self.img_files) |
|
|
|
x['results'] = [nf, nm, ne, nc, i] |
|
|
|
torch.save(x, path) # save for next time |
|
|
|
logging.info(f"New cache created: '{path}'") |
|
|
|
return x |
|
|
|
|
|
|
|
def __len__(self): |
|
|
@@ -509,7 +480,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing |
|
|
|
|
|
|
|
# def __iter__(self): |
|
|
|
# self.count = -1 |
|
|
|
# logger.info('ran dataset iter') |
|
|
|
# print('ran dataset iter') |
|
|
|
# #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF) |
|
|
|
# return self |
|
|
|
|
|
|
@@ -906,6 +877,41 @@ def flatten_recursive(path='../coco128'): |
|
|
|
shutil.copyfile(file, new_path / Path(file).name) |
|
|
|
|
|
|
|
|
|
|
|
def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128') |
|
|
|
# Convert detection dataset into classification dataset, with one directory per class |
|
|
|
|
|
|
|
path = Path(path) # images dir |
|
|
|
shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing |
|
|
|
files = list(path.rglob('*.*')) |
|
|
|
n = len(files) # number of files |
|
|
|
for im_file in tqdm(files, total=n): |
|
|
|
if im_file.suffix[1:] in img_formats: |
|
|
|
# image |
|
|
|
im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB |
|
|
|
h, w = im.shape[:2] |
|
|
|
|
|
|
|
# labels |
|
|
|
lb_file = Path(img2label_paths([str(im_file)])[0]) |
|
|
|
if Path(lb_file).exists(): |
|
|
|
with open(lb_file, 'r') as f: |
|
|
|
lb = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels |
|
|
|
|
|
|
|
for j, x in enumerate(lb): |
|
|
|
c = int(x[0]) # class |
|
|
|
f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename |
|
|
|
if not f.parent.is_dir(): |
|
|
|
f.parent.mkdir(parents=True) |
|
|
|
|
|
|
|
b = x[1:] * [w, h, w, h] # box |
|
|
|
# b[2:] = b[2:].max() # rectangle to square |
|
|
|
b[2:] = b[2:] * 1.2 + 3 # pad |
|
|
|
b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int) |
|
|
|
|
|
|
|
b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image |
|
|
|
b[[1, 3]] = np.clip(b[[1, 3]], 0, h) |
|
|
|
assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}' |
|
|
|
|
|
|
|
|
|
|
|
def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0)): # from utils.datasets import *; autosplit('../coco128') |
|
|
|
""" Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files |
|
|
|
# Arguments |