Explorar el Código

Merge remote-tracking branch 'origin/master'

5.0
Glenn Jocher hace 4 años
padre
commit
3e04d20c7d
Se han modificado 1 ficheros con 42 adiciones y 6 borrados
  1. +42
    -6
      utils/datasets.py

+ 42
- 6
utils/datasets.py Ver fichero

@@ -63,15 +63,51 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
sampler=train_sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn)
dataloader = InfiniteDataLoader (dataset,
batch_size=batch_size,
num_workers=nw,
sampler=train_sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn)
return dataloader, dataset


class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
'''
Dataloader that reuses workers.

Uses same syntax as vanilla DataLoader.
'''

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()

def __len__(self):
return len(self.batch_sampler.sampler)

def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)


class _RepeatSampler(object):
'''
Sampler that repeats forever.

Args:
sampler (Sampler)
'''

def __init__(self, sampler):
self.sampler = sampler

def __iter__(self):
while True:
yield from iter(self.sampler)

class LoadImages: # for inference
def __init__(self, path, img_size=640):
p = str(Path(path)) # os-agnostic

Cargando…
Cancelar
Guardar