You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

experimental.py 3.3KB

4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from models.common import *
  2. class Sum(nn.Module):
  3. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  4. def __init__(self, n, weight=False): # n: number of inputs
  5. super(Sum, self).__init__()
  6. self.weight = weight # apply weights boolean
  7. self.iter = range(n - 1) # iter object
  8. if weight:
  9. self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
  10. def forward(self, x):
  11. y = x[0] # no weight
  12. if self.weight:
  13. w = torch.sigmoid(self.w) * 2
  14. for i in self.iter:
  15. y = y + x[i + 1] * w[i]
  16. else:
  17. for i in self.iter:
  18. y = y + x[i + 1]
  19. return y
  20. class GhostConv(nn.Module):
  21. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  22. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  23. super(GhostConv, self).__init__()
  24. c_ = c2 // 2 # hidden channels
  25. self.cv1 = Conv(c1, c_, k, s, g, act)
  26. self.cv2 = Conv(c_, c_, 5, 1, c_, act)
  27. def forward(self, x):
  28. y = self.cv1(x)
  29. return torch.cat([y, self.cv2(y)], 1)
  30. class GhostBottleneck(nn.Module):
  31. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  32. def __init__(self, c1, c2, k, s):
  33. super(GhostBottleneck, self).__init__()
  34. c_ = c2 // 2
  35. self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
  36. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  37. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  38. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
  39. Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  40. def forward(self, x):
  41. return self.conv(x) + self.shortcut(x)
  42. class ConvPlus(nn.Module):
  43. # Plus-shaped convolution
  44. def __init__(self, c1, c2, k=3, s=1, g=1, bias=True): # ch_in, ch_out, kernel, stride, groups
  45. super(ConvPlus, self).__init__()
  46. self.cv1 = nn.Conv2d(c1, c2, (k, 1), s, (k // 2, 0), groups=g, bias=bias)
  47. self.cv2 = nn.Conv2d(c1, c2, (1, k), s, (0, k // 2), groups=g, bias=bias)
  48. def forward(self, x):
  49. return self.cv1(x) + self.cv2(x)
  50. class MixConv2d(nn.Module):
  51. # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
  52. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
  53. super(MixConv2d, self).__init__()
  54. groups = len(k)
  55. if equal_ch: # equal c_ per group
  56. i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
  57. c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
  58. else: # equal weight.numel() per group
  59. b = [c2] + [0] * groups
  60. a = np.eye(groups + 1, groups, k=-1)
  61. a -= np.roll(a, 1, axis=1)
  62. a *= np.array(k) ** 2
  63. a[0] = 1
  64. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  65. self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
  66. self.bn = nn.BatchNorm2d(c2)
  67. self.act = nn.LeakyReLU(0.1, inplace=True)
  68. def forward(self, x):
  69. return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))