Browse Source

Improved model+EMA checkpointing 2 (#2295)

5.0
Glenn Jocher GitHub 3 years ago
parent
commit
71dd2768f2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions
  1. +1
    -0
      test.py
  2. +3
    -4
      train.py

+ 1
- 0
test.py View File

print(f'pycocotools unable to run: {e}') print(f'pycocotools unable to run: {e}')


# Return results # Return results
model.float() # for training
if not training: if not training:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}") print(f"Results saved to {save_dir}{s}")

+ 3
- 4
train.py View File

import os import os
import random import random
import time import time
from copy import deepcopy
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread


ckpt = {'epoch': epoch, ckpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': results_file.read_text(), 'training_results': results_file.read_text(),
'model': (model.module if is_parallel(model) else model).half(),
'ema': (ema.ema.half(), ema.updates),
'model': deepcopy(model.module if is_parallel(model) else model).half(),
'ema': (deepcopy(ema.ema).half(), ema.updates),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None} 'wandb_id': wandb_run.id if wandb else None}


torch.save(ckpt, best) torch.save(ckpt, best)
del ckpt del ckpt


model.float(), ema.ema.float()

# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
# end training # end training



Loading…
Cancel
Save