Browse Source

Add `DWConvClass()` (#4274)

* Add `DWConvClass()`

* Cleanup

* Cleanup2
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
587c4b4b81
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 5 deletions
  1. +9
    -2
      models/common.py
  2. +1
    -1
      models/experimental.py
  3. +2
    -2
      models/yolo.py

+ 9
- 2
models/common.py View File

@@ -30,7 +30,7 @@ def autopad(k, p=None): # kernel, padding


def DWConv(c1, c2, k=1, s=1, act=True):
# Depth-wise convolution
# Depth-wise convolution function
return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)


@@ -45,10 +45,17 @@ class Conv(nn.Module):
def forward(self, x):
return self.act(self.bn(self.conv(x)))

def fuseforward(self, x):
def forward_fuse(self, x):
return self.act(self.conv(x))


class DWConvClass(Conv):
# Depth-wise convolution class
def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super().__init__(c1, c2, k, s, act)
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k), groups=math.gcd(c1, c2), bias=False)


class TransformerLayer(nn.Module):
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
def __init__(self, c, num_heads):

+ 1
- 1
models/experimental.py View File

@@ -72,7 +72,7 @@ class GhostBottleneck(nn.Module):


class MixConv2d(nn.Module):
# Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
super().__init__()
groups = len(k)

+ 2
- 2
models/yolo.py View File

@@ -202,10 +202,10 @@ class Model(nn.Module):
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
LOGGER.info('Fusing layers... ')
for m in self.model.modules():
if type(m) is Conv and hasattr(m, 'bn'):
if isinstance(m, (Conv, DWConvClass)) and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.fuseforward # update forward
m.forward = m.forward_fuse # update forward
self.info()
return self


Loading…
Cancel
Save