Browse Source

EMA bug fix 2 (#2330)

* EMA bug fix 2

* update
5.0
Glenn Jocher GitHub 3 years ago
parent
commit
fab5085674
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 10 deletions
  1. +1
    -1
      hubconf.py
  2. +2
    -1
      models/experimental.py
  3. +5
    -5
      train.py
  4. +5
    -3
      utils/general.py

+ 1
- 1
hubconf.py View File

""" """
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
if isinstance(model, dict): if isinstance(model, dict):
model = model['model'] # load model
model = model['ema' if model.get('ema') else 'model'] # load model


hub_model = Model(model.yaml).to(next(model.parameters()).device) # create hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
hub_model.load_state_dict(model.float().state_dict()) # load state_dict hub_model.load_state_dict(model.float().state_dict()) # load state_dict

+ 2
- 1
models/experimental.py View File

model = Ensemble() model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
attempt_download(w) attempt_download(w)
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
ckpt = torch.load(w, map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model


# Compatibility updates # Compatibility updates
for m in model.modules(): for m in model.modules():

+ 5
- 5
train.py View File



# EMA # EMA
if ema and ckpt.get('ema'): if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
ema.updates = ckpt['ema'][1]
ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
ema.updates = ckpt['updates']


# Results # Results
if ckpt.get('training_results') is not None: if ckpt.get('training_results') is not None:
ckpt = {'epoch': epoch, ckpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': results_file.read_text(), 'training_results': results_file.read_text(),
'model': ema.ema if final_epoch else deepcopy(
model.module if is_parallel(model) else model).half(),
'ema': (deepcopy(ema.ema).half(), ema.updates),
'model': deepcopy(model.module if is_parallel(model) else model).half(),
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None} 'wandb_id': wandb_run.id if wandb else None}



+ 5
- 3
utils/general.py View File

return output return output




def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's' # Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu')) x = torch.load(f, map_location=torch.device('cpu'))
for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
x[k] = None x[k] = None
x['epoch'] = -1 x['epoch'] = -1
x['model'].half() # to FP16 x['model'].half() # to FP16
p.requires_grad = False p.requires_grad = False
torch.save(x, s or f) torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize mb = os.path.getsize(s or f) / 1E6 # filesize
print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")




def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''): def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):

Loading…
Cancel
Save