TensorRT转化代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
6.1KB

  1. # YOLOv5 experiment logging utils
  2. import warnings
  3. from threading import Thread
  4. import torch
  5. from torch.utils.tensorboard import SummaryWriter
  6. from utils.general import colorstr, emojis
  7. from utils.loggers.wandb.wandb_utils import WandbLogger
  8. from utils.plots import plot_images, plot_results
  9. from utils.torch_utils import de_parallel
  10. LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases
  11. try:
  12. import wandb
  13. assert hasattr(wandb, '__version__') # verify package import not local dir
  14. except (ImportError, AssertionError):
  15. wandb = None
  16. class Loggers():
  17. # YOLOv5 Loggers class
  18. def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS):
  19. self.save_dir = save_dir
  20. self.weights = weights
  21. self.opt = opt
  22. self.hyp = hyp
  23. self.logger = logger # for printing results to console
  24. self.include = include
  25. self.keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  26. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
  27. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  28. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  29. for k in LOGGERS:
  30. setattr(self, k, None) # init empty logger dictionary
  31. self.csv = True # always log to csv
  32. # Message
  33. if not wandb:
  34. prefix = colorstr('Weights & Biases: ')
  35. s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
  36. print(emojis(s))
  37. # TensorBoard
  38. s = self.save_dir
  39. if 'tb' in self.include and not self.opt.evolve:
  40. prefix = colorstr('TensorBoard: ')
  41. self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
  42. self.tb = SummaryWriter(str(s))
  43. # W&B
  44. if wandb and 'wandb' in self.include:
  45. wandb_artifact_resume = isinstance(self.opt.resume, str) and self.opt.resume.startswith('wandb-artifact://')
  46. run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume and not wandb_artifact_resume else None
  47. self.opt.hyp = self.hyp # add hyperparameters
  48. self.wandb = WandbLogger(self.opt, run_id)
  49. else:
  50. self.wandb = None
  51. def on_pretrain_routine_end(self):
  52. # Callback runs on pre-train routine end
  53. paths = self.save_dir.glob('*labels*.jpg') # training labels
  54. if self.wandb:
  55. self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
  56. def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
  57. # Callback runs on train batch end
  58. if plots:
  59. if ni == 0:
  60. with warnings.catch_warnings():
  61. warnings.simplefilter('ignore') # suppress jit trace warning
  62. self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
  63. if ni < 3:
  64. f = self.save_dir / f'train_batch{ni}.jpg' # filename
  65. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  66. if self.wandb and ni == 10:
  67. files = sorted(self.save_dir.glob('train*.jpg'))
  68. self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
  69. def on_train_epoch_end(self, epoch):
  70. # Callback runs on train epoch end
  71. if self.wandb:
  72. self.wandb.current_epoch = epoch + 1
  73. def on_val_image_end(self, pred, predn, path, names, im):
  74. # Callback runs on val image end
  75. if self.wandb:
  76. self.wandb.val_one_image(pred, predn, path, names, im)
  77. def on_val_end(self):
  78. # Callback runs on val end
  79. if self.wandb:
  80. files = sorted(self.save_dir.glob('val*.jpg'))
  81. self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
  82. def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
  83. # Callback runs at the end of each fit (train+val) epoch
  84. x = {k: v for k, v in zip(self.keys, vals)} # dict
  85. if self.csv:
  86. file = self.save_dir / 'results.csv'
  87. n = len(x) + 1 # number of cols
  88. s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
  89. with open(file, 'a') as f:
  90. f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
  91. if self.tb:
  92. for k, v in x.items():
  93. self.tb.add_scalar(k, v, epoch)
  94. if self.wandb:
  95. self.wandb.log(x)
  96. self.wandb.end_epoch(best_result=best_fitness == fi)
  97. def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
  98. # Callback runs on model save event
  99. if self.wandb:
  100. if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
  101. self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
  102. def on_train_end(self, last, best, plots, epoch):
  103. # Callback runs on training end
  104. if plots:
  105. plot_results(file=self.save_dir / 'results.csv') # save results.png
  106. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  107. files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
  108. if self.tb:
  109. from PIL import Image
  110. import numpy as np
  111. for f in files:
  112. self.tb.add_image(f.stem, np.asarray(Image.open(f)), epoch, dataformats='HWC')
  113. if self.wandb:
  114. self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
  115. # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
  116. wandb.log_artifact(str(best if best.exists() else last), type='model',
  117. name='run_' + self.wandb.wandb_run.id + '_model',
  118. aliases=['latest', 'best', 'stripped'])
  119. self.wandb.finish_run()