_RepeatSampler outside of InfiniteDataLoader

This commit is contained in:
Glenn Jocher 2020-09-10 12:27:35 -07:00
parent bb8872ea5f
commit d49c52eee3
1 changed files with 13 additions and 12 deletions

View File

@ -68,7 +68,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
num_workers=nw, num_workers=nw,
sampler=sampler, sampler=sampler,
pin_memory=True, pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn) collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader()
return dataloader, dataset return dataloader, dataset
@ -80,7 +80,7 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler)) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__() self.iterator = super().__iter__()
def __len__(self): def __len__(self):
@ -90,19 +90,20 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
for i in range(len(self)): for i in range(len(self)):
yield next(self.iterator) yield next(self.iterator)
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args: class _RepeatSampler(object):
sampler (Sampler) """ Sampler that repeats forever.
"""
def __init__(self, sampler): Args:
self.sampler = sampler sampler (Sampler)
"""
def __iter__(self): def __init__(self, sampler):
while True: self.sampler = sampler
yield from iter(self.sampler)
def __iter__(self):
while True:
yield from iter(self.sampler)
class LoadImages: # for inference class LoadImages: # for inference