Optimised Callback Class to Reduce Code and Fix Errors (#4688)
* added callbacks * added back callback to main * added save_dir to callback output * reduced code count * updated callbacks * added default callback class to main, added missing parameters to on_model_save * Glenn updates Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
548745181a
commit
2317f86ca4
20
train.py
20
train.py
|
|
@ -56,7 +56,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
||||||
def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
opt,
|
opt,
|
||||||
device,
|
device,
|
||||||
callbacks=Callbacks()
|
callbacks
|
||||||
):
|
):
|
||||||
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
|
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
|
||||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
|
||||||
|
|
@ -231,7 +231,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
||||||
model.half().float() # pre-reduce anchor precision
|
model.half().float() # pre-reduce anchor precision
|
||||||
|
|
||||||
callbacks.on_pretrain_routine_end()
|
callbacks.run('on_pretrain_routine_end')
|
||||||
|
|
||||||
# DDP mode
|
# DDP mode
|
||||||
if cuda and RANK != -1:
|
if cuda and RANK != -1:
|
||||||
|
|
@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||||
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
|
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
|
||||||
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
|
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
|
||||||
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots, opt.sync_bn)
|
callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
|
||||||
# end batch ------------------------------------------------------------------------------------------------
|
# end batch ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
|
|
@ -342,7 +342,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
|
|
||||||
if RANK in [-1, 0]:
|
if RANK in [-1, 0]:
|
||||||
# mAP
|
# mAP
|
||||||
callbacks.on_train_epoch_end(epoch=epoch)
|
callbacks.run('on_train_epoch_end', epoch=epoch)
|
||||||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
||||||
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
|
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
|
||||||
if not noval or final_epoch: # Calculate mAP
|
if not noval or final_epoch: # Calculate mAP
|
||||||
|
|
@ -364,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
if fi > best_fitness:
|
if fi > best_fitness:
|
||||||
best_fitness = fi
|
best_fitness = fi
|
||||||
log_vals = list(mloss) + list(results) + lr
|
log_vals = list(mloss) + list(results) + lr
|
||||||
callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi)
|
callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
|
||||||
|
|
||||||
# Save model
|
# Save model
|
||||||
if (not nosave) or (final_epoch and not evolve): # if save
|
if (not nosave) or (final_epoch and not evolve): # if save
|
||||||
|
|
@ -381,7 +381,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
if best_fitness == fi:
|
if best_fitness == fi:
|
||||||
torch.save(ckpt, best)
|
torch.save(ckpt, best)
|
||||||
del ckpt
|
del ckpt
|
||||||
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
|
callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
|
||||||
|
|
||||||
# Stop Single-GPU
|
# Stop Single-GPU
|
||||||
if RANK == -1 and stopper(epoch=epoch, fitness=fi):
|
if RANK == -1 and stopper(epoch=epoch, fitness=fi):
|
||||||
|
|
@ -418,7 +418,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
for f in last, best:
|
for f in last, best:
|
||||||
if f.exists():
|
if f.exists():
|
||||||
strip_optimizer(f) # strip optimizers
|
strip_optimizer(f) # strip optimizers
|
||||||
callbacks.on_train_end(last, best, plots, epoch)
|
callbacks.run('on_train_end', last, best, plots, epoch)
|
||||||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
@ -467,7 +467,7 @@ def parse_opt(known=False):
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
def main(opt, callbacks=Callbacks()):
|
||||||
# Checks
|
# Checks
|
||||||
set_logging(RANK)
|
set_logging(RANK)
|
||||||
if RANK in [-1, 0]:
|
if RANK in [-1, 0]:
|
||||||
|
|
@ -505,7 +505,7 @@ def main(opt):
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
if not opt.evolve:
|
if not opt.evolve:
|
||||||
train(opt.hyp, opt, device)
|
train(opt.hyp, opt, device, callbacks)
|
||||||
if WORLD_SIZE > 1 and RANK == 0:
|
if WORLD_SIZE > 1 and RANK == 0:
|
||||||
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
|
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
|
||||||
|
|
||||||
|
|
@ -585,7 +585,7 @@ def main(opt):
|
||||||
hyp[k] = round(hyp[k], 5) # significant digits
|
hyp[k] = round(hyp[k], 5) # significant digits
|
||||||
|
|
||||||
# Train mutation
|
# Train mutation
|
||||||
results = train(hyp.copy(), opt, device)
|
results = train(hyp.copy(), opt, device, callbacks)
|
||||||
|
|
||||||
# Write mutation results
|
# Write mutation results
|
||||||
print_mutation(results, hyp.copy(), save_dir, opt.bucket)
|
print_mutation(results, hyp.copy(), save_dir, opt.bucket)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ class Callbacks:
|
||||||
Handles all registered callbacks for YOLOv5 Hooks
|
Handles all registered callbacks for YOLOv5 Hooks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Define the available callbacks
|
||||||
_callbacks = {
|
_callbacks = {
|
||||||
'on_pretrain_routine_start': [],
|
'on_pretrain_routine_start': [],
|
||||||
'on_pretrain_routine_end': [],
|
'on_pretrain_routine_end': [],
|
||||||
|
|
@ -34,16 +35,13 @@ class Callbacks:
|
||||||
'teardown': [],
|
'teardown': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
return
|
|
||||||
|
|
||||||
def register_action(self, hook, name='', callback=None):
|
def register_action(self, hook, name='', callback=None):
|
||||||
"""
|
"""
|
||||||
Register a new action to a callback hook
|
Register a new action to a callback hook
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hook The callback hook name to register the action to
|
hook The callback hook name to register the action to
|
||||||
name The name of the action
|
name The name of the action for later reference
|
||||||
callback The callback to fire
|
callback The callback to fire
|
||||||
"""
|
"""
|
||||||
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
||||||
|
|
@ -62,118 +60,17 @@ class Callbacks:
|
||||||
else:
|
else:
|
||||||
return self._callbacks
|
return self._callbacks
|
||||||
|
|
||||||
def run_callbacks(self, hook, *args, **kwargs):
|
def run(self, hook, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Loop through the registered actions and fire all callbacks
|
Loop through the registered actions and fire all callbacks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook The name of the hook to check, defaults to all
|
||||||
|
args Arguments to receive from YOLOv5
|
||||||
|
kwargs Keyword Arguments to receive from YOLOv5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
||||||
|
|
||||||
for logger in self._callbacks[hook]:
|
for logger in self._callbacks[hook]:
|
||||||
# print(f"Running callbacks.{logger['callback'].__name__}()")
|
|
||||||
logger['callback'](*args, **kwargs)
|
logger['callback'](*args, **kwargs)
|
||||||
|
|
||||||
def on_pretrain_routine_start(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the start of each pretraining routine
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_pretrain_routine_end(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the end of each pretraining routine
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_train_start(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the start of each training
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_train_start', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_train_epoch_start(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the start of each training epoch
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_train_epoch_start', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_train_batch_start(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the start of each training batch
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_train_batch_start', *args, **kwargs)
|
|
||||||
|
|
||||||
def optimizer_step(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks on each optimizer step
|
|
||||||
"""
|
|
||||||
self.run_callbacks('optimizer_step', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_before_zero_grad(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks before zero grad
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_before_zero_grad', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_train_batch_end(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the end of each training batch
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_train_batch_end', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_train_epoch_end(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the end of each training epoch
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_train_epoch_end', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_val_start(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the start of the validation
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_val_start', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_val_batch_start(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the start of each validation batch
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_val_batch_start', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_val_image_end(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the end of each val image
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_val_image_end', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_val_batch_end(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the end of each validation batch
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_val_batch_end', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_val_end(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the end of the validation
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_val_end', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_fit_epoch_end(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the end of each fit (train+val) epoch
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_model_save(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks after each model save
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_model_save', *args, **kwargs)
|
|
||||||
|
|
||||||
def on_train_end(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks at the end of training
|
|
||||||
"""
|
|
||||||
self.run_callbacks('on_train_end', *args, **kwargs)
|
|
||||||
|
|
||||||
def teardown(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Fires all registered callbacks before teardown
|
|
||||||
"""
|
|
||||||
self.run_callbacks('teardown', *args, **kwargs)
|
|
||||||
|
|
|
||||||
4
val.py
4
val.py
|
|
@ -216,7 +216,7 @@ def run(data,
|
||||||
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
|
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
|
||||||
if save_json:
|
if save_json:
|
||||||
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
|
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
|
||||||
callbacks.on_val_image_end(pred, predn, path, names, img[si])
|
callbacks.run('on_val_image_end', pred, predn, path, names, img[si])
|
||||||
|
|
||||||
# Plot images
|
# Plot images
|
||||||
if plots and batch_i < 3:
|
if plots and batch_i < 3:
|
||||||
|
|
@ -253,7 +253,7 @@ def run(data,
|
||||||
# Plots
|
# Plots
|
||||||
if plots:
|
if plots:
|
||||||
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
||||||
callbacks.on_val_end()
|
callbacks.run('on_val_end')
|
||||||
|
|
||||||
# Save JSON
|
# Save JSON
|
||||||
if save_json and len(jdict):
|
if save_json and len(jdict):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue