Browse Source

Fix MixConv2d() remove shortcut + apply depthwise (#5410)

modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
5d4258fac5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 12 deletions
  1. +1
    -1
      models/common.py
  2. +11
    -10
      models/experimental.py
  3. +1
    -1
      utils/torch_utils.py

+ 1
- 1
models/common.py View File

@@ -113,7 +113,7 @@ class BottleneckCSP(nn.Module):
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
self.cv4 = Conv(2 * c_, c2, 1, 1)
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.act = nn.SiLU()
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

def forward(self, x):

+ 11
- 10
models/experimental.py View File

@@ -2,7 +2,7 @@
"""
Experimental modules
"""
import math
import numpy as np
import torch
import torch.nn as nn
@@ -48,26 +48,27 @@ class Sum(nn.Module):

class MixConv2d(nn.Module):
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
super().__init__()
groups = len(k)
n = len(k) # number of convolutions
if equal_ch: # equal c_ per group
i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
c_ = [(i == g).sum() for g in range(n)] # intermediate channels
else: # equal weight.numel() per group
b = [c2] + [0] * groups
a = np.eye(groups + 1, groups, k=-1)
b = [c2] + [0] * n
a = np.eye(n + 1, n, k=-1)
a -= np.roll(a, 1, axis=1)
a *= np.array(k) ** 2
a[0] = 1
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b

self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
self.m = nn.ModuleList(
[nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.act = nn.SiLU()

def forward(self, x):
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))


class Ensemble(nn.ModuleList):

+ 1
- 1
utils/torch_utils.py View File

@@ -166,7 +166,7 @@ def initialize_weights(model):
elif t is nn.BatchNorm2d:
m.eps = 1e-3
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True



Loading…
Cancel
Save