Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

188 lines
8.0KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Logging utils
  4. """
  5. import os
  6. import warnings
  7. from threading import Thread
  8. import pkg_resources as pkg
  9. import torch
  10. from torch.utils.tensorboard import SummaryWriter
  11. from utils.general import colorstr, cv2, emojis
  12. from utils.loggers.wandb.wandb_utils import WandbLogger
  13. from utils.plots import plot_images, plot_results
  14. from utils.torch_utils import de_parallel
  15. LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases
  16. RANK = int(os.getenv('RANK', -1))
  17. try:
  18. import wandb
  19. assert hasattr(wandb, '__version__') # verify package import not local dir
  20. if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.2') and RANK in [0, -1]:
  21. try:
  22. wandb_login_success = wandb.login(timeout=30)
  23. except wandb.errors.UsageError: # known non-TTY terminal issue
  24. wandb_login_success = False
  25. if not wandb_login_success:
  26. wandb = None
  27. except (ImportError, AssertionError):
  28. wandb = None
  29. class Loggers():
  30. # YOLOv5 Loggers class
  31. def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS):
  32. self.save_dir = save_dir
  33. self.weights = weights
  34. self.opt = opt
  35. self.hyp = hyp
  36. self.logger = logger # for printing results to console
  37. self.include = include
  38. self.keys = [
  39. 'train/box_loss',
  40. 'train/obj_loss',
  41. 'train/cls_loss', # train loss
  42. 'metrics/precision',
  43. 'metrics/recall',
  44. 'metrics/mAP_0.5',
  45. 'metrics/mAP_0.5:0.95', # metrics
  46. 'val/box_loss',
  47. 'val/obj_loss',
  48. 'val/cls_loss', # val loss
  49. 'x/lr0',
  50. 'x/lr1',
  51. 'x/lr2'] # params
  52. self.best_keys = ['best/epoch', 'best/precision', 'best/recall', 'best/mAP_0.5', 'best/mAP_0.5:0.95']
  53. for k in LOGGERS:
  54. setattr(self, k, None) # init empty logger dictionary
  55. self.csv = True # always log to csv
  56. # Message
  57. if not wandb:
  58. prefix = colorstr('Weights & Biases: ')
  59. s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
  60. self.logger.info(emojis(s))
  61. # TensorBoard
  62. s = self.save_dir
  63. if 'tb' in self.include and not self.opt.evolve:
  64. prefix = colorstr('TensorBoard: ')
  65. self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
  66. self.tb = SummaryWriter(str(s))
  67. # W&B
  68. if wandb and 'wandb' in self.include:
  69. wandb_artifact_resume = isinstance(self.opt.resume, str) and self.opt.resume.startswith('wandb-artifact://')
  70. run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume and not wandb_artifact_resume else None
  71. self.opt.hyp = self.hyp # add hyperparameters
  72. self.wandb = WandbLogger(self.opt, run_id)
  73. # temp warn. because nested artifacts not supported after 0.12.10
  74. if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.11'):
  75. self.logger.warning(
  76. "YOLOv5 temporarily requires wandb version 0.12.10 or below. Some features may not work as expected."
  77. )
  78. else:
  79. self.wandb = None
  80. def on_train_start(self):
  81. # Callback runs on train start
  82. pass
  83. def on_pretrain_routine_end(self):
  84. # Callback runs on pre-train routine end
  85. paths = self.save_dir.glob('*labels*.jpg') # training labels
  86. if self.wandb:
  87. self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
  88. def on_train_batch_end(self, ni, model, imgs, targets, paths, plots, sync_bn):
  89. # Callback runs on train batch end
  90. if plots:
  91. if ni == 0:
  92. if not sync_bn: # tb.add_graph() --sync known issue https://github.com/ultralytics/yolov5/issues/3754
  93. with warnings.catch_warnings():
  94. warnings.simplefilter('ignore') # suppress jit trace warning
  95. self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
  96. if ni < 3:
  97. f = self.save_dir / f'train_batch{ni}.jpg' # filename
  98. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  99. if self.wandb and ni == 10:
  100. files = sorted(self.save_dir.glob('train*.jpg'))
  101. self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
  102. def on_train_epoch_end(self, epoch):
  103. # Callback runs on train epoch end
  104. if self.wandb:
  105. self.wandb.current_epoch = epoch + 1
  106. def on_val_image_end(self, pred, predn, path, names, im):
  107. # Callback runs on val image end
  108. if self.wandb:
  109. self.wandb.val_one_image(pred, predn, path, names, im)
  110. def on_val_end(self):
  111. # Callback runs on val end
  112. if self.wandb:
  113. files = sorted(self.save_dir.glob('val*.jpg'))
  114. self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
  115. def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
  116. # Callback runs at the end of each fit (train+val) epoch
  117. x = {k: v for k, v in zip(self.keys, vals)} # dict
  118. if self.csv:
  119. file = self.save_dir / 'results.csv'
  120. n = len(x) + 1 # number of cols
  121. s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
  122. with open(file, 'a') as f:
  123. f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
  124. if self.tb:
  125. for k, v in x.items():
  126. self.tb.add_scalar(k, v, epoch)
  127. if self.wandb:
  128. if best_fitness == fi:
  129. best_results = [epoch] + vals[3:7]
  130. for i, name in enumerate(self.best_keys):
  131. self.wandb.wandb_run.summary[name] = best_results[i] # log best results in the summary
  132. self.wandb.log(x)
  133. self.wandb.end_epoch(best_result=best_fitness == fi)
  134. def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
  135. # Callback runs on model save event
  136. if self.wandb:
  137. if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
  138. self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
  139. def on_train_end(self, last, best, plots, epoch, results):
  140. # Callback runs on training end
  141. if plots:
  142. plot_results(file=self.save_dir / 'results.csv') # save results.png
  143. files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
  144. files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
  145. if self.tb:
  146. for f in files:
  147. self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
  148. if self.wandb:
  149. self.wandb.log({k: v for k, v in zip(self.keys[3:10], results)}) # log best.pt val results
  150. self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
  151. # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
  152. if not self.opt.evolve:
  153. wandb.log_artifact(str(best if best.exists() else last),
  154. type='model',
  155. name='run_' + self.wandb.wandb_run.id + '_model',
  156. aliases=['latest', 'best', 'stripped'])
  157. self.wandb.finish_run()
  158. def on_params_update(self, params):
  159. # Update hyperparams or configs of the experiment
  160. # params: A dict containing {param: value} pairs
  161. if self.wandb:
  162. self.wandb.wandb_run.config.update(params, allow_val_change=True)