|
|
@@ -539,7 +539,7 @@ def main(opt): |
|
|
|
assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command' |
|
|
|
torch.cuda.set_device(LOCAL_RANK) |
|
|
|
device = torch.device('cuda', LOCAL_RANK) |
|
|
|
dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60)) |
|
|
|
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=60)) |
|
|
|
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count' |
|
|
|
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training' |
|
|
|
|