无人机视角的行人小目标检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
4.5KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Experimental modules
  4. """
  5. import math
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from models.common import Conv
  10. from utils.downloads import attempt_download
  11. class CrossConv(nn.Module):
  12. # Cross Convolution Downsample
  13. def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
  14. # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
  15. super().__init__()
  16. c_ = int(c2 * e) # hidden channels
  17. self.cv1 = Conv(c1, c_, (1, k), (1, s))
  18. self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
  19. self.add = shortcut and c1 == c2
  20. def forward(self, x):
  21. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  22. class Sum(nn.Module):
  23. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  24. def __init__(self, n, weight=False): # n: number of inputs
  25. super().__init__()
  26. self.weight = weight # apply weights boolean
  27. self.iter = range(n - 1) # iter object
  28. if weight:
  29. self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
  30. def forward(self, x):
  31. y = x[0] # no weight
  32. if self.weight:
  33. w = torch.sigmoid(self.w) * 2
  34. for i in self.iter:
  35. y = y + x[i + 1] * w[i]
  36. else:
  37. for i in self.iter:
  38. y = y + x[i + 1]
  39. return y
  40. class MixConv2d(nn.Module):
  41. # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
  42. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
  43. super().__init__()
  44. n = len(k) # number of convolutions
  45. if equal_ch: # equal c_ per group
  46. i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
  47. c_ = [(i == g).sum() for g in range(n)] # intermediate channels
  48. else: # equal weight.numel() per group
  49. b = [c2] + [0] * n
  50. a = np.eye(n + 1, n, k=-1)
  51. a -= np.roll(a, 1, axis=1)
  52. a *= np.array(k) ** 2
  53. a[0] = 1
  54. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  55. self.m = nn.ModuleList(
  56. [nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
  57. self.bn = nn.BatchNorm2d(c2)
  58. self.act = nn.SiLU()
  59. def forward(self, x):
  60. return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
  61. class Ensemble(nn.ModuleList):
  62. # Ensemble of models
  63. def __init__(self):
  64. super().__init__()
  65. def forward(self, x, augment=False, profile=False, visualize=False):
  66. y = []
  67. for module in self:
  68. y.append(module(x, augment, profile, visualize)[0])
  69. # y = torch.stack(y).max(0)[0] # max ensemble
  70. # y = torch.stack(y).mean(0) # mean ensemble
  71. y = torch.cat(y, 1) # nms ensemble
  72. return y, None # inference, train output
  73. def attempt_load(weights, map_location=None, inplace=True, fuse=True):
  74. from models.yolo import Detect, Model
  75. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  76. model = Ensemble()
  77. for w in weights if isinstance(weights, list) else [weights]:
  78. ckpt = torch.load(attempt_download(w), map_location=map_location) # load
  79. if fuse:
  80. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
  81. else:
  82. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse
  83. # Compatibility updates
  84. for m in model.modules():
  85. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
  86. m.inplace = inplace # pytorch 1.7.0 compatibility
  87. if type(m) is Detect:
  88. if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
  89. delattr(m, 'anchor_grid')
  90. setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
  91. elif type(m) is Conv:
  92. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  93. if len(model) == 1:
  94. return model[-1] # return model
  95. else:
  96. print(f'Ensemble created with {weights}\n')
  97. for k in ['names']:
  98. setattr(model, k, getattr(model[-1], k))
  99. model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
  100. return model # return ensemble