|
|
@@ -62,7 +62,6 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
init_seeds(2 + rank) |
|
|
|
with open(opt.data) as f: |
|
|
|
data_dict = yaml.safe_load(f) # data dict |
|
|
|
is_coco = opt.data.endswith('coco.yaml') |
|
|
|
|
|
|
|
# Logging- Doing this before checking the dataset. Might update data_dict |
|
|
|
loggers = {'wandb': None} # loggers dict |
|
|
@@ -78,6 +77,7 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes |
|
|
|
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names |
|
|
|
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check |
|
|
|
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset |
|
|
|
|
|
|
|
# Model |
|
|
|
pretrained = weights.endswith('.pt') |
|
|
@@ -358,6 +358,7 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
single_cls=opt.single_cls, |
|
|
|
dataloader=testloader, |
|
|
|
save_dir=save_dir, |
|
|
|
save_json=is_coco and final_epoch, |
|
|
|
verbose=nc < 50 and final_epoch, |
|
|
|
plots=plots and final_epoch, |
|
|
|
wandb_logger=wandb_logger, |
|
|
@@ -409,41 +410,38 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
# end epoch ---------------------------------------------------------------------------------------------------- |
|
|
|
# end training |
|
|
|
if rank in [-1, 0]: |
|
|
|
# Plots |
|
|
|
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') |
|
|
|
if plots: |
|
|
|
plot_results(save_dir=save_dir) # save as results.png |
|
|
|
if wandb_logger.wandb: |
|
|
|
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] |
|
|
|
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files |
|
|
|
if (save_dir / f).exists()]}) |
|
|
|
# Test best.pt |
|
|
|
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) |
|
|
|
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO |
|
|
|
for m in [last, best] if best.exists() else [last]: # speed, mAP tests |
|
|
|
results, _, _ = test.test(opt.data, |
|
|
|
batch_size=batch_size * 2, |
|
|
|
imgsz=imgsz_test, |
|
|
|
conf_thres=0.001, |
|
|
|
iou_thres=0.7, |
|
|
|
model=attempt_load(m, device).half(), |
|
|
|
single_cls=opt.single_cls, |
|
|
|
dataloader=testloader, |
|
|
|
save_dir=save_dir, |
|
|
|
save_json=True, |
|
|
|
plots=False, |
|
|
|
is_coco=is_coco) |
|
|
|
|
|
|
|
# Strip optimizers |
|
|
|
final = best if best.exists() else last # final model |
|
|
|
for f in last, best: |
|
|
|
if f.exists(): |
|
|
|
strip_optimizer(f) # strip optimizers |
|
|
|
if opt.bucket: |
|
|
|
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload |
|
|
|
if wandb_logger.wandb and not opt.evolve: # Log the stripped model |
|
|
|
wandb_logger.wandb.log_artifact(str(final), type='model', |
|
|
|
name='run_' + wandb_logger.wandb_run.id + '_model', |
|
|
|
aliases=['latest', 'best', 'stripped']) |
|
|
|
|
|
|
|
if not opt.evolve: |
|
|
|
if is_coco: # COCO dataset |
|
|
|
for m in [last, best] if best.exists() else [last]: # speed, mAP tests |
|
|
|
results, _, _ = test.test(opt.data, |
|
|
|
batch_size=batch_size * 2, |
|
|
|
imgsz=imgsz_test, |
|
|
|
conf_thres=0.001, |
|
|
|
iou_thres=0.7, |
|
|
|
model=attempt_load(m, device).half(), |
|
|
|
single_cls=opt.single_cls, |
|
|
|
dataloader=testloader, |
|
|
|
save_dir=save_dir, |
|
|
|
save_json=True, |
|
|
|
plots=False, |
|
|
|
is_coco=is_coco) |
|
|
|
|
|
|
|
# Strip optimizers |
|
|
|
for f in last, best: |
|
|
|
if f.exists(): |
|
|
|
strip_optimizer(f) # strip optimizers |
|
|
|
if wandb_logger.wandb: # Log the stripped model |
|
|
|
wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model', |
|
|
|
name='run_' + wandb_logger.wandb_run.id + '_model', |
|
|
|
aliases=['latest', 'best', 'stripped']) |
|
|
|
wandb_logger.finish_run() |
|
|
|
else: |
|
|
|
dist.destroy_process_group() |