Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

120 lines
4.4KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Experimental modules
  4. """
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. from models.common import Conv
  9. from utils.downloads import attempt_download
  10. class CrossConv(nn.Module):
  11. # Cross Convolution Downsample
  12. def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
  13. # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
  14. super().__init__()
  15. c_ = int(c2 * e) # hidden channels
  16. self.cv1 = Conv(c1, c_, (1, k), (1, s))
  17. self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
  18. self.add = shortcut and c1 == c2
  19. def forward(self, x):
  20. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  21. class Sum(nn.Module):
  22. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  23. def __init__(self, n, weight=False): # n: number of inputs
  24. super().__init__()
  25. self.weight = weight # apply weights boolean
  26. self.iter = range(n - 1) # iter object
  27. if weight:
  28. self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
  29. def forward(self, x):
  30. y = x[0] # no weight
  31. if self.weight:
  32. w = torch.sigmoid(self.w) * 2
  33. for i in self.iter:
  34. y = y + x[i + 1] * w[i]
  35. else:
  36. for i in self.iter:
  37. y = y + x[i + 1]
  38. return y
  39. class MixConv2d(nn.Module):
  40. # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
  41. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
  42. super().__init__()
  43. groups = len(k)
  44. if equal_ch: # equal c_ per group
  45. i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
  46. c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
  47. else: # equal weight.numel() per group
  48. b = [c2] + [0] * groups
  49. a = np.eye(groups + 1, groups, k=-1)
  50. a -= np.roll(a, 1, axis=1)
  51. a *= np.array(k) ** 2
  52. a[0] = 1
  53. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  54. self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
  55. self.bn = nn.BatchNorm2d(c2)
  56. self.act = nn.LeakyReLU(0.1, inplace=True)
  57. def forward(self, x):
  58. return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
  59. class Ensemble(nn.ModuleList):
  60. # Ensemble of models
  61. def __init__(self):
  62. super().__init__()
  63. def forward(self, x, augment=False, profile=False, visualize=False):
  64. y = []
  65. for module in self:
  66. y.append(module(x, augment, profile, visualize)[0])
  67. # y = torch.stack(y).max(0)[0] # max ensemble
  68. # y = torch.stack(y).mean(0) # mean ensemble
  69. y = torch.cat(y, 1) # nms ensemble
  70. return y, None # inference, train output
  71. def attempt_load(weights, map_location=None, inplace=True, fuse=True):
  72. from models.yolo import Detect, Model
  73. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  74. model = Ensemble()
  75. for w in weights if isinstance(weights, list) else [weights]:
  76. ckpt = torch.load(attempt_download(w), map_location=map_location) # load
  77. if fuse:
  78. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
  79. else:
  80. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse
  81. # Compatibility updates
  82. for m in model.modules():
  83. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
  84. m.inplace = inplace # pytorch 1.7.0 compatibility
  85. if type(m) is Detect:
  86. if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
  87. delattr(m, 'anchor_grid')
  88. setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
  89. elif type(m) is Conv:
  90. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  91. if len(model) == 1:
  92. return model[-1] # return model
  93. else:
  94. print(f'Ensemble created with {weights}\n')
  95. for k in ['names']:
  96. setattr(model, k, getattr(model[-1], k))
  97. model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
  98. return model # return ensemble