選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

74 行
2.4KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Callback utils
  4. """
  5. class Callbacks:
  6. """"
  7. Handles all registered callbacks for YOLOv5 Hooks
  8. """
  9. def __init__(self):
  10. # Define the available callbacks
  11. self._callbacks = {
  12. 'on_pretrain_routine_start': [],
  13. 'on_pretrain_routine_end': [],
  14. 'on_train_start': [],
  15. 'on_train_epoch_start': [],
  16. 'on_train_batch_start': [],
  17. 'optimizer_step': [],
  18. 'on_before_zero_grad': [],
  19. 'on_train_batch_end': [],
  20. 'on_train_epoch_end': [],
  21. 'on_val_start': [],
  22. 'on_val_batch_start': [],
  23. 'on_val_image_end': [],
  24. 'on_val_batch_end': [],
  25. 'on_val_end': [],
  26. 'on_fit_epoch_end': [], # fit = train + val
  27. 'on_model_save': [],
  28. 'on_train_end': [],
  29. 'on_params_update': [],
  30. 'teardown': [],}
  31. self.stop_training = False # set True to interrupt training
  32. def register_action(self, hook, name='', callback=None):
  33. """
  34. Register a new action to a callback hook
  35. Args:
  36. hook The callback hook name to register the action to
  37. name The name of the action for later reference
  38. callback The callback to fire
  39. """
  40. assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
  41. assert callable(callback), f"callback '{callback}' is not callable"
  42. self._callbacks[hook].append({'name': name, 'callback': callback})
  43. def get_registered_actions(self, hook=None):
  44. """"
  45. Returns all the registered actions by callback hook
  46. Args:
  47. hook The name of the hook to check, defaults to all
  48. """
  49. if hook:
  50. return self._callbacks[hook]
  51. else:
  52. return self._callbacks
  53. def run(self, hook, *args, **kwargs):
  54. """
  55. Loop through the registered actions and fire all callbacks
  56. Args:
  57. hook The name of the hook to check, defaults to all
  58. args Arguments to receive from YOLOv5
  59. kwargs Keyword Arguments to receive from YOLOv5
  60. """
  61. assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
  62. for logger in self._callbacks[hook]:
  63. logger['callback'](*args, **kwargs)