Browse Source

updates

5.0
Glenn Jocher 4 years ago
parent
commit
ce36905358
2 changed files with 16 additions and 16 deletions
  1. +1
    -1
      test.py
  2. +15
    -15
      train.py

+ 1
- 1
test.py View File

@@ -256,7 +256,7 @@ if __name__ == '__main__':
opt.augment)

elif opt.task == 'study': # run over a range of settings and save/plot
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.p5', 'yolov5x.pt', 'yolov3-spp.pt']:
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
x = list(range(256, 1024, 32)) # x axis
y = [] # y axis

+ 15
- 15
train.py View File

@@ -108,30 +108,30 @@ def train(hyp):
google_utils.attempt_download(weights)
start_epoch, best_fitness = 0, 0.0
if weights.endswith('.pt'): # pytorch format
chkpt = torch.load(weights, map_location=device)
ckpt = torch.load(weights, map_location=device) # load checkpoint

# load model
try:
chkpt['model'] = \
{k: v for k, v in chkpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
model.load_state_dict(chkpt['model'], strict=False)
ckpt['model'] = \
{k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
model.load_state_dict(ckpt['model'], strict=False)
except KeyError as e:
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
% (opt.weights, opt.cfg, opt.weights)
raise KeyError(s) from e

# load optimizer
if chkpt['optimizer'] is not None:
optimizer.load_state_dict(chkpt['optimizer'])
best_fitness = chkpt['best_fitness']
if ckpt['optimizer'] is not None:
optimizer.load_state_dict(ckpt['optimizer'])
best_fitness = ckpt['best_fitness']

# load results
if chkpt.get('training_results') is not None:
if ckpt.get('training_results') is not None:
with open(results_file, 'w') as file:
file.write(chkpt['training_results']) # write results.txt
file.write(ckpt['training_results']) # write results.txt

start_epoch = chkpt['epoch'] + 1
del chkpt
start_epoch = ckpt['epoch'] + 1
del ckpt

# Mixed precision training https://github.com/NVIDIA/apex
if mixed_precision:
@@ -324,17 +324,17 @@ def train(hyp):
save = (not opt.nosave) or (final_epoch and not opt.evolve)
if save:
with open(results_file, 'r') as f: # create checkpoint
chkpt = {'epoch': epoch,
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': f.read(),
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
'optimizer': None if final_epoch else optimizer.state_dict()}

# Save last, best and delete
torch.save(chkpt, last)
torch.save(ckpt, last)
if (best_fitness == fi) and not final_epoch:
torch.save(chkpt, best)
del chkpt
torch.save(ckpt, best)
del ckpt

# end epoch ----------------------------------------------------------------------------------------------------
# end training

Loading…
Cancel
Save