Default DataLoader `shuffle=True` for training (#5623)

* Fix shuffle DataLoader argument

* Add shuffle argument

* Disable shuffle when rect

* Cleanup, add rect warning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup2

* Cleanup3

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Werner Duvaud 2021-11-13 12:07:32 +00:00 committed by GitHub
parent 7473f0f95d
commit 09d170381c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 19 deletions

View File

@ -212,7 +212,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK, hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad, workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: ')) prefix=colorstr('train: '), shuffle=True)
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
nb = len(train_loader) # number of batches nb = len(train_loader) # number of batches
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}' assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'

View File

@ -22,7 +22,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import yaml import yaml
from PIL import ExifTags, Image, ImageOps from PIL import ExifTags, Image, ImageOps
from torch.utils.data import Dataset from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm from tqdm import tqdm
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
@ -93,13 +93,15 @@ def exif_transpose(image):
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''): rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache if rect and shuffle:
with torch_distributed_zero_first(rank): LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = LoadImagesAndLabels(path, imgsz, batch_size, dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images augment=augment, # augmentation
hyp=hyp, # augmentation hyperparameters hyp=hyp, # hyperparameters
rect=rect, # rectangular training rect=rect, # rectangular batches
cache_images=cache, cache_images=cache,
single_cls=single_cls, single_cls=single_cls,
stride=int(stride), stride=int(stride),
@ -109,19 +111,18 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() return loader(dataset,
dataloader = loader(dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw, num_workers=nw,
sampler=sampler, sampler=sampler,
pin_memory=True, pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
return dataloader, dataset
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): class InfiniteDataLoader(dataloader.DataLoader):
""" Dataloader that reuses workers """ Dataloader that reuses workers
Uses same syntax as vanilla DataLoader Uses same syntax as vanilla DataLoader