From 12499f1c014494263d75cf19b63311e80362a38c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 26 Nov 2020 13:25:51 +0100 Subject: [PATCH] --image_weights bug fix (#1524) --- utils/datasets.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index d55323f..cd6151e 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -72,12 +72,14 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None - dataloader = InfiniteDataLoader(dataset, - batch_size=batch_size, - num_workers=nw, - sampler=sampler, - pin_memory=True, - collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader() + loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader + # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() + dataloader = loader(dataset, + batch_size=batch_size, + num_workers=nw, + sampler=sampler, + pin_memory=True, + collate_fn=LoadImagesAndLabels.collate_fn) return dataloader, dataset