|
|
@@ -38,9 +38,10 @@ def create(name, pretrained, channels, classes, autoshape): |
|
|
|
fname = f'{name}.pt' # checkpoint filename |
|
|
|
attempt_download(fname) # download if not found locally |
|
|
|
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load |
|
|
|
state_dict = ckpt['model'].float().state_dict() # to FP32 |
|
|
|
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter |
|
|
|
model.load_state_dict(state_dict, strict=False) # load |
|
|
|
msd = model.state_dict() # model state_dict |
|
|
|
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 |
|
|
|
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter |
|
|
|
model.load_state_dict(csd, strict=False) # load |
|
|
|
if len(ckpt['model'].names) == classes: |
|
|
|
model.names = ckpt['model'].names # set class names attribute |
|
|
|
if autoshape: |