Speed profiling improvements (#2648)
* Speed profiling improvements * Update torch_utils.py deepcopy() required to avoid adding elements to model. * Update torch_utils.py
This commit is contained in:
parent
1e8ab3f5f2
commit
866bc7d640
|
|
@ -38,9 +38,10 @@ def create(name, pretrained, channels, classes, autoshape):
|
||||||
fname = f'{name}.pt' # checkpoint filename
|
fname = f'{name}.pt' # checkpoint filename
|
||||||
attempt_download(fname) # download if not found locally
|
attempt_download(fname) # download if not found locally
|
||||||
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
||||||
state_dict = ckpt['model'].float().state_dict() # to FP32
|
msd = model.state_dict() # model state_dict
|
||||||
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
||||||
model.load_state_dict(state_dict, strict=False) # load
|
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|
||||||
|
model.load_state_dict(csd, strict=False) # load
|
||||||
if len(ckpt['model'].names) == classes:
|
if len(ckpt['model'].names) == classes:
|
||||||
model.names = ckpt['model'].names # set class names attribute
|
model.names = ckpt['model'].names # set class names attribute
|
||||||
if autoshape:
|
if autoshape:
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ def fuse_conv_and_bn(conv, bn):
|
||||||
# prepare filters
|
# prepare filters
|
||||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||||
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
||||||
|
|
||||||
# prepare spatial bias
|
# prepare spatial bias
|
||||||
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue