|
|
@@ -72,12 +72,14 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa |
|
|
|
batch_size = min(batch_size, len(dataset)) |
|
|
|
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers |
|
|
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None |
|
|
|
dataloader = InfiniteDataLoader(dataset, |
|
|
|
batch_size=batch_size, |
|
|
|
num_workers=nw, |
|
|
|
sampler=sampler, |
|
|
|
pin_memory=True, |
|
|
|
collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader() |
|
|
|
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader |
|
|
|
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() |
|
|
|
dataloader = loader(dataset, |
|
|
|
batch_size=batch_size, |
|
|
|
num_workers=nw, |
|
|
|
sampler=sampler, |
|
|
|
pin_memory=True, |
|
|
|
collate_fn=LoadImagesAndLabels.collate_fn) |
|
|
|
return dataloader, dataset |
|
|
|
|
|
|
|
|