FP16 to FP32 ckpt load
This commit is contained in:
parent
c5966abba8
commit
14523bb030
4
train.py
4
train.py
|
|
@ -112,8 +112,8 @@ def train(hyp):
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
try:
|
try:
|
||||||
ckpt['model'] = \
|
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
|
||||||
{k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
|
if model.state_dict()[k].shape == v.shape} # to FP32, filter
|
||||||
model.load_state_dict(ckpt['model'], strict=False)
|
model.load_state_dict(ckpt['model'], strict=False)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
|
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue