Browse Source

Simplify callbacks (#4289)

modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
4103ce9ad0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 21 deletions
  1. +20
    -21
      utils/callbacks.py

+ 20
- 21
utils/callbacks.py View File

else: else:
return self._callbacks return self._callbacks


@staticmethod
def run_callbacks(register, *args, **kwargs):
def run_callbacks(self, hook, *args, **kwargs):
""" """
Loop through the registered actions and fire all callbacks Loop through the registered actions and fire all callbacks
""" """
for logger in register:
for logger in self._callbacks[hook]:
# print(f"Running callbacks.{logger['callback'].__name__}()") # print(f"Running callbacks.{logger['callback'].__name__}()")
logger['callback'](*args, **kwargs) logger['callback'](*args, **kwargs)


""" """
Fires all registered callbacks at the start of each pretraining routine Fires all registered callbacks at the start of each pretraining routine
""" """
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)


def on_pretrain_routine_end(self, *args, **kwargs): def on_pretrain_routine_end(self, *args, **kwargs):
""" """
Fires all registered callbacks at the end of each pretraining routine Fires all registered callbacks at the end of each pretraining routine
""" """
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)


def on_train_start(self, *args, **kwargs): def on_train_start(self, *args, **kwargs):
""" """
Fires all registered callbacks at the start of each training Fires all registered callbacks at the start of each training
""" """
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)
self.run_callbacks('on_train_start', *args, **kwargs)


def on_train_epoch_start(self, *args, **kwargs): def on_train_epoch_start(self, *args, **kwargs):
""" """
Fires all registered callbacks at the start of each training epoch Fires all registered callbacks at the start of each training epoch
""" """
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)
self.run_callbacks('on_train_epoch_start', *args, **kwargs)


def on_train_batch_start(self, *args, **kwargs): def on_train_batch_start(self, *args, **kwargs):
""" """
Fires all registered callbacks at the start of each training batch Fires all registered callbacks at the start of each training batch
""" """
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)
self.run_callbacks('on_train_batch_start', *args, **kwargs)


def optimizer_step(self, *args, **kwargs): def optimizer_step(self, *args, **kwargs):
""" """
Fires all registered callbacks on each optimizer step Fires all registered callbacks on each optimizer step
""" """
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)
self.run_callbacks('optimizer_step', *args, **kwargs)


def on_before_zero_grad(self, *args, **kwargs): def on_before_zero_grad(self, *args, **kwargs):
""" """
Fires all registered callbacks before zero grad Fires all registered callbacks before zero grad
""" """
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)
self.run_callbacks('on_before_zero_grad', *args, **kwargs)


def on_train_batch_end(self, *args, **kwargs): def on_train_batch_end(self, *args, **kwargs):
""" """
Fires all registered callbacks at the end of each training batch Fires all registered callbacks at the end of each training batch
""" """
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)
self.run_callbacks('on_train_batch_end', *args, **kwargs)


def on_train_epoch_end(self, *args, **kwargs): def on_train_epoch_end(self, *args, **kwargs):
""" """
Fires all registered callbacks at the end of each training epoch Fires all registered callbacks at the end of each training epoch
""" """
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)
self.run_callbacks('on_train_epoch_end', *args, **kwargs)


def on_val_start(self, *args, **kwargs): def on_val_start(self, *args, **kwargs):
""" """
Fires all registered callbacks at the start of the validation Fires all registered callbacks at the start of the validation
""" """
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)
self.run_callbacks('on_val_start', *args, **kwargs)


def on_val_batch_start(self, *args, **kwargs): def on_val_batch_start(self, *args, **kwargs):
""" """
Fires all registered callbacks at the start of each validation batch Fires all registered callbacks at the start of each validation batch
""" """
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)
self.run_callbacks('on_val_batch_start', *args, **kwargs)


def on_val_image_end(self, *args, **kwargs): def on_val_image_end(self, *args, **kwargs):
""" """
Fires all registered callbacks at the end of each val image Fires all registered callbacks at the end of each val image
""" """
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)
self.run_callbacks('on_val_image_end', *args, **kwargs)


def on_val_batch_end(self, *args, **kwargs): def on_val_batch_end(self, *args, **kwargs):
""" """
Fires all registered callbacks at the end of each validation batch Fires all registered callbacks at the end of each validation batch
""" """
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)
self.run_callbacks('on_val_batch_end', *args, **kwargs)


def on_val_end(self, *args, **kwargs): def on_val_end(self, *args, **kwargs):
""" """
Fires all registered callbacks at the end of the validation Fires all registered callbacks at the end of the validation
""" """
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)
self.run_callbacks('on_val_end', *args, **kwargs)


def on_fit_epoch_end(self, *args, **kwargs): def on_fit_epoch_end(self, *args, **kwargs):
""" """
Fires all registered callbacks at the end of each fit (train+val) epoch Fires all registered callbacks at the end of each fit (train+val) epoch
""" """
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)
self.run_callbacks('on_fit_epoch_end', *args, **kwargs)


def on_model_save(self, *args, **kwargs): def on_model_save(self, *args, **kwargs):
""" """
Fires all registered callbacks after each model save Fires all registered callbacks after each model save
""" """
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)
self.run_callbacks('on_model_save', *args, **kwargs)


def on_train_end(self, *args, **kwargs): def on_train_end(self, *args, **kwargs):
""" """
Fires all registered callbacks at the end of training Fires all registered callbacks at the end of training
""" """
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)
self.run_callbacks('on_train_end', *args, **kwargs)


def teardown(self, *args, **kwargs): def teardown(self, *args, **kwargs):
""" """
Fires all registered callbacks before teardown Fires all registered callbacks before teardown
""" """
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
self.run_callbacks('teardown', *args, **kwargs)

Loading…
Cancel
Save