Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

253 lines
10.0KB

  1. # This file contains modules common to various models
  2. import math
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from PIL import Image, ImageDraw
  7. from utils.datasets import letterbox
  8. from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
  9. from utils.plots import color_list
  10. def autopad(k, p=None): # kernel, padding
  11. # Pad to 'same'
  12. if p is None:
  13. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  14. return p
  15. def DWConv(c1, c2, k=1, s=1, act=True):
  16. # Depthwise convolution
  17. return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  18. class Conv(nn.Module):
  19. # Standard convolution
  20. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  21. super(Conv, self).__init__()
  22. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  23. self.bn = nn.BatchNorm2d(c2)
  24. self.act = nn.Hardswish() if act else nn.Identity()
  25. def forward(self, x):
  26. return self.act(self.bn(self.conv(x)))
  27. def fuseforward(self, x):
  28. return self.act(self.conv(x))
  29. class Bottleneck(nn.Module):
  30. # Standard bottleneck
  31. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  32. super(Bottleneck, self).__init__()
  33. c_ = int(c2 * e) # hidden channels
  34. self.cv1 = Conv(c1, c_, 1, 1)
  35. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  36. self.add = shortcut and c1 == c2
  37. def forward(self, x):
  38. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  39. class BottleneckCSP(nn.Module):
  40. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  41. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  42. super(BottleneckCSP, self).__init__()
  43. c_ = int(c2 * e) # hidden channels
  44. self.cv1 = Conv(c1, c_, 1, 1)
  45. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  46. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  47. self.cv4 = Conv(2 * c_, c2, 1, 1)
  48. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  49. self.act = nn.LeakyReLU(0.1, inplace=True)
  50. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  51. def forward(self, x):
  52. y1 = self.cv3(self.m(self.cv1(x)))
  53. y2 = self.cv2(x)
  54. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  55. class SPP(nn.Module):
  56. # Spatial pyramid pooling layer used in YOLOv3-SPP
  57. def __init__(self, c1, c2, k=(5, 9, 13)):
  58. super(SPP, self).__init__()
  59. c_ = c1 // 2 # hidden channels
  60. self.cv1 = Conv(c1, c_, 1, 1)
  61. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  62. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  63. def forward(self, x):
  64. x = self.cv1(x)
  65. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  66. class Focus(nn.Module):
  67. # Focus wh information into c-space
  68. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  69. super(Focus, self).__init__()
  70. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  71. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  72. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  73. class Concat(nn.Module):
  74. # Concatenate a list of tensors along dimension
  75. def __init__(self, dimension=1):
  76. super(Concat, self).__init__()
  77. self.d = dimension
  78. def forward(self, x):
  79. return torch.cat(x, self.d)
  80. class NMS(nn.Module):
  81. # Non-Maximum Suppression (NMS) module
  82. conf = 0.25 # confidence threshold
  83. iou = 0.45 # IoU threshold
  84. classes = None # (optional list) filter by class
  85. def __init__(self):
  86. super(NMS, self).__init__()
  87. def forward(self, x):
  88. return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
  89. class autoShape(nn.Module):
  90. # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  91. img_size = 640 # inference size (pixels)
  92. conf = 0.25 # NMS confidence threshold
  93. iou = 0.45 # NMS IoU threshold
  94. classes = None # (optional list) filter by class
  95. def __init__(self, model):
  96. super(autoShape, self).__init__()
  97. self.model = model.eval()
  98. def forward(self, imgs, size=640, augment=False, profile=False):
  99. # supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
  100. # opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
  101. # PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3)
  102. # numpy: imgs = np.zeros((720,1280,3)) # HWC
  103. # torch: imgs = torch.zeros(16,3,720,1280) # BCHW
  104. # multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  105. p = next(self.model.parameters()) # for device and type
  106. if isinstance(imgs, torch.Tensor): # torch
  107. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  108. # Pre-process
  109. if not isinstance(imgs, list):
  110. imgs = [imgs]
  111. shape0, shape1 = [], [] # image and inference shapes
  112. batch = range(len(imgs)) # batch size
  113. for i in batch:
  114. imgs[i] = np.array(imgs[i]) # to numpy
  115. if imgs[i].shape[0] < 5: # image in CHW
  116. imgs[i] = imgs[i].transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  117. imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input
  118. s = imgs[i].shape[:2] # HWC
  119. shape0.append(s) # image shape
  120. g = (size / max(s)) # gain
  121. shape1.append([y * g for y in s])
  122. shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
  123. x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
  124. x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
  125. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  126. x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
  127. # Inference
  128. with torch.no_grad():
  129. y = self.model(x, augment, profile)[0] # forward
  130. y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
  131. # Post-process
  132. for i in batch:
  133. if y[i] is not None:
  134. y[i][:, :4] = scale_coords(shape1, y[i][:, :4], shape0[i])
  135. return Detections(imgs, y, self.names)
  136. class Detections:
  137. # detections class for YOLOv5 inference results
  138. def __init__(self, imgs, pred, names=None):
  139. super(Detections, self).__init__()
  140. self.imgs = imgs # list of images as numpy arrays
  141. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  142. self.names = names # class names
  143. self.xyxy = pred # xyxy pixels
  144. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  145. d = pred[0].device # device
  146. gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
  147. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  148. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  149. self.n = len(self.pred)
  150. def display(self, pprint=False, show=False, save=False):
  151. colors = color_list()
  152. for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
  153. str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
  154. if pred is not None:
  155. for c in pred[:, -1].unique():
  156. n = (pred[:, -1] == c).sum() # detections per class
  157. str += f'{n} {self.names[int(c)]}s, ' # add to string
  158. if show or save:
  159. img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
  160. for *box, conf, cls in pred: # xyxy, confidence, class
  161. # str += '%s %.2f, ' % (names[int(cls)], conf) # label
  162. ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot
  163. if save:
  164. f = f'results{i}.jpg'
  165. str += f"saved to '{f}'"
  166. img.save(f) # save
  167. if show:
  168. img.show(f'Image {i}') # show
  169. if pprint:
  170. print(str)
  171. def print(self):
  172. self.display(pprint=True) # print results
  173. def show(self):
  174. self.display(show=True) # show results
  175. def save(self):
  176. self.display(save=True) # save results
  177. def __len__(self):
  178. return self.n
  179. def tolist(self):
  180. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  181. x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
  182. for d in x:
  183. for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  184. setattr(d, k, getattr(d, k)[0]) # pop out of list
  185. return x
  186. class Flatten(nn.Module):
  187. # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
  188. @staticmethod
  189. def forward(x):
  190. return x.view(x.size(0), -1)
  191. class Classify(nn.Module):
  192. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  193. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  194. super(Classify, self).__init__()
  195. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  196. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) # to x(b,c2,1,1)
  197. self.flat = Flatten()
  198. def forward(self, x):
  199. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  200. return self.flat(self.conv(z)) # flatten to x(b,c2)