You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 satır
4.0KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Experimental modules
  4. """
  5. import math
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from models.common import Conv
  10. from utils.downloads import attempt_download
  11. class Sum(nn.Module):
  12. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  13. def __init__(self, n, weight=False): # n: number of inputs
  14. super().__init__()
  15. self.weight = weight # apply weights boolean
  16. self.iter = range(n - 1) # iter object
  17. if weight:
  18. self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
  19. def forward(self, x):
  20. y = x[0] # no weight
  21. if self.weight:
  22. w = torch.sigmoid(self.w) * 2
  23. for i in self.iter:
  24. y = y + x[i + 1] * w[i]
  25. else:
  26. for i in self.iter:
  27. y = y + x[i + 1]
  28. return y
  29. class MixConv2d(nn.Module):
  30. # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
  31. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
  32. super().__init__()
  33. n = len(k) # number of convolutions
  34. if equal_ch: # equal c_ per group
  35. i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
  36. c_ = [(i == g).sum() for g in range(n)] # intermediate channels
  37. else: # equal weight.numel() per group
  38. b = [c2] + [0] * n
  39. a = np.eye(n + 1, n, k=-1)
  40. a -= np.roll(a, 1, axis=1)
  41. a *= np.array(k) ** 2
  42. a[0] = 1
  43. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  44. self.m = nn.ModuleList([
  45. nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
  46. self.bn = nn.BatchNorm2d(c2)
  47. self.act = nn.SiLU()
  48. def forward(self, x):
  49. return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
  50. class Ensemble(nn.ModuleList):
  51. # Ensemble of models
  52. def __init__(self):
  53. super().__init__()
  54. def forward(self, x, augment=False, profile=False, visualize=False):
  55. y = [module(x, augment, profile, visualize)[0] for module in self]
  56. # y = torch.stack(y).max(0)[0] # max ensemble
  57. # y = torch.stack(y).mean(0) # mean ensemble
  58. y = torch.cat(y, 1) # nms ensemble
  59. return y, None # inference, train output
  60. def attempt_load(weights, device=None, inplace=True, fuse=True):
  61. from models.yolo import Detect, Model
  62. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  63. model = Ensemble()
  64. for w in weights if isinstance(weights, list) else [weights]:
  65. ckpt = torch.load(attempt_download(w))
  66. ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
  67. model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
  68. # Compatibility updates
  69. for m in model.modules():
  70. t = type(m)
  71. if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
  72. m.inplace = inplace # torch 1.7.0 compatibility
  73. if t is Detect and not isinstance(m.anchor_grid, list):
  74. delattr(m, 'anchor_grid')
  75. setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
  76. elif t is Conv:
  77. m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility
  78. elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
  79. m.recompute_scale_factor = None # torch 1.11.0 compatibility
  80. if len(model) == 1:
  81. return model[-1] # return model
  82. print(f'Ensemble created with {weights}\n')
  83. for k in 'names', 'nc', 'yaml':
  84. setattr(model, k, getattr(model[0], k))
  85. model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
  86. assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
  87. return model # return ensemble