--image_weights bug fix (#1524)
This commit is contained in:
parent
9728e2b8ae
commit
12499f1c01
|
|
@ -72,12 +72,14 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
||||||
batch_size = min(batch_size, len(dataset))
|
batch_size = min(batch_size, len(dataset))
|
||||||
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
||||||
dataloader = InfiniteDataLoader(dataset,
|
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
|
||||||
batch_size=batch_size,
|
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
|
||||||
num_workers=nw,
|
dataloader = loader(dataset,
|
||||||
sampler=sampler,
|
batch_size=batch_size,
|
||||||
pin_memory=True,
|
num_workers=nw,
|
||||||
collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader()
|
sampler=sampler,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=LoadImagesAndLabels.collate_fn)
|
||||||
return dataloader, dataset
|
return dataloader, dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue