Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

callbacks.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Callback utils
  4. """
  5. class Callbacks:
  6. """"
  7. Handles all registered callbacks for YOLOv5 Hooks
  8. """
  9. _callbacks = {
  10. 'on_pretrain_routine_start': [],
  11. 'on_pretrain_routine_end': [],
  12. 'on_train_start': [],
  13. 'on_train_epoch_start': [],
  14. 'on_train_batch_start': [],
  15. 'optimizer_step': [],
  16. 'on_before_zero_grad': [],
  17. 'on_train_batch_end': [],
  18. 'on_train_epoch_end': [],
  19. 'on_val_start': [],
  20. 'on_val_batch_start': [],
  21. 'on_val_image_end': [],
  22. 'on_val_batch_end': [],
  23. 'on_val_end': [],
  24. 'on_fit_epoch_end': [], # fit = train + val
  25. 'on_model_save': [],
  26. 'on_train_end': [],
  27. 'teardown': [],
  28. }
  29. def __init__(self):
  30. return
  31. def register_action(self, hook, name='', callback=None):
  32. """
  33. Register a new action to a callback hook
  34. Args:
  35. hook The callback hook name to register the action to
  36. name The name of the action
  37. callback The callback to fire
  38. """
  39. assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
  40. assert callable(callback), f"callback '{callback}' is not callable"
  41. self._callbacks[hook].append({'name': name, 'callback': callback})
  42. def get_registered_actions(self, hook=None):
  43. """"
  44. Returns all the registered actions by callback hook
  45. Args:
  46. hook The name of the hook to check, defaults to all
  47. """
  48. if hook:
  49. return self._callbacks[hook]
  50. else:
  51. return self._callbacks
  52. def run_callbacks(self, hook, *args, **kwargs):
  53. """
  54. Loop through the registered actions and fire all callbacks
  55. """
  56. for logger in self._callbacks[hook]:
  57. # print(f"Running callbacks.{logger['callback'].__name__}()")
  58. logger['callback'](*args, **kwargs)
  59. def on_pretrain_routine_start(self, *args, **kwargs):
  60. """
  61. Fires all registered callbacks at the start of each pretraining routine
  62. """
  63. self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
  64. def on_pretrain_routine_end(self, *args, **kwargs):
  65. """
  66. Fires all registered callbacks at the end of each pretraining routine
  67. """
  68. self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
  69. def on_train_start(self, *args, **kwargs):
  70. """
  71. Fires all registered callbacks at the start of each training
  72. """
  73. self.run_callbacks('on_train_start', *args, **kwargs)
  74. def on_train_epoch_start(self, *args, **kwargs):
  75. """
  76. Fires all registered callbacks at the start of each training epoch
  77. """
  78. self.run_callbacks('on_train_epoch_start', *args, **kwargs)
  79. def on_train_batch_start(self, *args, **kwargs):
  80. """
  81. Fires all registered callbacks at the start of each training batch
  82. """
  83. self.run_callbacks('on_train_batch_start', *args, **kwargs)
  84. def optimizer_step(self, *args, **kwargs):
  85. """
  86. Fires all registered callbacks on each optimizer step
  87. """
  88. self.run_callbacks('optimizer_step', *args, **kwargs)
  89. def on_before_zero_grad(self, *args, **kwargs):
  90. """
  91. Fires all registered callbacks before zero grad
  92. """
  93. self.run_callbacks('on_before_zero_grad', *args, **kwargs)
  94. def on_train_batch_end(self, *args, **kwargs):
  95. """
  96. Fires all registered callbacks at the end of each training batch
  97. """
  98. self.run_callbacks('on_train_batch_end', *args, **kwargs)
  99. def on_train_epoch_end(self, *args, **kwargs):
  100. """
  101. Fires all registered callbacks at the end of each training epoch
  102. """
  103. self.run_callbacks('on_train_epoch_end', *args, **kwargs)
  104. def on_val_start(self, *args, **kwargs):
  105. """
  106. Fires all registered callbacks at the start of the validation
  107. """
  108. self.run_callbacks('on_val_start', *args, **kwargs)
  109. def on_val_batch_start(self, *args, **kwargs):
  110. """
  111. Fires all registered callbacks at the start of each validation batch
  112. """
  113. self.run_callbacks('on_val_batch_start', *args, **kwargs)
  114. def on_val_image_end(self, *args, **kwargs):
  115. """
  116. Fires all registered callbacks at the end of each val image
  117. """
  118. self.run_callbacks('on_val_image_end', *args, **kwargs)
  119. def on_val_batch_end(self, *args, **kwargs):
  120. """
  121. Fires all registered callbacks at the end of each validation batch
  122. """
  123. self.run_callbacks('on_val_batch_end', *args, **kwargs)
  124. def on_val_end(self, *args, **kwargs):
  125. """
  126. Fires all registered callbacks at the end of the validation
  127. """
  128. self.run_callbacks('on_val_end', *args, **kwargs)
  129. def on_fit_epoch_end(self, *args, **kwargs):
  130. """
  131. Fires all registered callbacks at the end of each fit (train+val) epoch
  132. """
  133. self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
  134. def on_model_save(self, *args, **kwargs):
  135. """
  136. Fires all registered callbacks after each model save
  137. """
  138. self.run_callbacks('on_model_save', *args, **kwargs)
  139. def on_train_end(self, *args, **kwargs):
  140. """
  141. Fires all registered callbacks at the end of training
  142. """
  143. self.run_callbacks('on_train_end', *args, **kwargs)
  144. def teardown(self, *args, **kwargs):
  145. """
  146. Fires all registered callbacks before teardown
  147. """
  148. self.run_callbacks('teardown', *args, **kwargs)