From 045d5d86299a4a724fca40faaf0225ded91a68b4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 17 Jun 2021 22:12:42 +0200 Subject: [PATCH] Update TensorBoard (#3669) --- train.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/train.py b/train.py index 113d084..9d71e70 100644 --- a/train.py +++ b/train.py @@ -42,7 +42,6 @@ logger = logging.getLogger(__name__) def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, - tb_writer=None ): save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \ @@ -74,9 +73,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary with open(opt.data) as f: data_dict = yaml.safe_load(f) # data dict - # Logging- Doing this before checking the dataset. Might update data_dict - loggers = {'wandb': None} # loggers dict + # Loggers + loggers = {'wandb': None, 'tb': None} # loggers dict if rank in [-1, 0]: + # TensorBoard + if not opt.evolve: + prefix = colorstr('tensorboard: ') + logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/") + loggers['tb'] = SummaryWriter(opt.save_dir) + + # W&B opt.hyp = hyp # add hyperparameters run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict) @@ -219,8 +225,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # model._initialize_biases(cf.to(device)) if plots: plot_labels(labels, names, save_dir, loggers) - if tb_writer: - tb_writer.add_histogram('classes', c, 0) + if loggers['tb']: + loggers['tb'].add_histogram('classes', c, 0) # TensorBoard # Anchors if not opt.noautoanchor: @@ -341,10 +347,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if plots and ni < 3: f = save_dir / f'train_batch{ni}.jpg' # filename Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() - if tb_writer and ni == 0: + if loggers['tb'] and ni == 0: # TensorBoard with warnings.catch_warnings(): warnings.simplefilter('ignore') # suppress jit trace warning - tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) + loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) elif plots and ni == 10 and wandb_logger.wandb: wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg') if x.exists()]}) @@ -352,7 +358,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # end batch ------------------------------------------------------------------------------------------------ # Scheduler - lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard + lr = [x['lr'] for x in optimizer.param_groups] # for loggers scheduler.step() # DDP process 0 or single-GPU @@ -385,8 +391,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss 'x/lr0', 'x/lr1', 'x/lr2'] # params for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): - if tb_writer: - tb_writer.add_scalar(tag, x, epoch) # tensorboard + if loggers['tb']: + loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard if wandb_logger.wandb: wandb_logger.log({tag: x}) # W&B @@ -537,12 +543,7 @@ if __name__ == '__main__': # Train logger.info(opt) if not opt.evolve: - tb_writer = None # init loggers - if opt.global_rank in [-1, 0]: - prefix = colorstr('tensorboard: ') - logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/") - tb_writer = SummaryWriter(opt.save_dir) # Tensorboard - train(opt.hyp, opt, device, tb_writer) + train(opt.hyp, opt, device) # Evolve hyperparameters (optional) else: