Default DataLoader `shuffle=True` for training (#5623)
* Fix shuffle DataLoader argument * Add shuffle argument * Disable shuffle when rect * Cleanup, add rect warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup2 * Cleanup3 Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
7473f0f95d
commit
09d170381c
2
train.py
2
train.py
|
|
@ -212,7 +212,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
||||
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
|
||||
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
|
||||
prefix=colorstr('train: '))
|
||||
prefix=colorstr('train: '), shuffle=True)
|
||||
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
|
||||
nb = len(train_loader) # number of batches
|
||||
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
import yaml
|
||||
from PIL import ExifTags, Image, ImageOps
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
|
||||
|
|
@ -93,13 +93,15 @@ def exif_transpose(image):
|
|||
|
||||
|
||||
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
|
||||
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
|
||||
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
|
||||
with torch_distributed_zero_first(rank):
|
||||
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
|
||||
if rect and shuffle:
|
||||
LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
|
||||
shuffle = False
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
||||
augment=augment, # augment images
|
||||
hyp=hyp, # augmentation hyperparameters
|
||||
rect=rect, # rectangular training
|
||||
augment=augment, # augmentation
|
||||
hyp=hyp, # hyperparameters
|
||||
rect=rect, # rectangular batches
|
||||
cache_images=cache,
|
||||
single_cls=single_cls,
|
||||
stride=int(stride),
|
||||
|
|
@ -109,19 +111,18 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
|
|||
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
||||
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
|
||||
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
|
||||
dataloader = loader(dataset,
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||
return loader(dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=True,
|
||||
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
|
||||
return dataloader, dataset
|
||||
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
|
||||
|
||||
|
||||
class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
||||
class InfiniteDataLoader(dataloader.DataLoader):
|
||||
""" Dataloader that reuses workers
|
||||
|
||||
Uses same syntax as vanilla DataLoader
|
||||
|
|
|
|||
Loading…
Reference in New Issue