No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

callbacks.py 2.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Callback utils
  4. """
  5. class Callbacks:
  6. """"
  7. Handles all registered callbacks for YOLOv5 Hooks
  8. """
  9. def __init__(self):
  10. # Define the available callbacks
  11. self._callbacks = {
  12. 'on_pretrain_routine_start': [],
  13. 'on_pretrain_routine_end': [],
  14. 'on_train_start': [],
  15. 'on_train_epoch_start': [],
  16. 'on_train_batch_start': [],
  17. 'optimizer_step': [],
  18. 'on_before_zero_grad': [],
  19. 'on_train_batch_end': [],
  20. 'on_train_epoch_end': [],
  21. 'on_val_start': [],
  22. 'on_val_batch_start': [],
  23. 'on_val_image_end': [],
  24. 'on_val_batch_end': [],
  25. 'on_val_end': [],
  26. 'on_fit_epoch_end': [], # fit = train + val
  27. 'on_model_save': [],
  28. 'on_train_end': [],
  29. 'on_params_update': [],
  30. 'teardown': [],
  31. }
  32. self.stop_training = False # set True to interrupt training
  33. def register_action(self, hook, name='', callback=None):
  34. """
  35. Register a new action to a callback hook
  36. Args:
  37. hook The callback hook name to register the action to
  38. name The name of the action for later reference
  39. callback The callback to fire
  40. """
  41. assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
  42. assert callable(callback), f"callback '{callback}' is not callable"
  43. self._callbacks[hook].append({'name': name, 'callback': callback})
  44. def get_registered_actions(self, hook=None):
  45. """"
  46. Returns all the registered actions by callback hook
  47. Args:
  48. hook The name of the hook to check, defaults to all
  49. """
  50. if hook:
  51. return self._callbacks[hook]
  52. else:
  53. return self._callbacks
  54. def run(self, hook, *args, **kwargs):
  55. """
  56. Loop through the registered actions and fire all callbacks
  57. Args:
  58. hook The name of the hook to check, defaults to all
  59. args Arguments to receive from YOLOv5
  60. kwargs Keyword Arguments to receive from YOLOv5
  61. """
  62. assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
  63. for logger in self._callbacks[hook]:
  64. logger['callback'](*args, **kwargs)