Browse Source

Update datasets.py

5.0
Glenn Jocher 4 years ago
parent
commit
d3f9bf2bb7
1 changed files with 18 additions and 25 deletions
  1. +18
    -25
      utils/datasets.py

+ 18
- 25
utils/datasets.py View File

@@ -62,26 +62,25 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa

batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
dataloader = InfiniteDataLoader (dataset,
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
dataloader = InfiniteDataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
sampler=train_sampler,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn)
return dataloader, dataset


class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
'''
Dataloader that reuses workers.
""" Dataloader that reuses workers.

Uses same syntax as vanilla DataLoader.
'''
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()

def __len__(self):
@@ -91,22 +90,20 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
for i in range(len(self)):
yield next(self.iterator)

class _RepeatSampler(object):
""" Sampler that repeats forever.

class _RepeatSampler(object):
'''
Sampler that repeats forever.
Args:
sampler (Sampler)
"""

Args:
sampler (Sampler)
'''
def __init__(self, sampler):
self.sampler = sampler

def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)

def __iter__(self):
while True:
yield from iter(self.sampler)

class LoadImages: # for inference
def __init__(self, path, img_size=640):
@@ -684,14 +681,10 @@ def load_mosaic(self, index):
# Concat/clip labels
if len(labels4):
labels4 = np.concatenate(labels4, 0)
# np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine

# Replicate
# img4, labels4 = replicate(img4, labels4)
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective
# img4, labels4 = replicate(img4, labels4) # replicate

# Augment
# img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
img4, labels4 = random_perspective(img4, labels4,
degrees=self.hyp['degrees'],
translate=self.hyp['translate'],

Loading…
Cancel
Save