From d49c52eee39afebe8c64d230a30311cf2901e578 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 10 Sep 2020 12:27:35 -0700 Subject: [PATCH] _RepeatSampler outside of InfiniteDataLoader --- utils/datasets.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index 16d2fa8..a5396f0 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -68,7 +68,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa num_workers=nw, sampler=sampler, pin_memory=True, - collate_fn=LoadImagesAndLabels.collate_fn) + collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader() return dataloader, dataset @@ -80,7 +80,7 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler)) + object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) self.iterator = super().__iter__() def __len__(self): @@ -90,19 +90,20 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): for i in range(len(self)): yield next(self.iterator) - class _RepeatSampler(object): - """ Sampler that repeats forever. - Args: - sampler (Sampler) - """ +class _RepeatSampler(object): + """ Sampler that repeats forever. - def __init__(self, sampler): - self.sampler = sampler + Args: + sampler (Sampler) + """ - def __iter__(self): - while True: - yield from iter(self.sampler) + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) class LoadImages: # for inference