_RepeatSampler outside of InfiniteDataLoader
This commit is contained in:
parent
bb8872ea5f
commit
d49c52eee3
|
|
@ -68,7 +68,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=True,
|
||||
collate_fn=LoadImagesAndLabels.collate_fn)
|
||||
collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader()
|
||||
return dataloader, dataset
|
||||
|
||||
|
||||
|
|
@ -80,7 +80,7 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler))
|
||||
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
|
|
@ -90,6 +90,7 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
|||
for i in range(len(self)):
|
||||
yield next(self.iterator)
|
||||
|
||||
|
||||
class _RepeatSampler(object):
|
||||
""" Sampler that repeats forever.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue