Default DataLoader `shuffle=True` for training (#5623)
* Fix shuffle DataLoader argument * Add shuffle argument * Disable shuffle when rect * Cleanup, add rect warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup2 * Cleanup3 Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
7473f0f95d
commit
09d170381c
2
train.py
2
train.py
|
|
@ -212,7 +212,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
||||||
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
|
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
|
||||||
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
|
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
|
||||||
prefix=colorstr('train: '))
|
prefix=colorstr('train: '), shuffle=True)
|
||||||
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
|
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
|
||||||
nb = len(train_loader) # number of batches
|
nb = len(train_loader) # number of batches
|
||||||
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
|
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import yaml
|
import yaml
|
||||||
from PIL import ExifTags, Image, ImageOps
|
from PIL import ExifTags, Image, ImageOps
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
|
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
|
||||||
|
|
@ -93,13 +93,15 @@ def exif_transpose(image):
|
||||||
|
|
||||||
|
|
||||||
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
|
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
|
||||||
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
|
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
|
||||||
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
if rect and shuffle:
|
||||||
with torch_distributed_zero_first(rank):
|
LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
|
||||||
|
shuffle = False
|
||||||
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
||||||
augment=augment, # augment images
|
augment=augment, # augmentation
|
||||||
hyp=hyp, # augmentation hyperparameters
|
hyp=hyp, # hyperparameters
|
||||||
rect=rect, # rectangular training
|
rect=rect, # rectangular batches
|
||||||
cache_images=cache,
|
cache_images=cache,
|
||||||
single_cls=single_cls,
|
single_cls=single_cls,
|
||||||
stride=int(stride),
|
stride=int(stride),
|
||||||
|
|
@ -109,19 +111,18 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
|
||||||
|
|
||||||
batch_size = min(batch_size, len(dataset))
|
batch_size = min(batch_size, len(dataset))
|
||||||
nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||||
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
|
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||||
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
|
return loader(dataset,
|
||||||
dataloader = loader(dataset,
|
batch_size=batch_size,
|
||||||
batch_size=batch_size,
|
shuffle=shuffle and sampler is None,
|
||||||
num_workers=nw,
|
num_workers=nw,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
|
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
|
||||||
return dataloader, dataset
|
|
||||||
|
|
||||||
|
|
||||||
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
class InfiniteDataLoader(dataloader.DataLoader):
|
||||||
""" Dataloader that reuses workers
|
""" Dataloader that reuses workers
|
||||||
|
|
||||||
Uses same syntax as vanilla DataLoader
|
Uses same syntax as vanilla DataLoader
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue