|
|
@@ -46,8 +46,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
opt, |
|
|
|
device, |
|
|
|
): |
|
|
|
save_dir, epochs, batch_size, weights, single_cls = \ |
|
|
|
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls |
|
|
|
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, notest, nosave, workers, = \ |
|
|
|
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ |
|
|
|
opt.resume, opt.notest, opt.nosave, opt.workers |
|
|
|
|
|
|
|
# Directories |
|
|
|
save_dir = Path(save_dir) |
|
|
@@ -70,34 +71,34 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
yaml.safe_dump(vars(opt), f, sort_keys=False) |
|
|
|
|
|
|
|
# Configure |
|
|
|
plots = not opt.evolve # create plots |
|
|
|
plots = not evolve # create plots |
|
|
|
cuda = device.type != 'cpu' |
|
|
|
init_seeds(2 + RANK) |
|
|
|
with open(opt.data) as f: |
|
|
|
with open(data) as f: |
|
|
|
data_dict = yaml.safe_load(f) # data dict |
|
|
|
|
|
|
|
# Loggers |
|
|
|
loggers = {'wandb': None, 'tb': None} # loggers dict |
|
|
|
if RANK in [-1, 0]: |
|
|
|
# TensorBoard |
|
|
|
if not opt.evolve: |
|
|
|
if not evolve: |
|
|
|
prefix = colorstr('tensorboard: ') |
|
|
|
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/") |
|
|
|
loggers['tb'] = SummaryWriter(opt.save_dir) |
|
|
|
loggers['tb'] = SummaryWriter(str(save_dir)) |
|
|
|
|
|
|
|
# W&B |
|
|
|
opt.hyp = hyp # add hyperparameters |
|
|
|
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None |
|
|
|
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict) |
|
|
|
loggers['wandb'] = wandb_logger.wandb |
|
|
|
data_dict = wandb_logger.data_dict |
|
|
|
if wandb_logger.wandb: |
|
|
|
if loggers['wandb']: |
|
|
|
data_dict = wandb_logger.data_dict |
|
|
|
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming |
|
|
|
|
|
|
|
nc = 1 if single_cls else int(data_dict['nc']) # number of classes |
|
|
|
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names |
|
|
|
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check |
|
|
|
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset |
|
|
|
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data) # check |
|
|
|
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset |
|
|
|
|
|
|
|
# Model |
|
|
|
pretrained = weights.endswith('.pt') |
|
|
@@ -105,14 +106,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
with torch_distributed_zero_first(RANK): |
|
|
|
weights = attempt_download(weights) # download if not found locally |
|
|
|
ckpt = torch.load(weights, map_location=device) # load checkpoint |
|
|
|
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create |
|
|
|
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys |
|
|
|
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create |
|
|
|
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys |
|
|
|
state_dict = ckpt['model'].float().state_dict() # to FP32 |
|
|
|
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect |
|
|
|
model.load_state_dict(state_dict, strict=False) # load |
|
|
|
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report |
|
|
|
else: |
|
|
|
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create |
|
|
|
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create |
|
|
|
with torch_distributed_zero_first(RANK): |
|
|
|
check_dataset(data_dict) # check |
|
|
|
train_path = data_dict['train'] |
|
|
@@ -182,7 +183,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
|
|
|
|
# Epochs |
|
|
|
start_epoch = ckpt['epoch'] + 1 |
|
|
|
if opt.resume: |
|
|
|
if resume: |
|
|
|
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs) |
|
|
|
if epochs < start_epoch: |
|
|
|
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % |
|
|
@@ -210,20 +211,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
# Trainloader |
|
|
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, |
|
|
|
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK, |
|
|
|
workers=opt.workers, |
|
|
|
workers=workers, |
|
|
|
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) |
|
|
|
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class |
|
|
|
nb = len(dataloader) # number of batches |
|
|
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) |
|
|
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1) |
|
|
|
|
|
|
|
# Process 0 |
|
|
|
if RANK in [-1, 0]: |
|
|
|
testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls, |
|
|
|
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, |
|
|
|
workers=opt.workers, |
|
|
|
hyp=hyp, cache=opt.cache_images and not notest, rect=True, rank=-1, |
|
|
|
workers=workers, |
|
|
|
pad=0.5, prefix=colorstr('val: '))[0] |
|
|
|
|
|
|
|
if not opt.resume: |
|
|
|
if not resume: |
|
|
|
labels = np.concatenate(dataset.labels, 0) |
|
|
|
c = torch.tensor(labels[:, 0]) # classes |
|
|
|
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency |
|
|
@@ -356,8 +357,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
with warnings.catch_warnings(): |
|
|
|
warnings.simplefilter('ignore') # suppress jit trace warning |
|
|
|
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) |
|
|
|
elif plots and ni == 10 and wandb_logger.wandb: |
|
|
|
wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in |
|
|
|
elif plots and ni == 10 and loggers['wandb']: |
|
|
|
wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in |
|
|
|
save_dir.glob('train*.jpg') if x.exists()]}) |
|
|
|
|
|
|
|
# end batch ------------------------------------------------------------------------------------------------ |
|
|
@@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
# mAP |
|
|
|
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) |
|
|
|
final_epoch = epoch + 1 == epochs |
|
|
|
if not opt.notest or final_epoch: # Calculate mAP |
|
|
|
if not notest or final_epoch: # Calculate mAP |
|
|
|
wandb_logger.current_epoch = epoch + 1 |
|
|
|
results, maps, _ = test.test(data_dict, |
|
|
|
batch_size=batch_size // WORLD_SIZE * 2, |
|
|
@@ -398,7 +399,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): |
|
|
|
if loggers['tb']: |
|
|
|
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard |
|
|
|
if wandb_logger.wandb: |
|
|
|
if loggers['wandb']: |
|
|
|
wandb_logger.log({tag: x}) # W&B |
|
|
|
|
|
|
|
# Update best mAP |
|
|
@@ -408,7 +409,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
wandb_logger.end_epoch(best_result=best_fitness == fi) |
|
|
|
|
|
|
|
# Save model |
|
|
|
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save |
|
|
|
if (not nosave) or (final_epoch and not evolve): # if save |
|
|
|
ckpt = {'epoch': epoch, |
|
|
|
'best_fitness': best_fitness, |
|
|
|
'training_results': results_file.read_text(), |
|
|
@@ -416,13 +417,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
'ema': deepcopy(ema.ema).half(), |
|
|
|
'updates': ema.updates, |
|
|
|
'optimizer': optimizer.state_dict(), |
|
|
|
'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None} |
|
|
|
'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None} |
|
|
|
|
|
|
|
# Save last, best and delete |
|
|
|
torch.save(ckpt, last) |
|
|
|
if best_fitness == fi: |
|
|
|
torch.save(ckpt, best) |
|
|
|
if wandb_logger.wandb: |
|
|
|
if loggers['wandb']: |
|
|
|
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1: |
|
|
|
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi) |
|
|
|
del ckpt |
|
|
@@ -433,15 +434,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') |
|
|
|
if plots: |
|
|
|
plot_results(save_dir=save_dir) # save as results.png |
|
|
|
if wandb_logger.wandb: |
|
|
|
if loggers['wandb']: |
|
|
|
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] |
|
|
|
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files |
|
|
|
wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files |
|
|
|
if (save_dir / f).exists()]}) |
|
|
|
|
|
|
|
if not opt.evolve: |
|
|
|
if not evolve: |
|
|
|
if is_coco: # COCO dataset |
|
|
|
for m in [last, best] if best.exists() else [last]: # speed, mAP tests |
|
|
|
results, _, _ = test.test(opt.data, |
|
|
|
results, _, _ = test.test(data, |
|
|
|
batch_size=batch_size // WORLD_SIZE * 2, |
|
|
|
imgsz=imgsz_test, |
|
|
|
conf_thres=0.001, |
|
|
@@ -457,17 +458,17 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
for f in last, best: |
|
|
|
if f.exists(): |
|
|
|
strip_optimizer(f) # strip optimizers |
|
|
|
if wandb_logger.wandb: # Log the stripped model |
|
|
|
wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model', |
|
|
|
name='run_' + wandb_logger.wandb_run.id + '_model', |
|
|
|
aliases=['latest', 'best', 'stripped']) |
|
|
|
if loggers['wandb']: # Log the stripped model |
|
|
|
loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model', |
|
|
|
name='run_' + wandb_logger.wandb_run.id + '_model', |
|
|
|
aliases=['latest', 'best', 'stripped']) |
|
|
|
wandb_logger.finish_run() |
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
def parse_opt(): |
|
|
|
def parse_opt(known=False): |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path') |
|
|
|
parser.add_argument('--cfg', type=str, default='', help='model.yaml path') |
|
|
@@ -503,7 +504,7 @@ def parse_opt(): |
|
|
|
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') |
|
|
|
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') |
|
|
|
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') |
|
|
|
opt = parser.parse_args() |
|
|
|
opt = parser.parse_known_args()[0] if known else parser.parse_args() |
|
|
|
return opt |
|
|
|
|
|
|
|
|
|
|
@@ -633,6 +634,14 @@ def main(opt): |
|
|
|
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}') |
|
|
|
|
|
|
|
|
|
|
|
def run(**kwargs): |
|
|
|
# Usage: import train; train.run(imgsz=320, weights='yolov5m.pt') |
|
|
|
opt = parse_opt(True) |
|
|
|
for k, v in kwargs.items(): |
|
|
|
setattr(opt, k, v) |
|
|
|
main(opt) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
opt = parse_opt() |
|
|
|
main(opt) |