backbone as FP16, save default to FP32
This commit is contained in:
parent
d9b64c27c2
commit
cce95e744d
2
train.py
2
train.py
|
|
@ -332,7 +332,7 @@ def train(hyp):
|
||||||
ckpt = {'epoch': epoch,
|
ckpt = {'epoch': epoch,
|
||||||
'best_fitness': best_fitness,
|
'best_fitness': best_fitness,
|
||||||
'training_results': f.read(),
|
'training_results': f.read(),
|
||||||
'model': ema.ema.module.half() if hasattr(model, 'module') else ema.ema.half(),
|
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
|
||||||
'optimizer': None if final_epoch else optimizer.state_dict()}
|
'optimizer': None if final_epoch else optimizer.state_dict()}
|
||||||
|
|
||||||
# Save last, best and delete
|
# Save last, best and delete
|
||||||
|
|
|
||||||
|
|
@ -627,13 +627,12 @@ def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_op
|
||||||
def create_backbone(f='weights/best.pt', s='weights/backbone.pt'): # from utils.utils import *; create_backbone()
|
def create_backbone(f='weights/best.pt', s='weights/backbone.pt'): # from utils.utils import *; create_backbone()
|
||||||
# create backbone 's' from 'f'
|
# create backbone 's' from 'f'
|
||||||
device = torch.device('cpu')
|
device = torch.device('cpu')
|
||||||
x = torch.load(f, map_location=device)
|
|
||||||
torch.save(x, s) # update model if SourceChangeWarning
|
|
||||||
x = torch.load(s, map_location=device)
|
x = torch.load(s, map_location=device)
|
||||||
|
|
||||||
x['optimizer'] = None
|
x['optimizer'] = None
|
||||||
x['training_results'] = None
|
x['training_results'] = None
|
||||||
x['epoch'] = -1
|
x['epoch'] = -1
|
||||||
|
x['model'].half() # to FP16
|
||||||
for p in x['model'].parameters():
|
for p in x['model'].parameters():
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
torch.save(x, s)
|
torch.save(x, s)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue