瀏覽代碼

model fuse

5.0
Glenn Jocher 4 年之前
父節點
當前提交
8fe299f179
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. +1
    -1
      utils/torch_utils.py

+ 1
- 1
utils/torch_utils.py 查看文件

@@ -90,7 +90,7 @@ def fuse_conv_and_bn(conv, bn):
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros(conv.weight.size(0))
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device)
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)


Loading…
取消
儲存