TensorRT转化代码
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

176 行
5.6KB

  1. #!/usr/bin/env python
  2. class Callbacks:
  3. """"
  4. Handles all registered callbacks for YOLOv5 Hooks
  5. """
  6. _callbacks = {
  7. 'on_pretrain_routine_start': [],
  8. 'on_pretrain_routine_end': [],
  9. 'on_train_start': [],
  10. 'on_train_epoch_start': [],
  11. 'on_train_batch_start': [],
  12. 'optimizer_step': [],
  13. 'on_before_zero_grad': [],
  14. 'on_train_batch_end': [],
  15. 'on_train_epoch_end': [],
  16. 'on_val_start': [],
  17. 'on_val_batch_start': [],
  18. 'on_val_image_end': [],
  19. 'on_val_batch_end': [],
  20. 'on_val_end': [],
  21. 'on_fit_epoch_end': [], # fit = train + val
  22. 'on_model_save': [],
  23. 'on_train_end': [],
  24. 'teardown': [],
  25. }
  26. def __init__(self):
  27. return
  28. def register_action(self, hook, name='', callback=None):
  29. """
  30. Register a new action to a callback hook
  31. Args:
  32. hook The callback hook name to register the action to
  33. name The name of the action
  34. callback The callback to fire
  35. """
  36. assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
  37. assert callable(callback), f"callback '{callback}' is not callable"
  38. self._callbacks[hook].append({'name': name, 'callback': callback})
  39. def get_registered_actions(self, hook=None):
  40. """"
  41. Returns all the registered actions by callback hook
  42. Args:
  43. hook The name of the hook to check, defaults to all
  44. """
  45. if hook:
  46. return self._callbacks[hook]
  47. else:
  48. return self._callbacks
  49. def run_callbacks(self, hook, *args, **kwargs):
  50. """
  51. Loop through the registered actions and fire all callbacks
  52. """
  53. for logger in self._callbacks[hook]:
  54. # print(f"Running callbacks.{logger['callback'].__name__}()")
  55. logger['callback'](*args, **kwargs)
  56. def on_pretrain_routine_start(self, *args, **kwargs):
  57. """
  58. Fires all registered callbacks at the start of each pretraining routine
  59. """
  60. self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
  61. def on_pretrain_routine_end(self, *args, **kwargs):
  62. """
  63. Fires all registered callbacks at the end of each pretraining routine
  64. """
  65. self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
  66. def on_train_start(self, *args, **kwargs):
  67. """
  68. Fires all registered callbacks at the start of each training
  69. """
  70. self.run_callbacks('on_train_start', *args, **kwargs)
  71. def on_train_epoch_start(self, *args, **kwargs):
  72. """
  73. Fires all registered callbacks at the start of each training epoch
  74. """
  75. self.run_callbacks('on_train_epoch_start', *args, **kwargs)
  76. def on_train_batch_start(self, *args, **kwargs):
  77. """
  78. Fires all registered callbacks at the start of each training batch
  79. """
  80. self.run_callbacks('on_train_batch_start', *args, **kwargs)
  81. def optimizer_step(self, *args, **kwargs):
  82. """
  83. Fires all registered callbacks on each optimizer step
  84. """
  85. self.run_callbacks('optimizer_step', *args, **kwargs)
  86. def on_before_zero_grad(self, *args, **kwargs):
  87. """
  88. Fires all registered callbacks before zero grad
  89. """
  90. self.run_callbacks('on_before_zero_grad', *args, **kwargs)
  91. def on_train_batch_end(self, *args, **kwargs):
  92. """
  93. Fires all registered callbacks at the end of each training batch
  94. """
  95. self.run_callbacks('on_train_batch_end', *args, **kwargs)
  96. def on_train_epoch_end(self, *args, **kwargs):
  97. """
  98. Fires all registered callbacks at the end of each training epoch
  99. """
  100. self.run_callbacks('on_train_epoch_end', *args, **kwargs)
  101. def on_val_start(self, *args, **kwargs):
  102. """
  103. Fires all registered callbacks at the start of the validation
  104. """
  105. self.run_callbacks('on_val_start', *args, **kwargs)
  106. def on_val_batch_start(self, *args, **kwargs):
  107. """
  108. Fires all registered callbacks at the start of each validation batch
  109. """
  110. self.run_callbacks('on_val_batch_start', *args, **kwargs)
  111. def on_val_image_end(self, *args, **kwargs):
  112. """
  113. Fires all registered callbacks at the end of each val image
  114. """
  115. self.run_callbacks('on_val_image_end', *args, **kwargs)
  116. def on_val_batch_end(self, *args, **kwargs):
  117. """
  118. Fires all registered callbacks at the end of each validation batch
  119. """
  120. self.run_callbacks('on_val_batch_end', *args, **kwargs)
  121. def on_val_end(self, *args, **kwargs):
  122. """
  123. Fires all registered callbacks at the end of the validation
  124. """
  125. self.run_callbacks('on_val_end', *args, **kwargs)
  126. def on_fit_epoch_end(self, *args, **kwargs):
  127. """
  128. Fires all registered callbacks at the end of each fit (train+val) epoch
  129. """
  130. self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
  131. def on_model_save(self, *args, **kwargs):
  132. """
  133. Fires all registered callbacks after each model save
  134. """
  135. self.run_callbacks('on_model_save', *args, **kwargs)
  136. def on_train_end(self, *args, **kwargs):
  137. """
  138. Fires all registered callbacks at the end of training
  139. """
  140. self.run_callbacks('on_train_end', *args, **kwargs)
  141. def teardown(self, *args, **kwargs):
  142. """
  143. Fires all registered callbacks before teardown
  144. """
  145. self.run_callbacks('teardown', *args, **kwargs)