Add `train.py` and `val.py` callbacks (#4220)
* added callbacks * Update callbacks.py * Update train.py * Update val.py * Fix CamlCase add staticmethod * Refactor logger into callbacks * Cleanup * New callback on_val_image_end() * Add curves and results images to TensorBoard Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
d8f18834a2
commit
b74929c910
29
train.py
29
train.py
|
|
@ -34,7 +34,7 @@ from utils.autoanchor import check_anchors
|
||||||
from utils.datasets import create_dataloader
|
from utils.datasets import create_dataloader
|
||||||
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
||||||
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
||||||
check_requirements, print_mutation, set_logging, one_cycle, colorstr
|
check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
|
||||||
from utils.downloads import attempt_download
|
from utils.downloads import attempt_download
|
||||||
from utils.loss import ComputeLoss
|
from utils.loss import ComputeLoss
|
||||||
from utils.plots import plot_labels, plot_evolution
|
from utils.plots import plot_labels, plot_evolution
|
||||||
|
|
@ -42,6 +42,7 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di
|
||||||
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
||||||
from utils.metrics import fitness
|
from utils.metrics import fitness
|
||||||
from utils.loggers import Loggers
|
from utils.loggers import Loggers
|
||||||
|
from utils.callbacks import Callbacks
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
|
|
@ -52,6 +53,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
||||||
def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
opt,
|
opt,
|
||||||
device,
|
device,
|
||||||
|
callbacks=Callbacks()
|
||||||
):
|
):
|
||||||
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
|
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
|
||||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
|
||||||
|
|
@ -77,12 +79,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
|
|
||||||
# Loggers
|
# Loggers
|
||||||
if RANK in [-1, 0]:
|
if RANK in [-1, 0]:
|
||||||
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict
|
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
|
||||||
if loggers.wandb:
|
if loggers.wandb:
|
||||||
data_dict = loggers.wandb.data_dict
|
data_dict = loggers.wandb.data_dict
|
||||||
if resume:
|
if resume:
|
||||||
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
|
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
|
||||||
|
|
||||||
|
# Register actions
|
||||||
|
for k in methods(loggers):
|
||||||
|
callbacks.register_action(k, callback=getattr(loggers, k))
|
||||||
|
|
||||||
# Config
|
# Config
|
||||||
plots = not evolve # create plots
|
plots = not evolve # create plots
|
||||||
cuda = device.type != 'cpu'
|
cuda = device.type != 'cpu'
|
||||||
|
|
@ -215,13 +221,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
||||||
# model._initialize_biases(cf.to(device))
|
# model._initialize_biases(cf.to(device))
|
||||||
if plots:
|
if plots:
|
||||||
plot_labels(labels, names, save_dir, loggers)
|
plot_labels(labels, names, save_dir)
|
||||||
|
|
||||||
# Anchors
|
# Anchors
|
||||||
if not opt.noautoanchor:
|
if not opt.noautoanchor:
|
||||||
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
||||||
model.half().float() # pre-reduce anchor precision
|
model.half().float() # pre-reduce anchor precision
|
||||||
|
|
||||||
|
callbacks.on_pretrain_routine_end()
|
||||||
|
|
||||||
# DDP mode
|
# DDP mode
|
||||||
if cuda and RANK != -1:
|
if cuda and RANK != -1:
|
||||||
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
||||||
|
|
@ -329,8 +337,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||||
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
|
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
|
||||||
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
|
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
|
||||||
loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots)
|
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots)
|
||||||
|
|
||||||
# end batch ------------------------------------------------------------------------------------------------
|
# end batch ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
|
|
@ -339,7 +346,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
|
|
||||||
if RANK in [-1, 0]:
|
if RANK in [-1, 0]:
|
||||||
# mAP
|
# mAP
|
||||||
loggers.on_train_epoch_end(epoch)
|
callbacks.on_train_epoch_end(epoch=epoch)
|
||||||
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
||||||
final_epoch = epoch + 1 == epochs
|
final_epoch = epoch + 1 == epochs
|
||||||
if not noval or final_epoch: # Calculate mAP
|
if not noval or final_epoch: # Calculate mAP
|
||||||
|
|
@ -353,14 +360,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
save_json=is_coco and final_epoch,
|
save_json=is_coco and final_epoch,
|
||||||
verbose=nc < 50 and final_epoch,
|
verbose=nc < 50 and final_epoch,
|
||||||
plots=plots and final_epoch,
|
plots=plots and final_epoch,
|
||||||
loggers=loggers,
|
callbacks=callbacks,
|
||||||
compute_loss=compute_loss)
|
compute_loss=compute_loss)
|
||||||
|
|
||||||
# Update best mAP
|
# Update best mAP
|
||||||
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
|
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
|
||||||
if fi > best_fitness:
|
if fi > best_fitness:
|
||||||
best_fitness = fi
|
best_fitness = fi
|
||||||
loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi)
|
callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi)
|
||||||
|
|
||||||
# Save model
|
# Save model
|
||||||
if (not nosave) or (final_epoch and not evolve): # if save
|
if (not nosave) or (final_epoch and not evolve): # if save
|
||||||
|
|
@ -377,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
if best_fitness == fi:
|
if best_fitness == fi:
|
||||||
torch.save(ckpt, best)
|
torch.save(ckpt, best)
|
||||||
del ckpt
|
del ckpt
|
||||||
loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)
|
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
|
||||||
|
|
||||||
# end epoch ----------------------------------------------------------------------------------------------------
|
# end epoch ----------------------------------------------------------------------------------------------------
|
||||||
# end training -----------------------------------------------------------------------------------------------------
|
# end training -----------------------------------------------------------------------------------------------------
|
||||||
|
|
@ -400,7 +407,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
for f in last, best:
|
for f in last, best:
|
||||||
if f.exists():
|
if f.exists():
|
||||||
strip_optimizer(f) # strip optimizers
|
strip_optimizer(f) # strip optimizers
|
||||||
loggers.on_train_end(last, best, plots)
|
callbacks.on_train_end(last, best, plots, epoch)
|
||||||
|
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return results
|
return results
|
||||||
|
|
@ -448,6 +456,7 @@ def parse_opt(known=False):
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
def main(opt):
|
||||||
|
# Checks
|
||||||
set_logging(RANK)
|
set_logging(RANK)
|
||||||
if RANK in [-1, 0]:
|
if RANK in [-1, 0]:
|
||||||
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
|
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,176 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
class Callbacks:
|
||||||
|
""""
|
||||||
|
Handles all registered callbacks for YOLOv5 Hooks
|
||||||
|
"""
|
||||||
|
|
||||||
|
_callbacks = {
|
||||||
|
'on_pretrain_routine_start': [],
|
||||||
|
'on_pretrain_routine_end': [],
|
||||||
|
|
||||||
|
'on_train_start': [],
|
||||||
|
'on_train_epoch_start': [],
|
||||||
|
'on_train_batch_start': [],
|
||||||
|
'optimizer_step': [],
|
||||||
|
'on_before_zero_grad': [],
|
||||||
|
'on_train_batch_end': [],
|
||||||
|
'on_train_epoch_end': [],
|
||||||
|
|
||||||
|
'on_val_start': [],
|
||||||
|
'on_val_batch_start': [],
|
||||||
|
'on_val_image_end': [],
|
||||||
|
'on_val_batch_end': [],
|
||||||
|
'on_val_end': [],
|
||||||
|
|
||||||
|
'on_fit_epoch_end': [], # fit = train + val
|
||||||
|
'on_model_save': [],
|
||||||
|
'on_train_end': [],
|
||||||
|
|
||||||
|
'teardown': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def register_action(self, hook, name='', callback=None):
|
||||||
|
"""
|
||||||
|
Register a new action to a callback hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook The callback hook name to register the action to
|
||||||
|
name The name of the action
|
||||||
|
callback The callback to fire
|
||||||
|
"""
|
||||||
|
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
||||||
|
assert callable(callback), f"callback '{callback}' is not callable"
|
||||||
|
self._callbacks[hook].append({'name': name, 'callback': callback})
|
||||||
|
|
||||||
|
def get_registered_actions(self, hook=None):
|
||||||
|
""""
|
||||||
|
Returns all the registered actions by callback hook
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook The name of the hook to check, defaults to all
|
||||||
|
"""
|
||||||
|
if hook:
|
||||||
|
return self._callbacks[hook]
|
||||||
|
else:
|
||||||
|
return self._callbacks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run_callbacks(register, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Loop through the registered actions and fire all callbacks
|
||||||
|
"""
|
||||||
|
for logger in register:
|
||||||
|
# print(f"Running callbacks.{logger['callback'].__name__}()")
|
||||||
|
logger['callback'](*args, **kwargs)
|
||||||
|
|
||||||
|
def on_pretrain_routine_start(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the start of each pretraining routine
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_pretrain_routine_end(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the end of each pretraining routine
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_train_start(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the start of each training
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_train_epoch_start(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the start of each training epoch
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_train_batch_start(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the start of each training batch
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)
|
||||||
|
|
||||||
|
def optimizer_step(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks on each optimizer step
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_before_zero_grad(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks before zero grad
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_train_batch_end(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the end of each training batch
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the end of each training epoch
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_val_start(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the start of the validation
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_val_batch_start(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the start of each validation batch
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_val_image_end(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the end of each val image
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_val_batch_end(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the end of each validation batch
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_val_end(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the end of the validation
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_fit_epoch_end(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the end of each fit (train+val) epoch
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_model_save(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks after each model save
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)
|
||||||
|
|
||||||
|
def on_train_end(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks at the end of training
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)
|
||||||
|
|
||||||
|
def teardown(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Fires all registered callbacks before teardown
|
||||||
|
"""
|
||||||
|
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
|
||||||
|
|
@ -67,6 +67,11 @@ def try_except(func):
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
def methods(instance):
|
||||||
|
# Get class/instance methods
|
||||||
|
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
||||||
|
|
||||||
|
|
||||||
def set_logging(rank=-1, verbose=True):
|
def set_logging(rank=-1, verbose=True):
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(message)s",
|
format="%(message)s",
|
||||||
|
|
|
||||||
|
|
@ -29,10 +29,12 @@ class Loggers():
|
||||||
self.hyp = hyp
|
self.hyp = hyp
|
||||||
self.logger = logger # for printing results to console
|
self.logger = logger # for printing results to console
|
||||||
self.include = include
|
self.include = include
|
||||||
|
self.keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
|
||||||
|
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
|
||||||
|
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
|
||||||
|
'x/lr0', 'x/lr1', 'x/lr2'] # params
|
||||||
for k in LOGGERS:
|
for k in LOGGERS:
|
||||||
setattr(self, k, None) # init empty logger dictionary
|
setattr(self, k, None) # init empty logger dictionary
|
||||||
|
|
||||||
def start(self):
|
|
||||||
self.csv = True # always log to csv
|
self.csv = True # always log to csv
|
||||||
|
|
||||||
# Message
|
# Message
|
||||||
|
|
@ -57,7 +59,11 @@ class Loggers():
|
||||||
else:
|
else:
|
||||||
self.wandb = None
|
self.wandb = None
|
||||||
|
|
||||||
return self
|
def on_pretrain_routine_end(self):
|
||||||
|
# Callback runs on pre-train routine end
|
||||||
|
paths = self.save_dir.glob('*labels*.jpg') # training labels
|
||||||
|
if self.wandb:
|
||||||
|
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
|
||||||
|
|
||||||
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
|
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
|
||||||
# Callback runs on train batch end
|
# Callback runs on train batch end
|
||||||
|
|
@ -78,8 +84,8 @@ class Loggers():
|
||||||
if self.wandb:
|
if self.wandb:
|
||||||
self.wandb.current_epoch = epoch + 1
|
self.wandb.current_epoch = epoch + 1
|
||||||
|
|
||||||
def on_val_batch_end(self, pred, predn, path, names, im):
|
def on_val_image_end(self, pred, predn, path, names, im):
|
||||||
# Callback runs on train batch end
|
# Callback runs on val image end
|
||||||
if self.wandb:
|
if self.wandb:
|
||||||
self.wandb.val_one_image(pred, predn, path, names, im)
|
self.wandb.val_one_image(pred, predn, path, names, im)
|
||||||
|
|
||||||
|
|
@ -89,25 +95,20 @@ class Loggers():
|
||||||
files = sorted(self.save_dir.glob('val*.jpg'))
|
files = sorted(self.save_dir.glob('val*.jpg'))
|
||||||
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
|
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
|
||||||
|
|
||||||
def on_train_val_end(self, mloss, results, lr, epoch, best_fitness, fi):
|
def on_fit_epoch_end(self, mloss, results, lr, epoch, best_fitness, fi):
|
||||||
# Callback runs on val end during training
|
# Callback runs at the end of each fit (train+val) epoch
|
||||||
vals = list(mloss) + list(results) + lr
|
vals = list(mloss) + list(results) + lr
|
||||||
keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
|
x = {k: v for k, v in zip(self.keys, vals)} # dict
|
||||||
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
|
|
||||||
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
|
|
||||||
'x/lr0', 'x/lr1', 'x/lr2'] # params
|
|
||||||
x = {k: v for k, v in zip(keys, vals)} # dict
|
|
||||||
|
|
||||||
if self.csv:
|
if self.csv:
|
||||||
file = self.save_dir / 'results.csv'
|
file = self.save_dir / 'results.csv'
|
||||||
n = len(x) + 1 # number of cols
|
n = len(x) + 1 # number of cols
|
||||||
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header
|
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
|
||||||
with open(file, 'a') as f:
|
with open(file, 'a') as f:
|
||||||
f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
|
f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
|
||||||
|
|
||||||
if self.tb:
|
if self.tb:
|
||||||
for k, v in x.items():
|
for k, v in x.items():
|
||||||
self.tb.add_scalar(k, v, epoch) # TensorBoard
|
self.tb.add_scalar(k, v, epoch)
|
||||||
|
|
||||||
if self.wandb:
|
if self.wandb:
|
||||||
self.wandb.log(x)
|
self.wandb.log(x)
|
||||||
|
|
@ -119,20 +120,22 @@ class Loggers():
|
||||||
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
|
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
|
||||||
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
|
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
|
||||||
|
|
||||||
def on_train_end(self, last, best, plots):
|
def on_train_end(self, last, best, plots, epoch):
|
||||||
# Callback runs on training end
|
# Callback runs on training end
|
||||||
if plots:
|
if plots:
|
||||||
plot_results(dir=self.save_dir) # save results.png
|
plot_results(dir=self.save_dir) # save results.png
|
||||||
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
||||||
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
|
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
|
||||||
|
|
||||||
|
if self.tb:
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
for f in files:
|
||||||
|
self.tb.add_image(f.stem, np.asarray(Image.open(f)), epoch, dataformats='HWC')
|
||||||
|
|
||||||
if self.wandb:
|
if self.wandb:
|
||||||
wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
|
wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
|
||||||
wandb.log_artifact(str(best if best.exists() else last), type='model',
|
wandb.log_artifact(str(best if best.exists() else last), type='model',
|
||||||
name='run_' + self.wandb.wandb_run.id + '_model',
|
name='run_' + self.wandb.wandb_run.id + '_model',
|
||||||
aliases=['latest', 'best', 'stripped'])
|
aliases=['latest', 'best', 'stripped'])
|
||||||
self.wandb.finish_run()
|
self.wandb.finish_run()
|
||||||
|
|
||||||
def log_images(self, paths):
|
|
||||||
# Log images
|
|
||||||
if self.wandb:
|
|
||||||
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
|
|
||||||
|
|
|
||||||
|
|
@ -281,7 +281,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
|
||||||
plt.savefig(str(Path(path).name) + '.png', dpi=300)
|
plt.savefig(str(Path(path).name) + '.png', dpi=300)
|
||||||
|
|
||||||
|
|
||||||
def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
|
def plot_labels(labels, names=(), save_dir=Path('')):
|
||||||
# plot dataset labels
|
# plot dataset labels
|
||||||
print('Plotting labels... ')
|
print('Plotting labels... ')
|
||||||
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
||||||
|
|
@ -324,10 +324,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg')
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
# loggers
|
|
||||||
if loggers:
|
|
||||||
loggers.log_images(save_dir.glob('*labels*.jpg'))
|
|
||||||
|
|
||||||
|
|
||||||
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
|
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
|
||||||
# Plot hyperparameter evolution results in evolve.txt
|
# Plot hyperparameter evolution results in evolve.txt
|
||||||
|
|
|
||||||
10
val.py
10
val.py
|
|
@ -25,7 +25,7 @@ from utils.general import coco80_to_coco91_class, check_dataset, check_file, che
|
||||||
from utils.metrics import ap_per_class, ConfusionMatrix
|
from utils.metrics import ap_per_class, ConfusionMatrix
|
||||||
from utils.plots import plot_images, output_to_target, plot_study_txt
|
from utils.plots import plot_images, output_to_target, plot_study_txt
|
||||||
from utils.torch_utils import select_device, time_sync
|
from utils.torch_utils import select_device, time_sync
|
||||||
from utils.loggers import Loggers
|
from utils.callbacks import Callbacks
|
||||||
|
|
||||||
|
|
||||||
def save_one_txt(predn, save_conf, shape, file):
|
def save_one_txt(predn, save_conf, shape, file):
|
||||||
|
|
@ -97,7 +97,7 @@ def run(data,
|
||||||
dataloader=None,
|
dataloader=None,
|
||||||
save_dir=Path(''),
|
save_dir=Path(''),
|
||||||
plots=True,
|
plots=True,
|
||||||
loggers=Loggers(),
|
callbacks=Callbacks(),
|
||||||
compute_loss=None,
|
compute_loss=None,
|
||||||
):
|
):
|
||||||
# Initialize/load model and set device
|
# Initialize/load model and set device
|
||||||
|
|
@ -213,7 +213,7 @@ def run(data,
|
||||||
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
|
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
|
||||||
if save_json:
|
if save_json:
|
||||||
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
|
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
|
||||||
loggers.on_val_batch_end(pred, predn, path, names, img[si])
|
callbacks.on_val_image_end(pred, predn, path, names, img[si])
|
||||||
|
|
||||||
# Plot images
|
# Plot images
|
||||||
if plots and batch_i < 3:
|
if plots and batch_i < 3:
|
||||||
|
|
@ -250,7 +250,7 @@ def run(data,
|
||||||
# Plots
|
# Plots
|
||||||
if plots:
|
if plots:
|
||||||
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
||||||
loggers.on_val_end()
|
callbacks.on_val_end()
|
||||||
|
|
||||||
# Save JSON
|
# Save JSON
|
||||||
if save_json and len(jdict):
|
if save_json and len(jdict):
|
||||||
|
|
@ -282,7 +282,7 @@ def run(data,
|
||||||
model.float() # for training
|
model.float() # for training
|
||||||
if not training:
|
if not training:
|
||||||
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
|
||||||
print(f"Results saved to {save_dir}{s}")
|
print(f"Results saved to {colorstr('bold', save_dir)}{s}")
|
||||||
maps = np.zeros(nc) + map
|
maps = np.zeros(nc) + map
|
||||||
for i, c in enumerate(ap_class):
|
for i, c in enumerate(ap_class):
|
||||||
maps[c] = ap[i]
|
maps[c] = ap[i]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue