浏览代码

FP16 to FP32 ckpt load

5.0
Glenn Jocher 4 年前
父节点
当前提交
14523bb030
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. +2
    -2
      train.py

+ 2
- 2
train.py 查看文件

@@ -112,8 +112,8 @@ def train(hyp):

# load model
try:
ckpt['model'] = \
{k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
if model.state_dict()[k].shape == v.shape} # to FP32, filter
model.load_state_dict(ckpt['model'], strict=False)
except KeyError as e:
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \

正在加载...
取消
保存