From c923fbff90f5af100eed55b34aff2481ffe658f0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 16 Dec 2020 17:52:12 -0800 Subject: [PATCH] W&B artifacts feature addition (#1712) * Log artifacts * cleanup --- train.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 374f26c..cb6e21a 100644 --- a/train.py +++ b/train.py @@ -386,10 +386,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): if rank in [-1, 0]: # Strip optimizers + final = best if best.exists() else last # final model for f in [last, best]: - if f.exists(): # is *.pt - strip_optimizer(f) # strip optimizer - os.system('gsutil cp %s gs://%s/weights' % (f, opt.bucket)) if opt.bucket else None # upload + if f.exists(): + strip_optimizer(f) # strip optimizers + if opt.bucket: + os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload # Plots if plots: @@ -398,9 +400,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png'] wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files if (save_dir / f).exists()]}) - logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) + if opt.log_artifacts: + wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem) # Test best.pt + logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) if opt.data.endswith('coco.yaml') and nc == 80: # if COCO for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests results, _, _ = test.test(opt.data, @@ -408,7 +412,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): imgsz=imgsz_test, conf_thres=conf, iou_thres=iou, - model=attempt_load(best if best.exists() else last, device).half(), + model=attempt_load(final, device).half(), single_cls=opt.single_cls, dataloader=testloader, save_dir=save_dir, @@ -448,6 +452,7 @@ if __name__ == '__main__': parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100') + parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model') parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') parser.add_argument('--project', default='runs/train', help='save to project/name') parser.add_argument('--name', default='exp', help='save to project/name')