|
|
@@ -108,30 +108,30 @@ def train(hyp): |
|
|
|
google_utils.attempt_download(weights) |
|
|
|
start_epoch, best_fitness = 0, 0.0 |
|
|
|
if weights.endswith('.pt'): # pytorch format |
|
|
|
chkpt = torch.load(weights, map_location=device) |
|
|
|
ckpt = torch.load(weights, map_location=device) # load checkpoint |
|
|
|
|
|
|
|
# load model |
|
|
|
try: |
|
|
|
chkpt['model'] = \ |
|
|
|
{k: v for k, v in chkpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()} |
|
|
|
model.load_state_dict(chkpt['model'], strict=False) |
|
|
|
ckpt['model'] = \ |
|
|
|
{k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()} |
|
|
|
model.load_state_dict(ckpt['model'], strict=False) |
|
|
|
except KeyError as e: |
|
|
|
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \ |
|
|
|
% (opt.weights, opt.cfg, opt.weights) |
|
|
|
raise KeyError(s) from e |
|
|
|
|
|
|
|
# load optimizer |
|
|
|
if chkpt['optimizer'] is not None: |
|
|
|
optimizer.load_state_dict(chkpt['optimizer']) |
|
|
|
best_fitness = chkpt['best_fitness'] |
|
|
|
if ckpt['optimizer'] is not None: |
|
|
|
optimizer.load_state_dict(ckpt['optimizer']) |
|
|
|
best_fitness = ckpt['best_fitness'] |
|
|
|
|
|
|
|
# load results |
|
|
|
if chkpt.get('training_results') is not None: |
|
|
|
if ckpt.get('training_results') is not None: |
|
|
|
with open(results_file, 'w') as file: |
|
|
|
file.write(chkpt['training_results']) # write results.txt |
|
|
|
file.write(ckpt['training_results']) # write results.txt |
|
|
|
|
|
|
|
start_epoch = chkpt['epoch'] + 1 |
|
|
|
del chkpt |
|
|
|
start_epoch = ckpt['epoch'] + 1 |
|
|
|
del ckpt |
|
|
|
|
|
|
|
# Mixed precision training https://github.com/NVIDIA/apex |
|
|
|
if mixed_precision: |
|
|
@@ -324,17 +324,17 @@ def train(hyp): |
|
|
|
save = (not opt.nosave) or (final_epoch and not opt.evolve) |
|
|
|
if save: |
|
|
|
with open(results_file, 'r') as f: # create checkpoint |
|
|
|
chkpt = {'epoch': epoch, |
|
|
|
ckpt = {'epoch': epoch, |
|
|
|
'best_fitness': best_fitness, |
|
|
|
'training_results': f.read(), |
|
|
|
'model': ema.ema.module if hasattr(model, 'module') else ema.ema, |
|
|
|
'optimizer': None if final_epoch else optimizer.state_dict()} |
|
|
|
|
|
|
|
# Save last, best and delete |
|
|
|
torch.save(chkpt, last) |
|
|
|
torch.save(ckpt, last) |
|
|
|
if (best_fitness == fi) and not final_epoch: |
|
|
|
torch.save(chkpt, best) |
|
|
|
del chkpt |
|
|
|
torch.save(ckpt, best) |
|
|
|
del ckpt |
|
|
|
|
|
|
|
# end epoch ---------------------------------------------------------------------------------------------------- |
|
|
|
# end training |