Browse Source

Remove `opt` from `create_dataloader()`` (#3552)

modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
958ab92dc1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 12 deletions
  1. +1
    -1
      test.py
  2. +9
    -8
      train.py
  3. +3
    -3
      utils/datasets.py

+ 1
- 1
test.py View File

@@ -88,7 +88,7 @@ def test(data,
if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, opt, pad=0.5, rect=True,
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=0.5, rect=True,
prefix=colorstr(f'{task}: '))[0]

seen = 0

+ 9
- 8
train.py View File

@@ -41,8 +41,9 @@ logger = logging.getLogger(__name__)

def train(hyp, opt, device, tb_writer=None):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
opt.single_cls

# Directories
wdir = save_dir / 'weights'
@@ -75,8 +76,8 @@ def train(hyp, opt, device, tb_writer=None):
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming

nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset

@@ -187,7 +188,7 @@ def train(hyp, opt, device, tb_writer=None):
logger.info('Using SyncBatchNorm()')

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size, workers=opt.workers,
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
@@ -197,7 +198,7 @@ def train(hyp, opt, device, tb_writer=None):

# Process 0
if rank in [-1, 0]:
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers,
pad=0.5, prefix=colorstr('val: '))[0]
@@ -357,7 +358,7 @@ def train(hyp, opt, device, tb_writer=None):
batch_size=batch_size * 2,
imgsz=imgsz_test,
model=ema.ema,
single_cls=opt.single_cls,
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=is_coco and final_epoch,
@@ -429,7 +430,7 @@ def train(hyp, opt, device, tb_writer=None):
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=opt.single_cls,
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,

+ 3
- 3
utils/datasets.py View File

@@ -62,8 +62,8 @@ def exif_size(img):
return s


def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
@@ -71,7 +71,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=opt.single_cls,
single_cls=single_cls,
stride=int(stride),
pad=pad,
image_weights=image_weights,

Loading…
Cancel
Save