Browse Source

add stride to datasets.py

5.0
Glenn Jocher 4 years ago
parent
commit
b8557f87e3
3 changed files with 7 additions and 4 deletions
  1. +1
    -0
      test.py
  2. +4
    -2
      train.py
  3. +2
    -2
      utils/datasets.py

+ 1
- 0
test.py View File

batch_size, batch_size,
rect=True, # rectangular inference rect=True, # rectangular inference
single_cls=opt.single_cls, # single class mode single_cls=opt.single_cls, # single class mode
stride=int(max(model.stride)), # model stride
pad=0.5) # padding pad=0.5) # padding
batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers

+ 4
- 2
train.py View File

hyp=hyp, # augmentation hyperparameters hyp=hyp, # augmentation hyperparameters
rect=opt.rect, # rectangular training rect=opt.rect, # rectangular training
cache_images=opt.cache_images, cache_images=opt.cache_images,
single_cls=opt.single_cls)
single_cls=opt.single_cls,
stride=gs)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg) assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)


hyp=hyp, hyp=hyp,
rect=True, rect=True,
cache_images=opt.cache_images, cache_images=opt.cache_images,
single_cls=opt.single_cls),
single_cls=opt.single_cls,
stride=gs),
batch_size=batch_size, batch_size=batch_size,
num_workers=nw, num_workers=nw,
pin_memory=True, pin_memory=True,

+ 2
- 2
utils/datasets.py View File



class LoadImagesAndLabels(Dataset): # for training/testing class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, pad=0.0):
cache_images=False, single_cls=False, stride=32, pad=0.0):
try: try:
path = str(Path(path)) # os-agnostic path = str(Path(path)) # os-agnostic
parent = str(Path(path).parent) + os.sep parent = str(Path(path).parent) + os.sep
elif mini > 1: elif mini > 1:
shapes[i] = [1, 1 / mini] shapes[i] = [1, 1 / mini]


self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32. + pad).astype(np.int) * 32
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride


# Cache labels # Cache labels
self.imgs = [None] * n self.imgs = [None] * n

Loading…
Cancel
Save