@@ -90,7 +90,7 @@ def fuse_conv_and_bn(conv, bn): | |||
if conv.bias is not None: | |||
b_conv = conv.bias | |||
else: | |||
b_conv = torch.zeros(conv.weight.size(0)) | |||
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) | |||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) | |||
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) | |||