Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

145 lines
6.1KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Logging utils
  4. """
  5. import warnings
  6. from threading import Thread
  7. import torch
  8. from torch.utils.tensorboard import SummaryWriter
  9. from utils.general import colorstr, emojis
  10. from utils.loggers.wandb.wandb_utils import WandbLogger
  11. from utils.plots import plot_images, plot_results
  12. from utils.torch_utils import de_parallel
  13. LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases
  14. try:
  15. import wandb
  16. assert hasattr(wandb, '__version__') # verify package import not local dir
  17. except (ImportError, AssertionError):
  18. wandb = None
  19. class Loggers():
  20. # YOLOv5 Loggers class
  21. def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS):
  22. self.save_dir = save_dir
  23. self.weights = weights
  24. self.opt = opt
  25. self.hyp = hyp
  26. self.logger = logger # for printing results to console
  27. self.include = include
  28. self.keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  29. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
  30. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  31. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  32. for k in LOGGERS:
  33. setattr(self, k, None) # init empty logger dictionary
  34. self.csv = True # always log to csv
  35. # Message
  36. if not wandb:
  37. prefix = colorstr('Weights & Biases: ')
  38. s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
  39. print(emojis(s))
  40. # TensorBoard
  41. s = self.save_dir
  42. if 'tb' in self.include and not self.opt.evolve:
  43. prefix = colorstr('TensorBoard: ')
  44. self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
  45. self.tb = SummaryWriter(str(s))
  46. # W&B
  47. if wandb and 'wandb' in self.include:
  48. wandb_artifact_resume = isinstance(self.opt.resume, str) and self.opt.resume.startswith('wandb-artifact://')
  49. run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume and not wandb_artifact_resume else None
  50. self.opt.hyp = self.hyp # add hyperparameters
  51. self.wandb = WandbLogger(self.opt, run_id)
  52. else:
  53. self.wandb = None
  54. def on_pretrain_routine_end(self):
  55. # Callback runs on pre-train routine end
  56. paths = self.save_dir.glob('*labels*.jpg') # training labels
  57. if self.wandb:
  58. self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
  59. def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
  60. # Callback runs on train batch end
  61. if plots:
  62. if ni == 0:
  63. with warnings.catch_warnings():
  64. warnings.simplefilter('ignore') # suppress jit trace warning
  65. self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
  66. if ni < 3:
  67. f = self.save_dir / f'train_batch{ni}.jpg' # filename
  68. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  69. if self.wandb and ni == 10:
  70. files = sorted(self.save_dir.glob('train*.jpg'))
  71. self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
  72. def on_train_epoch_end(self, epoch):
  73. # Callback runs on train epoch end
  74. if self.wandb:
  75. self.wandb.current_epoch = epoch + 1
  76. def on_val_image_end(self, pred, predn, path, names, im):
  77. # Callback runs on val image end
  78. if self.wandb:
  79. self.wandb.val_one_image(pred, predn, path, names, im)
  80. def on_val_end(self):
  81. # Callback runs on val end
  82. if self.wandb:
  83. files = sorted(self.save_dir.glob('val*.jpg'))
  84. self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
  85. def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
  86. # Callback runs at the end of each fit (train+val) epoch
  87. x = {k: v for k, v in zip(self.keys, vals)} # dict
  88. if self.csv:
  89. file = self.save_dir / 'results.csv'
  90. n = len(x) + 1 # number of cols
  91. s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
  92. with open(file, 'a') as f:
  93. f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
  94. if self.tb:
  95. for k, v in x.items():
  96. self.tb.add_scalar(k, v, epoch)
  97. if self.wandb:
  98. self.wandb.log(x)
  99. self.wandb.end_epoch(best_result=best_fitness == fi)
  100. def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
  101. # Callback runs on model save event
  102. if self.wandb:
  103. if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
  104. self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
  105. def on_train_end(self, last, best, plots, epoch):
  106. # Callback runs on training end
  107. if plots:
  108. plot_results(file=self.save_dir / 'results.csv') # save results.png
  109. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  110. files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
  111. if self.tb:
  112. import cv2
  113. for f in files:
  114. self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
  115. if self.wandb:
  116. self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
  117. # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
  118. wandb.log_artifact(str(best if best.exists() else last), type='model',
  119. name='run_' + self.wandb.wandb_run.id + '_model',
  120. aliases=['latest', 'best', 'stripped'])
  121. self.wandb.finish_run()