您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

150 行
6.4KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Logging utils
  4. """
  5. import warnings
  6. from threading import Thread
  7. import torch
  8. from torch.utils.tensorboard import SummaryWriter
  9. from utils.general import colorstr, emojis
  10. from utils.loggers.wandb.wandb_utils import WandbLogger
  11. from utils.plots import plot_images, plot_results
  12. from utils.torch_utils import de_parallel
  13. LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases
  14. try:
  15. import wandb
  16. assert hasattr(wandb, '__version__') # verify package import not local dir
  17. except (ImportError, AssertionError):
  18. wandb = None
  19. class Loggers():
  20. # YOLOv5 Loggers class
  21. def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS):
  22. self.save_dir = save_dir
  23. self.weights = weights
  24. self.opt = opt
  25. self.hyp = hyp
  26. self.logger = logger # for printing results to console
  27. self.include = include
  28. self.keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  29. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
  30. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  31. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  32. for k in LOGGERS:
  33. setattr(self, k, None) # init empty logger dictionary
  34. self.csv = True # always log to csv
  35. # Message
  36. if not wandb:
  37. prefix = colorstr('Weights & Biases: ')
  38. s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
  39. print(emojis(s))
  40. # TensorBoard
  41. s = self.save_dir
  42. if 'tb' in self.include and not self.opt.evolve:
  43. prefix = colorstr('TensorBoard: ')
  44. self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
  45. self.tb = SummaryWriter(str(s))
  46. # W&B
  47. if wandb and 'wandb' in self.include:
  48. wandb_artifact_resume = isinstance(self.opt.resume, str) and self.opt.resume.startswith('wandb-artifact://')
  49. run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume and not wandb_artifact_resume else None
  50. self.opt.hyp = self.hyp # add hyperparameters
  51. self.wandb = WandbLogger(self.opt, run_id)
  52. else:
  53. self.wandb = None
  54. def on_pretrain_routine_end(self):
  55. # Callback runs on pre-train routine end
  56. paths = self.save_dir.glob('*labels*.jpg') # training labels
  57. if self.wandb:
  58. self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
  59. def on_train_batch_end(self, ni, model, imgs, targets, paths, plots, sync_bn):
  60. # Callback runs on train batch end
  61. if plots:
  62. if ni == 0:
  63. if not sync_bn: # tb.add_graph() --sync known issue https://github.com/ultralytics/yolov5/issues/3754
  64. with warnings.catch_warnings():
  65. warnings.simplefilter('ignore') # suppress jit trace warning
  66. self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
  67. if ni < 3:
  68. f = self.save_dir / f'train_batch{ni}.jpg' # filename
  69. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  70. if self.wandb and ni == 10:
  71. files = sorted(self.save_dir.glob('train*.jpg'))
  72. self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
  73. def on_train_epoch_end(self, epoch):
  74. # Callback runs on train epoch end
  75. if self.wandb:
  76. self.wandb.current_epoch = epoch + 1
  77. def on_val_image_end(self, pred, predn, path, names, im):
  78. # Callback runs on val image end
  79. if self.wandb:
  80. self.wandb.val_one_image(pred, predn, path, names, im)
  81. def on_val_end(self):
  82. # Callback runs on val end
  83. if self.wandb:
  84. files = sorted(self.save_dir.glob('val*.jpg'))
  85. self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
  86. def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
  87. # Callback runs at the end of each fit (train+val) epoch
  88. x = {k: v for k, v in zip(self.keys, vals)} # dict
  89. if self.csv:
  90. file = self.save_dir / 'results.csv'
  91. n = len(x) + 1 # number of cols
  92. s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
  93. with open(file, 'a') as f:
  94. f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
  95. if self.tb:
  96. for k, v in x.items():
  97. self.tb.add_scalar(k, v, epoch)
  98. if self.wandb:
  99. self.wandb.log(x)
  100. self.wandb.end_epoch(best_result=best_fitness == fi)
  101. def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
  102. # Callback runs on model save event
  103. if self.wandb:
  104. if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
  105. self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
  106. def on_train_end(self, last, best, plots, epoch):
  107. # Callback runs on training end
  108. if plots:
  109. plot_results(file=self.save_dir / 'results.csv') # save results.png
  110. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  111. files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
  112. if self.tb:
  113. import cv2
  114. for f in files:
  115. self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
  116. if self.wandb:
  117. self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
  118. # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
  119. if not self.opt.evolve:
  120. wandb.log_artifact(str(best if best.exists() else last), type='model',
  121. name='run_' + self.wandb.wandb_run.id + '_model',
  122. aliases=['latest', 'best', 'stripped'])
  123. self.wandb.finish_run()
  124. else:
  125. self.wandb.finish_run()
  126. self.wandb = WandbLogger(self.opt)