Browse Source

max workers for dataloader (#722)

5.0
Marc GitHub 4 years ago
parent
commit
a925f283a7
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 4 deletions
  1. +4
    -2
      train.py
  2. +2
    -2
      utils/datasets.py

+ 4
- 2
train.py View File

@@ -159,7 +159,7 @@ def train(hyp, opt, device, tb_writer=None):
# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size)
world_size=opt.world_size, workers=opt.workers)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
@@ -168,7 +168,8 @@ def train(hyp, opt, device, tb_writer=None):
if rank in [-1, 0]:
# local_rank is set to -1. Because only the first process is expected to do evaluation.
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
cache=opt.cache_images, rect=True, rank=-1, world_size=opt.world_size)[0]
cache=opt.cache_images, rect=True, rank=-1, world_size=opt.world_size,
workers=opt.workers)[0]

# Model parameters
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
@@ -403,6 +404,7 @@ if __name__ == '__main__':
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
parser.add_argument('--workers', type=int, default=8, help='maximum number of workers for dataloader')
opt = parser.parse_args()

# Set DDP variables

+ 2
- 2
utils/datasets.py View File

@@ -47,7 +47,7 @@ def exif_size(img):


def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1):
rank=-1, world_size=1, workers=8):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
@@ -61,7 +61,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
rank=rank)

batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, 8]) # number of workers
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,

Loading…
Cancel
Save