""" | """ | ||||
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint | model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint | ||||
if isinstance(model, dict): | if isinstance(model, dict): | ||||
model = model['model'] # load model | |||||
model = model['ema' if model.get('ema') else 'model'] # load model | |||||
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create | hub_model = Model(model.yaml).to(next(model.parameters()).device) # create | ||||
hub_model.load_state_dict(model.float().state_dict()) # load state_dict | hub_model.load_state_dict(model.float().state_dict()) # load state_dict |
model = Ensemble() | model = Ensemble() | ||||
for w in weights if isinstance(weights, list) else [weights]: | for w in weights if isinstance(weights, list) else [weights]: | ||||
attempt_download(w) | attempt_download(w) | ||||
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model | |||||
ckpt = torch.load(w, map_location=map_location) # load | |||||
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model | |||||
# Compatibility updates | # Compatibility updates | ||||
for m in model.modules(): | for m in model.modules(): |
# EMA | # EMA | ||||
if ema and ckpt.get('ema'): | if ema and ckpt.get('ema'): | ||||
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict()) | |||||
ema.updates = ckpt['ema'][1] | |||||
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) | |||||
ema.updates = ckpt['updates'] | |||||
# Results | # Results | ||||
if ckpt.get('training_results') is not None: | if ckpt.get('training_results') is not None: | ||||
ckpt = {'epoch': epoch, | ckpt = {'epoch': epoch, | ||||
'best_fitness': best_fitness, | 'best_fitness': best_fitness, | ||||
'training_results': results_file.read_text(), | 'training_results': results_file.read_text(), | ||||
'model': ema.ema if final_epoch else deepcopy( | |||||
model.module if is_parallel(model) else model).half(), | |||||
'ema': (deepcopy(ema.ema).half(), ema.updates), | |||||
'model': deepcopy(model.module if is_parallel(model) else model).half(), | |||||
'ema': deepcopy(ema.ema).half(), | |||||
'updates': ema.updates, | |||||
'optimizer': optimizer.state_dict(), | 'optimizer': optimizer.state_dict(), | ||||
'wandb_id': wandb_run.id if wandb else None} | 'wandb_id': wandb_run.id if wandb else None} | ||||
return output | return output | ||||
def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer() | |||||
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer() | |||||
# Strip optimizer from 'f' to finalize training, optionally save as 's' | # Strip optimizer from 'f' to finalize training, optionally save as 's' | ||||
x = torch.load(f, map_location=torch.device('cpu')) | x = torch.load(f, map_location=torch.device('cpu')) | ||||
for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys | |||||
if x.get('ema'): | |||||
x['model'] = x['ema'] # replace model with ema | |||||
for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys | |||||
x[k] = None | x[k] = None | ||||
x['epoch'] = -1 | x['epoch'] = -1 | ||||
x['model'].half() # to FP16 | x['model'].half() # to FP16 | ||||
p.requires_grad = False | p.requires_grad = False | ||||
torch.save(x, s or f) | torch.save(x, s or f) | ||||
mb = os.path.getsize(s or f) / 1E6 # filesize | mb = os.path.getsize(s or f) / 1E6 # filesize | ||||
print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb)) | |||||
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") | |||||
def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''): | def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''): |