浏览代码

Improved model+EMA checkpointing 2 (#2295)

5.0
Glenn Jocher GitHub 3 年前
父节点
当前提交
71dd2768f2
找不到此签名对应的密钥 GPG 密钥 ID: 4AEE18F83AFDEB23
共有 2 个文件被更改,包括 4 次插入4 次删除
  1. +1
    -0
      test.py
  2. +3
    -4
      train.py

+ 1
- 0
test.py 查看文件

@@ -269,6 +269,7 @@ def test(data,
print(f'pycocotools unable to run: {e}')

# Return results
model.float() # for training
if not training:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")

+ 3
- 4
train.py 查看文件

@@ -4,6 +4,7 @@ import math
import os
import random
import time
from copy import deepcopy
from pathlib import Path
from threading import Thread

@@ -381,8 +382,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': (model.module if is_parallel(model) else model).half(),
'ema': (ema.ema.half(), ema.updates),
'model': deepcopy(model.module if is_parallel(model) else model).half(),
'ema': (deepcopy(ema.ema).half(), ema.updates),
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}

@@ -392,8 +393,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
torch.save(ckpt, best)
del ckpt

model.float(), ema.ema.float()

# end epoch ----------------------------------------------------------------------------------------------------
# end training


正在加载...
取消
保存