updates
This commit is contained in:
parent
1e84a23f38
commit
ce36905358
2
test.py
2
test.py
|
|
@ -256,7 +256,7 @@ if __name__ == '__main__':
|
||||||
opt.augment)
|
opt.augment)
|
||||||
|
|
||||||
elif opt.task == 'study': # run over a range of settings and save/plot
|
elif opt.task == 'study': # run over a range of settings and save/plot
|
||||||
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.p5', 'yolov5x.pt', 'yolov3-spp.pt']:
|
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
|
||||||
f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
|
f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
|
||||||
x = list(range(256, 1024, 32)) # x axis
|
x = list(range(256, 1024, 32)) # x axis
|
||||||
y = [] # y axis
|
y = [] # y axis
|
||||||
|
|
|
||||||
30
train.py
30
train.py
|
|
@ -108,30 +108,30 @@ def train(hyp):
|
||||||
google_utils.attempt_download(weights)
|
google_utils.attempt_download(weights)
|
||||||
start_epoch, best_fitness = 0, 0.0
|
start_epoch, best_fitness = 0, 0.0
|
||||||
if weights.endswith('.pt'): # pytorch format
|
if weights.endswith('.pt'): # pytorch format
|
||||||
chkpt = torch.load(weights, map_location=device)
|
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
try:
|
try:
|
||||||
chkpt['model'] = \
|
ckpt['model'] = \
|
||||||
{k: v for k, v in chkpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
|
{k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
|
||||||
model.load_state_dict(chkpt['model'], strict=False)
|
model.load_state_dict(ckpt['model'], strict=False)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
|
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
|
||||||
% (opt.weights, opt.cfg, opt.weights)
|
% (opt.weights, opt.cfg, opt.weights)
|
||||||
raise KeyError(s) from e
|
raise KeyError(s) from e
|
||||||
|
|
||||||
# load optimizer
|
# load optimizer
|
||||||
if chkpt['optimizer'] is not None:
|
if ckpt['optimizer'] is not None:
|
||||||
optimizer.load_state_dict(chkpt['optimizer'])
|
optimizer.load_state_dict(ckpt['optimizer'])
|
||||||
best_fitness = chkpt['best_fitness']
|
best_fitness = ckpt['best_fitness']
|
||||||
|
|
||||||
# load results
|
# load results
|
||||||
if chkpt.get('training_results') is not None:
|
if ckpt.get('training_results') is not None:
|
||||||
with open(results_file, 'w') as file:
|
with open(results_file, 'w') as file:
|
||||||
file.write(chkpt['training_results']) # write results.txt
|
file.write(ckpt['training_results']) # write results.txt
|
||||||
|
|
||||||
start_epoch = chkpt['epoch'] + 1
|
start_epoch = ckpt['epoch'] + 1
|
||||||
del chkpt
|
del ckpt
|
||||||
|
|
||||||
# Mixed precision training https://github.com/NVIDIA/apex
|
# Mixed precision training https://github.com/NVIDIA/apex
|
||||||
if mixed_precision:
|
if mixed_precision:
|
||||||
|
|
@ -324,17 +324,17 @@ def train(hyp):
|
||||||
save = (not opt.nosave) or (final_epoch and not opt.evolve)
|
save = (not opt.nosave) or (final_epoch and not opt.evolve)
|
||||||
if save:
|
if save:
|
||||||
with open(results_file, 'r') as f: # create checkpoint
|
with open(results_file, 'r') as f: # create checkpoint
|
||||||
chkpt = {'epoch': epoch,
|
ckpt = {'epoch': epoch,
|
||||||
'best_fitness': best_fitness,
|
'best_fitness': best_fitness,
|
||||||
'training_results': f.read(),
|
'training_results': f.read(),
|
||||||
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
|
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
|
||||||
'optimizer': None if final_epoch else optimizer.state_dict()}
|
'optimizer': None if final_epoch else optimizer.state_dict()}
|
||||||
|
|
||||||
# Save last, best and delete
|
# Save last, best and delete
|
||||||
torch.save(chkpt, last)
|
torch.save(ckpt, last)
|
||||||
if (best_fitness == fi) and not final_epoch:
|
if (best_fitness == fi) and not final_epoch:
|
||||||
torch.save(chkpt, best)
|
torch.save(ckpt, best)
|
||||||
del chkpt
|
del ckpt
|
||||||
|
|
||||||
# end epoch ----------------------------------------------------------------------------------------------------
|
# end epoch ----------------------------------------------------------------------------------------------------
|
||||||
# end training
|
# end training
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue