You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
4.2KB

  1. # This file contains experimental modules
  2. from models.common import *
  3. class CrossConv(nn.Module):
  4. # Cross Convolution Downsample
  5. def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
  6. # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
  7. super(CrossConv, self).__init__()
  8. c_ = int(c2 * e) # hidden channels
  9. self.cv1 = Conv(c1, c_, (1, k), (1, s))
  10. self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
  11. self.add = shortcut and c1 == c2
  12. def forward(self, x):
  13. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  14. class C3(nn.Module):
  15. # Cross Convolution CSP
  16. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  17. super(C3, self).__init__()
  18. c_ = int(c2 * e) # hidden channels
  19. self.cv1 = Conv(c1, c_, 1, 1)
  20. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  21. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  22. self.cv4 = Conv(2 * c_, c2, 1, 1)
  23. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  24. self.act = nn.LeakyReLU(0.1, inplace=True)
  25. self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
  26. def forward(self, x):
  27. y1 = self.cv3(self.m(self.cv1(x)))
  28. y2 = self.cv2(x)
  29. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  30. class Sum(nn.Module):
  31. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  32. def __init__(self, n, weight=False): # n: number of inputs
  33. super(Sum, self).__init__()
  34. self.weight = weight # apply weights boolean
  35. self.iter = range(n - 1) # iter object
  36. if weight:
  37. self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
  38. def forward(self, x):
  39. y = x[0] # no weight
  40. if self.weight:
  41. w = torch.sigmoid(self.w) * 2
  42. for i in self.iter:
  43. y = y + x[i + 1] * w[i]
  44. else:
  45. for i in self.iter:
  46. y = y + x[i + 1]
  47. return y
  48. class GhostConv(nn.Module):
  49. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  50. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  51. super(GhostConv, self).__init__()
  52. c_ = c2 // 2 # hidden channels
  53. self.cv1 = Conv(c1, c_, k, s, g, act)
  54. self.cv2 = Conv(c_, c_, 5, 1, c_, act)
  55. def forward(self, x):
  56. y = self.cv1(x)
  57. return torch.cat([y, self.cv2(y)], 1)
  58. class GhostBottleneck(nn.Module):
  59. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  60. def __init__(self, c1, c2, k, s):
  61. super(GhostBottleneck, self).__init__()
  62. c_ = c2 // 2
  63. self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
  64. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  65. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  66. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
  67. Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  68. def forward(self, x):
  69. return self.conv(x) + self.shortcut(x)
  70. class MixConv2d(nn.Module):
  71. # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
  72. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
  73. super(MixConv2d, self).__init__()
  74. groups = len(k)
  75. if equal_ch: # equal c_ per group
  76. i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
  77. c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
  78. else: # equal weight.numel() per group
  79. b = [c2] + [0] * groups
  80. a = np.eye(groups + 1, groups, k=-1)
  81. a -= np.roll(a, 1, axis=1)
  82. a *= np.array(k) ** 2
  83. a[0] = 1
  84. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  85. self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
  86. self.bn = nn.BatchNorm2d(c2)
  87. self.act = nn.LeakyReLU(0.1, inplace=True)
  88. def forward(self, x):
  89. return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))