_RepeatSampler outside of InfiniteDataLoader
This commit is contained in:
parent
bb8872ea5f
commit
d49c52eee3
|
|
@ -68,7 +68,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
||||||
num_workers=nw,
|
num_workers=nw,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=LoadImagesAndLabels.collate_fn)
|
collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader()
|
||||||
return dataloader, dataset
|
return dataloader, dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -80,7 +80,7 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler))
|
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
||||||
self.iterator = super().__iter__()
|
self.iterator = super().__iter__()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
@ -90,6 +90,7 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
|
||||||
for i in range(len(self)):
|
for i in range(len(self)):
|
||||||
yield next(self.iterator)
|
yield next(self.iterator)
|
||||||
|
|
||||||
|
|
||||||
class _RepeatSampler(object):
|
class _RepeatSampler(object):
|
||||||
""" Sampler that repeats forever.
|
""" Sampler that repeats forever.
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue