You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
7.1KB

  1. # This file contains modules common to various models
  2. import math
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from utils.datasets import letterbox
  7. from utils.general import non_max_suppression, make_divisible, scale_coords
  8. def autopad(k, p=None): # kernel, padding
  9. # Pad to 'same'
  10. if p is None:
  11. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  12. return p
  13. def DWConv(c1, c2, k=1, s=1, act=True):
  14. # Depthwise convolution
  15. return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  16. class Conv(nn.Module):
  17. # Standard convolution
  18. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  19. super(Conv, self).__init__()
  20. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  21. self.bn = nn.BatchNorm2d(c2)
  22. self.act = nn.Hardswish() if act else nn.Identity()
  23. def forward(self, x):
  24. return self.act(self.bn(self.conv(x)))
  25. def fuseforward(self, x):
  26. return self.act(self.conv(x))
  27. class Bottleneck(nn.Module):
  28. # Standard bottleneck
  29. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  30. super(Bottleneck, self).__init__()
  31. c_ = int(c2 * e) # hidden channels
  32. self.cv1 = Conv(c1, c_, 1, 1)
  33. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  34. self.add = shortcut and c1 == c2
  35. def forward(self, x):
  36. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  37. class BottleneckCSP(nn.Module):
  38. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  39. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  40. super(BottleneckCSP, self).__init__()
  41. c_ = int(c2 * e) # hidden channels
  42. self.cv1 = Conv(c1, c_, 1, 1)
  43. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  44. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  45. self.cv4 = Conv(2 * c_, c2, 1, 1)
  46. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  47. self.act = nn.LeakyReLU(0.1, inplace=True)
  48. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  49. def forward(self, x):
  50. y1 = self.cv3(self.m(self.cv1(x)))
  51. y2 = self.cv2(x)
  52. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  53. class SPP(nn.Module):
  54. # Spatial pyramid pooling layer used in YOLOv3-SPP
  55. def __init__(self, c1, c2, k=(5, 9, 13)):
  56. super(SPP, self).__init__()
  57. c_ = c1 // 2 # hidden channels
  58. self.cv1 = Conv(c1, c_, 1, 1)
  59. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  60. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  61. def forward(self, x):
  62. x = self.cv1(x)
  63. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  64. class Focus(nn.Module):
  65. # Focus wh information into c-space
  66. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  67. super(Focus, self).__init__()
  68. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  69. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  70. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  71. class Concat(nn.Module):
  72. # Concatenate a list of tensors along dimension
  73. def __init__(self, dimension=1):
  74. super(Concat, self).__init__()
  75. self.d = dimension
  76. def forward(self, x):
  77. return torch.cat(x, self.d)
  78. class NMS(nn.Module):
  79. # Non-Maximum Suppression (NMS) module
  80. conf = 0.25 # confidence threshold
  81. iou = 0.45 # IoU threshold
  82. classes = None # (optional list) filter by class
  83. def __init__(self):
  84. super(NMS, self).__init__()
  85. def forward(self, x):
  86. return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
  87. class autoShape(nn.Module):
  88. # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  89. img_size = 640 # inference size (pixels)
  90. conf = 0.25 # NMS confidence threshold
  91. iou = 0.45 # NMS IoU threshold
  92. classes = None # (optional list) filter by class
  93. def __init__(self, model):
  94. super(autoShape, self).__init__()
  95. self.model = model
  96. def forward(self, x, size=640, augment=False, profile=False):
  97. # supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
  98. # opencv: x = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
  99. # PIL: x = Image.open('image.jpg') # HWC x(720,1280,3)
  100. # numpy: x = np.zeros((720,1280,3)) # HWC
  101. # torch: x = torch.zeros(16,3,720,1280) # BCHW
  102. # multiple: x = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  103. p = next(self.model.parameters()) # for device and type
  104. if isinstance(x, torch.Tensor): # torch
  105. return self.model(x.to(p.device).type_as(p), augment, profile) # inference
  106. # Pre-process
  107. if not isinstance(x, list):
  108. x = [x]
  109. shape0, shape1 = [], [] # image and inference shapes
  110. batch = range(len(x)) # batch size
  111. for i in batch:
  112. x[i] = np.array(x[i]) # to numpy
  113. x[i] = x[i][:, :, :3] if x[i].ndim == 3 else np.tile(x[i][:, :, None], 3) # enforce 3ch input
  114. s = x[i].shape[:2] # HWC
  115. shape0.append(s) # image shape
  116. g = (size / max(s)) # gain
  117. shape1.append([y * g for y in s])
  118. shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
  119. x = [letterbox(x[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
  120. x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
  121. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  122. x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
  123. # Inference
  124. x = self.model(x, augment, profile) # forward
  125. x = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
  126. # Post-process
  127. for i in batch:
  128. if x[i] is not None:
  129. x[i][:, :4] = scale_coords(shape1, x[i][:, :4], shape0[i])
  130. return x
  131. class Flatten(nn.Module):
  132. # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
  133. @staticmethod
  134. def forward(x):
  135. return x.view(x.size(0), -1)
  136. class Classify(nn.Module):
  137. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  138. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  139. super(Classify, self).__init__()
  140. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  141. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) # to x(b,c2,1,1)
  142. self.flat = Flatten()
  143. def forward(self, x):
  144. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  145. return self.flat(self.conv(z)) # flatten to x(b,c2)