@@ -120,7 +120,7 @@ def custom(path_or_model='path/to/model.pt', autoshape=True): | |||
""" | |||
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint | |||
if isinstance(model, dict): | |||
model = model['model'] # load model | |||
model = model['ema' if model.get('ema') else 'model'] # load model | |||
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create | |||
hub_model.load_state_dict(model.float().state_dict()) # load state_dict |
@@ -115,7 +115,8 @@ def attempt_load(weights, map_location=None): | |||
model = Ensemble() | |||
for w in weights if isinstance(weights, list) else [weights]: | |||
attempt_download(w) | |||
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model | |||
ckpt = torch.load(w, map_location=map_location) # load | |||
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model | |||
# Compatibility updates | |||
for m in model.modules(): |
@@ -151,8 +151,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
# EMA | |||
if ema and ckpt.get('ema'): | |||
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict()) | |||
ema.updates = ckpt['ema'][1] | |||
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) | |||
ema.updates = ckpt['updates'] | |||
# Results | |||
if ckpt.get('training_results') is not None: | |||
@@ -383,9 +383,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
ckpt = {'epoch': epoch, | |||
'best_fitness': best_fitness, | |||
'training_results': results_file.read_text(), | |||
'model': ema.ema if final_epoch else deepcopy( | |||
model.module if is_parallel(model) else model).half(), | |||
'ema': (deepcopy(ema.ema).half(), ema.updates), | |||
'model': deepcopy(model.module if is_parallel(model) else model).half(), | |||
'ema': deepcopy(ema.ema).half(), | |||
'updates': ema.updates, | |||
'optimizer': optimizer.state_dict(), | |||
'wandb_id': wandb_run.id if wandb else None} | |||
@@ -481,10 +481,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non | |||
return output | |||
def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer() | |||
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer() | |||
# Strip optimizer from 'f' to finalize training, optionally save as 's' | |||
x = torch.load(f, map_location=torch.device('cpu')) | |||
for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys | |||
if x.get('ema'): | |||
x['model'] = x['ema'] # replace model with ema | |||
for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys | |||
x[k] = None | |||
x['epoch'] = -1 | |||
x['model'].half() # to FP16 | |||
@@ -492,7 +494,7 @@ def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; | |||
p.requires_grad = False | |||
torch.save(x, s or f) | |||
mb = os.path.getsize(s or f) / 1E6 # filesize | |||
print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb)) | |||
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") | |||
def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''): |