* added callbacks * added back callback to main * added save_dir to callback output * reduced code count * updated callbacks * added default callback class to main, added missing parameters to on_model_save * Glenn updates Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>modifyDataloader
@@ -56,7 +56,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) | |||
def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
opt, | |||
device, | |||
callbacks=Callbacks() | |||
callbacks | |||
): | |||
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \ | |||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ | |||
@@ -231,7 +231,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) | |||
model.half().float() # pre-reduce anchor precision | |||
callbacks.on_pretrain_routine_end() | |||
callbacks.run('on_pretrain_routine_end') | |||
# DDP mode | |||
if cuda and RANK != -1: | |||
@@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) | |||
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( | |||
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) | |||
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots, opt.sync_bn) | |||
callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn) | |||
# end batch ------------------------------------------------------------------------------------------------ | |||
# Scheduler | |||
@@ -342,7 +342,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
if RANK in [-1, 0]: | |||
# mAP | |||
callbacks.on_train_epoch_end(epoch=epoch) | |||
callbacks.run('on_train_epoch_end', epoch=epoch) | |||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) | |||
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop | |||
if not noval or final_epoch: # Calculate mAP | |||
@@ -364,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
if fi > best_fitness: | |||
best_fitness = fi | |||
log_vals = list(mloss) + list(results) + lr | |||
callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi) | |||
callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi) | |||
# Save model | |||
if (not nosave) or (final_epoch and not evolve): # if save | |||
@@ -381,7 +381,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
if best_fitness == fi: | |||
torch.save(ckpt, best) | |||
del ckpt | |||
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) | |||
callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi) | |||
# Stop Single-GPU | |||
if RANK == -1 and stopper(epoch=epoch, fitness=fi): | |||
@@ -418,7 +418,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |||
for f in last, best: | |||
if f.exists(): | |||
strip_optimizer(f) # strip optimizers | |||
callbacks.on_train_end(last, best, plots, epoch) | |||
callbacks.run('on_train_end', last, best, plots, epoch) | |||
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") | |||
torch.cuda.empty_cache() | |||
@@ -467,7 +467,7 @@ def parse_opt(known=False): | |||
return opt | |||
def main(opt): | |||
def main(opt, callbacks=Callbacks()): | |||
# Checks | |||
set_logging(RANK) | |||
if RANK in [-1, 0]: | |||
@@ -505,7 +505,7 @@ def main(opt): | |||
# Train | |||
if not opt.evolve: | |||
train(opt.hyp, opt, device) | |||
train(opt.hyp, opt, device, callbacks) | |||
if WORLD_SIZE > 1 and RANK == 0: | |||
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')] | |||
@@ -585,7 +585,7 @@ def main(opt): | |||
hyp[k] = round(hyp[k], 5) # significant digits | |||
# Train mutation | |||
results = train(hyp.copy(), opt, device) | |||
results = train(hyp.copy(), opt, device, callbacks) | |||
# Write mutation results | |||
print_mutation(results, hyp.copy(), save_dir, opt.bucket) |
@@ -9,6 +9,7 @@ class Callbacks: | |||
Handles all registered callbacks for YOLOv5 Hooks | |||
""" | |||
# Define the available callbacks | |||
_callbacks = { | |||
'on_pretrain_routine_start': [], | |||
'on_pretrain_routine_end': [], | |||
@@ -34,16 +35,13 @@ class Callbacks: | |||
'teardown': [], | |||
} | |||
def __init__(self): | |||
return | |||
def register_action(self, hook, name='', callback=None): | |||
""" | |||
Register a new action to a callback hook | |||
Args: | |||
hook The callback hook name to register the action to | |||
name The name of the action | |||
name The name of the action for later reference | |||
callback The callback to fire | |||
""" | |||
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" | |||
@@ -62,118 +60,17 @@ class Callbacks: | |||
else: | |||
return self._callbacks | |||
def run_callbacks(self, hook, *args, **kwargs): | |||
def run(self, hook, *args, **kwargs): | |||
""" | |||
Loop through the registered actions and fire all callbacks | |||
""" | |||
for logger in self._callbacks[hook]: | |||
# print(f"Running callbacks.{logger['callback'].__name__}()") | |||
logger['callback'](*args, **kwargs) | |||
def on_pretrain_routine_start(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the start of each pretraining routine | |||
""" | |||
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs) | |||
def on_pretrain_routine_end(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the end of each pretraining routine | |||
""" | |||
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs) | |||
def on_train_start(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the start of each training | |||
""" | |||
self.run_callbacks('on_train_start', *args, **kwargs) | |||
def on_train_epoch_start(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the start of each training epoch | |||
""" | |||
self.run_callbacks('on_train_epoch_start', *args, **kwargs) | |||
def on_train_batch_start(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the start of each training batch | |||
""" | |||
self.run_callbacks('on_train_batch_start', *args, **kwargs) | |||
def optimizer_step(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks on each optimizer step | |||
""" | |||
self.run_callbacks('optimizer_step', *args, **kwargs) | |||
def on_before_zero_grad(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks before zero grad | |||
""" | |||
self.run_callbacks('on_before_zero_grad', *args, **kwargs) | |||
def on_train_batch_end(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the end of each training batch | |||
""" | |||
self.run_callbacks('on_train_batch_end', *args, **kwargs) | |||
def on_train_epoch_end(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the end of each training epoch | |||
""" | |||
self.run_callbacks('on_train_epoch_end', *args, **kwargs) | |||
def on_val_start(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the start of the validation | |||
""" | |||
self.run_callbacks('on_val_start', *args, **kwargs) | |||
def on_val_batch_start(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the start of each validation batch | |||
""" | |||
self.run_callbacks('on_val_batch_start', *args, **kwargs) | |||
def on_val_image_end(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the end of each val image | |||
""" | |||
self.run_callbacks('on_val_image_end', *args, **kwargs) | |||
def on_val_batch_end(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the end of each validation batch | |||
""" | |||
self.run_callbacks('on_val_batch_end', *args, **kwargs) | |||
def on_val_end(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the end of the validation | |||
""" | |||
self.run_callbacks('on_val_end', *args, **kwargs) | |||
def on_fit_epoch_end(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the end of each fit (train+val) epoch | |||
""" | |||
self.run_callbacks('on_fit_epoch_end', *args, **kwargs) | |||
def on_model_save(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks after each model save | |||
Args: | |||
hook The name of the hook to check, defaults to all | |||
args Arguments to receive from YOLOv5 | |||
kwargs Keyword Arguments to receive from YOLOv5 | |||
""" | |||
self.run_callbacks('on_model_save', *args, **kwargs) | |||
def on_train_end(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks at the end of training | |||
""" | |||
self.run_callbacks('on_train_end', *args, **kwargs) | |||
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" | |||
def teardown(self, *args, **kwargs): | |||
""" | |||
Fires all registered callbacks before teardown | |||
""" | |||
self.run_callbacks('teardown', *args, **kwargs) | |||
for logger in self._callbacks[hook]: | |||
logger['callback'](*args, **kwargs) |
@@ -216,7 +216,7 @@ def run(data, | |||
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) | |||
if save_json: | |||
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary | |||
callbacks.on_val_image_end(pred, predn, path, names, img[si]) | |||
callbacks.run('on_val_image_end', pred, predn, path, names, img[si]) | |||
# Plot images | |||
if plots and batch_i < 3: | |||
@@ -253,7 +253,7 @@ def run(data, | |||
# Plots | |||
if plots: | |||
confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) | |||
callbacks.on_val_end() | |||
callbacks.run('on_val_end') | |||
# Save JSON | |||
if save_json and len(jdict): |