diff --git a/models/yolo.py b/models/yolo.py index f9929aa..4b2606d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -160,7 +160,7 @@ class Model(nn.Module): def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers print('Fusing layers... ') for m in self.model.modules(): - if type(m) is Conv: + if type(m) is Conv and hasattr(Conv, 'bn'): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm