Browse Source

Add InfiniteDataLoader class (#876)

* Add InfiniteDataLoader

Only initializes at first epoch. Saves time.

* Moved class to a better location
5.0
NanoCode012 GitHub 4 years ago
parent
commit
1e15aad6f9
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 42 additions and 6 deletions
  1. +42
    -6
      utils/datasets.py

+ 42
- 6
utils/datasets.py View File

batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
sampler=train_sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn)
dataloader = InfiniteDataLoader (dataset,
batch_size=batch_size,
num_workers=nw,
sampler=train_sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn)
return dataloader, dataset return dataloader, dataset




class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
'''
Dataloader that reuses workers.

Uses same syntax as vanilla DataLoader.
'''

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()

def __len__(self):
return len(self.batch_sampler.sampler)

def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)


class _RepeatSampler(object):
'''
Sampler that repeats forever.

Args:
sampler (Sampler)
'''

def __init__(self, sampler):
self.sampler = sampler

def __iter__(self):
while True:
yield from iter(self.sampler)

class LoadImages: # for inference class LoadImages: # for inference
def __init__(self, path, img_size=640): def __init__(self, path, img_size=640):
p = str(Path(path)) # os-agnostic p = str(Path(path)) # os-agnostic

Loading…
Cancel
Save