Browse Source

Update DDP for `torch.distributed.run` with `gloo` backend (#3680)

* Update DDP for `torch.distributed.run`

* Add LOCAL_RANK

* remove opt.local_rank

* backend="gloo|nccl"

* print

* print

* debug

* debug

* os.getenv

* gloo

* gloo

* gloo

* cleanup

* fix getenv

* cleanup

* cleanup destroy

* try nccl

* return opt

* add --local_rank

* add timeout

* add init_method

* gloo

* move destroy

* move destroy

* move print(opt) under if RANK

* destroy only RANK 0

* move destroy inside train()

* restore destroy outside train()

* update print(opt)

* cleanup

* nccl

* gloo with 60 second timeout

* update namespace printing
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
fad27c0046
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 61 deletions
  1. +3
    -3
      detect.py
  2. +1
    -1
      models/export.py
  3. +2
    -2
      test.py
  4. +46
    -49
      train.py
  5. +2
    -2
      utils/datasets.py
  6. +3
    -2
      utils/torch_utils.py
  7. +4
    -2
      utils/wandb_logging/wandb_utils.py

+ 3
- 3
detect.py View File



from models.experimental import attempt_load from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import colors, plot_one_box from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized from utils.torch_utils import select_device, load_classifier, time_synchronized






def main(opt): def main(opt):
print(opt)
print(colorstr('detect: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop')) check_requirements(exclude=('tensorboard', 'thop'))
detect(**vars(opt)) detect(**vars(opt))



+ 1
- 1
models/export.py View File





def main(opt): def main(opt):
print(opt)
set_logging() set_logging()
print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
export(**vars(opt)) export(**vars(opt))





+ 2
- 2
test.py View File

device = next(model.parameters()).device # get model device device = next(model.parameters()).device # get model device


else: # called directly else: # called directly
set_logging()
device = select_device(device, batch_size=batch_size) device = select_device(device, batch_size=batch_size)


# Directories # Directories




def main(opt): def main(opt):
print(opt)
set_logging()
print(colorstr('test: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop')) check_requirements(exclude=('tensorboard', 'thop'))


if opt.task in ('train', 'val', 'test'): # run normally if opt.task in ('train', 'val', 'test'): # run normally

+ 46
- 49
train.py View File

from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))




def train(hyp, # path/to/hyp.yaml or hyp dictionary def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt, opt,
device, device,
): ):
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
opt.single_cls
save_dir, epochs, batch_size, total_batch_size, weights, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.single_cls


# Directories # Directories
wdir = save_dir / 'weights' wdir = save_dir / 'weights'
# Configure # Configure
plots = not opt.evolve # create plots plots = not opt.evolve # create plots
cuda = device.type != 'cpu' cuda = device.type != 'cpu'
init_seeds(2 + rank)
init_seeds(2 + RANK)
with open(opt.data) as f: with open(opt.data) as f:
data_dict = yaml.safe_load(f) # data dict data_dict = yaml.safe_load(f) # data dict


# Loggers # Loggers
loggers = {'wandb': None, 'tb': None} # loggers dict loggers = {'wandb': None, 'tb': None} # loggers dict
if rank in [-1, 0]:
if RANK in [-1, 0]:
# TensorBoard # TensorBoard
if not opt.evolve: if not opt.evolve:
prefix = colorstr('tensorboard: ') prefix = colorstr('tensorboard: ')
# Model # Model
pretrained = weights.endswith('.pt') pretrained = weights.endswith('.pt')
if pretrained: if pretrained:
with torch_distributed_zero_first(rank):
with torch_distributed_zero_first(RANK):
weights = attempt_download(weights) # download if not found locally weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
else: else:
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(rank):
with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check check_dataset(data_dict) # check
train_path = data_dict['train'] train_path = data_dict['train']
test_path = data_dict['val'] test_path = data_dict['val']
# plot_lr_scheduler(optimizer, scheduler, epochs) # plot_lr_scheduler(optimizer, scheduler, epochs)


# EMA # EMA
ema = ModelEMA(model) if rank in [-1, 0] else None
ema = ModelEMA(model) if RANK in [-1, 0] else None


# Resume # Resume
start_epoch, best_fitness = 0, 0.0 start_epoch, best_fitness = 0, 0.0
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples


# DP mode # DP mode
if cuda and rank == -1 and torch.cuda.device_count() > 1:
if cuda and RANK == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)


# SyncBatchNorm # SyncBatchNorm
if opt.sync_bn and cuda and rank != -1:
if opt.sync_bn and cuda and RANK != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
logger.info('Using SyncBatchNorm()') logger.info('Using SyncBatchNorm()')


# Trainloader # Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls, dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size, workers=opt.workers,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
workers=opt.workers,
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)


# Process 0 # Process 0
if rank in [-1, 0]:
if RANK in [-1, 0]:
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls, testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers,
workers=opt.workers,
pad=0.5, prefix=colorstr('val: '))[0] pad=0.5, prefix=colorstr('val: '))[0]


if not opt.resume: if not opt.resume:
model.half().float() # pre-reduce anchor precision model.half().float() # pre-reduce anchor precision


# DDP mode # DDP mode
if cuda and rank != -1:
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank,
if cuda and RANK != -1:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK,
# nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698 # nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules())) find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))


# Update image weights (optional) # Update image weights (optional)
if opt.image_weights: if opt.image_weights:
# Generate indices # Generate indices
if rank in [-1, 0]:
if RANK in [-1, 0]:
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
# Broadcast if DDP # Broadcast if DDP
if rank != -1:
indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
if RANK != -1:
indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
dist.broadcast(indices, 0) dist.broadcast(indices, 0)
if rank != 0:
if RANK != 0:
dataset.indices = indices.cpu().numpy() dataset.indices = indices.cpu().numpy()


# Update mosaic border # Update mosaic border
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders # dataset.mosaic_border = [b - imgsz, -b] # height, width borders


mloss = torch.zeros(4, device=device) # mean losses mloss = torch.zeros(4, device=device) # mean losses
if rank != -1:
if RANK != -1:
dataloader.sampler.set_epoch(epoch) dataloader.sampler.set_epoch(epoch)
pbar = enumerate(dataloader) pbar = enumerate(dataloader)
logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size')) logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
if rank in [-1, 0]:
if RANK in [-1, 0]:
pbar = tqdm(pbar, total=nb) # progress bar pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad() optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
with amp.autocast(enabled=cuda): with amp.autocast(enabled=cuda):
pred = model(imgs) # forward pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
if RANK != -1:
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
if opt.quad: if opt.quad:
loss *= 4. loss *= 4.


ema.update(model) ema.update(model)


# Print # Print
if rank in [-1, 0]:
if RANK in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % ( s = ('%10s' * 2 + '%10.4g' * 6) % (
scheduler.step() scheduler.step()


# DDP process 0 or single-GPU # DDP process 0 or single-GPU
if rank in [-1, 0]:
if RANK in [-1, 0]:
# mAP # mAP
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs


# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training ----------------------------------------------------------------------------------------------------- # end training -----------------------------------------------------------------------------------------------------
if rank in [-1, 0]:
if RANK in [-1, 0]:
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots: if plots:
plot_results(save_dir=save_dir) # save as results.png plot_results(save_dir=save_dir) # save as results.png
name='run_' + wandb_logger.wandb_run.id + '_model', name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped']) aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run() wandb_logger.finish_run()
else:
dist.destroy_process_group()

torch.cuda.empty_cache() torch.cuda.empty_cache()
return results return results


parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
parser.add_argument('--project', default='runs/train', help='save to project/name') parser.add_argument('--project', default='runs/train', help='save to project/name')
parser.add_argument('--entity', default=None, help='W&B entity') parser.add_argument('--entity', default=None, help='W&B entity')
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B') parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
opt = parser.parse_args() opt = parser.parse_args()

# Set DDP variables
opt.world_size = int(getattr(os.environ, 'WORLD_SIZE', 1))
opt.global_rank = int(getattr(os.environ, 'RANK', -1))
return opt return opt




def main(opt): def main(opt):
print(opt)
set_logging(opt.global_rank)
if opt.global_rank in [-1, 0]:
set_logging(RANK)
if RANK in [-1, 0]:
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_git_status() check_git_status()
check_requirements(exclude=['thop']) check_requirements(exclude=['thop'])


if opt.resume and not wandb_run: # resume an interrupted run if opt.resume and not wandb_run: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
apriori = opt.global_rank, opt.local_rank
with open(Path(ckpt).parent.parent / 'opt.yaml') as f: with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(**yaml.safe_load(f)) # replace opt = argparse.Namespace(**yaml.safe_load(f)) # replace
opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = \
'', ckpt, True, opt.total_batch_size, *apriori # reinstate
opt.cfg, opt.weights, opt.resume, opt.batch_size = '', ckpt, True, opt.total_batch_size # reinstate
logger.info('Resuming training from %s' % ckpt) logger.info('Resuming training from %s' % ckpt)
else: else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
# DDP mode # DDP mode
opt.total_batch_size = opt.batch_size opt.total_batch_size = opt.batch_size
device = select_device(opt.device, batch_size=opt.batch_size) device = select_device(opt.device, batch_size=opt.batch_size)
if opt.local_rank != -1:
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
device = torch.device('cuda', opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
if LOCAL_RANK != -1:
from datetime import timedelta
assert torch.cuda.device_count() > LOCAL_RANK, 'too few GPUS for DDP command'
torch.cuda.set_device(LOCAL_RANK)
device = torch.device('cuda', LOCAL_RANK)
dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60))
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training' assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
opt.batch_size = opt.total_batch_size // opt.world_size
opt.batch_size = opt.total_batch_size // WORLD_SIZE


# Train # Train
logger.info(opt)
if not opt.evolve: if not opt.evolve:
train(opt.hyp, opt, device) train(opt.hyp, opt, device)
if WORLD_SIZE > 1 and RANK == 0:
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]


# Evolve hyperparameters (optional) # Evolve hyperparameters (optional)
else: else:


with open(opt.hyp) as f: with open(opt.hyp) as f:
hyp = yaml.safe_load(f) # load hyps dict hyp = yaml.safe_load(f) # load hyps dict
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch opt.notest, opt.nosave = True, True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here

+ 2
- 2
utils/datasets.py View File





def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank): with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size, dataset = LoadImagesAndLabels(path, imgsz, batch_size,
prefix=prefix) prefix=prefix)


batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()

+ 3
- 2
utils/torch_utils.py View File



import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
Decorator to make all processes in distributed training wait for each local_master to do something. Decorator to make all processes in distributed training wait for each local_master to do something.
""" """
if local_rank not in [-1, 0]: if local_rank not in [-1, 0]:
torch.distributed.barrier()
dist.barrier()
yield yield
if local_rank == 0: if local_rank == 0:
torch.distributed.barrier()
dist.barrier()




def init_torch_seeds(seed=0): def init_torch_seeds(seed=0):

+ 4
- 2
utils/wandb_logging/wandb_utils.py View File

"""Utilities and tools for tracking runs with Weights & Biases.""" """Utilities and tools for tracking runs with Weights & Biases."""
import logging import logging
import os
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
except ImportError: except ImportError:
wandb = None wandb = None


RANK = int(os.getenv('RANK', -1))
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'








def check_wandb_resume(opt): def check_wandb_resume(opt):
process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None
if isinstance(opt.resume, str): if isinstance(opt.resume, str):
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
if opt.global_rank not in [-1, 0]: # For resuming DDP runs
if RANK not in [-1, 0]: # For resuming DDP runs
entity, project, run_id, model_artifact_name = get_run_info(opt.resume) entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
api = wandb.Api() api = wandb.Api()
artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest') artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')

Loading…
Cancel
Save