|
|
@@ -66,14 +66,16 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
is_coco = opt.data.endswith('coco.yaml') |
|
|
|
|
|
|
|
# Logging- Doing this before checking the dataset. Might update data_dict |
|
|
|
loggers = {'wandb': None} # loggers dict |
|
|
|
if rank in [-1, 0]: |
|
|
|
opt.hyp = hyp # add hyperparameters |
|
|
|
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None |
|
|
|
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict) |
|
|
|
loggers['wandb'] = wandb_logger.wandb |
|
|
|
data_dict = wandb_logger.data_dict |
|
|
|
if wandb_logger.wandb: |
|
|
|
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming |
|
|
|
loggers = {'wandb': wandb_logger.wandb} # loggers dict |
|
|
|
|
|
|
|
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes |
|
|
|
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names |
|
|
|
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check |
|
|
@@ -381,6 +383,7 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] |
|
|
|
if fi > best_fitness: |
|
|
|
best_fitness = fi |
|
|
|
wandb_logger.end_epoch(best_result=best_fitness == fi) |
|
|
|
|
|
|
|
# Save model |
|
|
|
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save |
|
|
@@ -402,7 +405,6 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
wandb_logger.log_model( |
|
|
|
last.parent, opt, epoch, fi, best_model=best_fitness == fi) |
|
|
|
del ckpt |
|
|
|
wandb_logger.end_epoch(best_result=best_fitness == fi) |
|
|
|
|
|
|
|
# end epoch ---------------------------------------------------------------------------------------------------- |
|
|
|
# end training |
|
|
@@ -442,10 +444,10 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
wandb_logger.wandb.log_artifact(str(final), type='model', |
|
|
|
name='run_' + wandb_logger.wandb_run.id + '_model', |
|
|
|
aliases=['last', 'best', 'stripped']) |
|
|
|
wandb_logger.finish_run() |
|
|
|
else: |
|
|
|
dist.destroy_process_group() |
|
|
|
torch.cuda.empty_cache() |
|
|
|
wandb_logger.finish_run() |
|
|
|
return results |
|
|
|
|
|
|
|
|