You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

309 lines
13KB

  1. # This file contains modules common to various models
  2. import math
  3. from pathlib import Path
  4. import numpy as np
  5. import requests
  6. import torch
  7. import torch.nn as nn
  8. from PIL import Image
  9. from utils.datasets import letterbox
  10. from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
  11. from utils.plots import color_list, plot_one_box
  12. def autopad(k, p=None): # kernel, padding
  13. # Pad to 'same'
  14. if p is None:
  15. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  16. return p
  17. def DWConv(c1, c2, k=1, s=1, act=True):
  18. # Depthwise convolution
  19. return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  20. class Conv(nn.Module):
  21. # Standard convolution
  22. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  23. super(Conv, self).__init__()
  24. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  25. self.bn = nn.BatchNorm2d(c2)
  26. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  27. def forward(self, x):
  28. return self.act(self.bn(self.conv(x)))
  29. def fuseforward(self, x):
  30. return self.act(self.conv(x))
  31. class Bottleneck(nn.Module):
  32. # Standard bottleneck
  33. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  34. super(Bottleneck, self).__init__()
  35. c_ = int(c2 * e) # hidden channels
  36. self.cv1 = Conv(c1, c_, 1, 1)
  37. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  38. self.add = shortcut and c1 == c2
  39. def forward(self, x):
  40. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  41. class BottleneckCSP(nn.Module):
  42. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  43. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  44. super(BottleneckCSP, self).__init__()
  45. c_ = int(c2 * e) # hidden channels
  46. self.cv1 = Conv(c1, c_, 1, 1)
  47. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  48. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  49. self.cv4 = Conv(2 * c_, c2, 1, 1)
  50. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  51. self.act = nn.LeakyReLU(0.1, inplace=True)
  52. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  53. def forward(self, x):
  54. y1 = self.cv3(self.m(self.cv1(x)))
  55. y2 = self.cv2(x)
  56. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  57. class C3(nn.Module):
  58. # CSP Bottleneck with 3 convolutions
  59. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  60. super(C3, self).__init__()
  61. c_ = int(c2 * e) # hidden channels
  62. self.cv1 = Conv(c1, c_, 1, 1)
  63. self.cv2 = Conv(c1, c_, 1, 1)
  64. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
  65. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  66. # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
  67. def forward(self, x):
  68. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
  69. class SPP(nn.Module):
  70. # Spatial pyramid pooling layer used in YOLOv3-SPP
  71. def __init__(self, c1, c2, k=(5, 9, 13)):
  72. super(SPP, self).__init__()
  73. c_ = c1 // 2 # hidden channels
  74. self.cv1 = Conv(c1, c_, 1, 1)
  75. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  76. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  77. def forward(self, x):
  78. x = self.cv1(x)
  79. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  80. class Focus(nn.Module):
  81. # Focus wh information into c-space
  82. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  83. super(Focus, self).__init__()
  84. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  85. # self.contract = Contract(gain=2)
  86. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  87. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  88. # return self.conv(self.contract(x))
  89. class Contract(nn.Module):
  90. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  91. def __init__(self, gain=2):
  92. super().__init__()
  93. self.gain = gain
  94. def forward(self, x):
  95. N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
  96. s = self.gain
  97. x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
  98. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  99. return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
  100. class Expand(nn.Module):
  101. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  102. def __init__(self, gain=2):
  103. super().__init__()
  104. self.gain = gain
  105. def forward(self, x):
  106. N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  107. s = self.gain
  108. x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
  109. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  110. return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
  111. class Concat(nn.Module):
  112. # Concatenate a list of tensors along dimension
  113. def __init__(self, dimension=1):
  114. super(Concat, self).__init__()
  115. self.d = dimension
  116. def forward(self, x):
  117. return torch.cat(x, self.d)
  118. class NMS(nn.Module):
  119. # Non-Maximum Suppression (NMS) module
  120. conf = 0.25 # confidence threshold
  121. iou = 0.45 # IoU threshold
  122. classes = None # (optional list) filter by class
  123. def __init__(self):
  124. super(NMS, self).__init__()
  125. def forward(self, x):
  126. return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
  127. class autoShape(nn.Module):
  128. # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  129. img_size = 640 # inference size (pixels)
  130. conf = 0.25 # NMS confidence threshold
  131. iou = 0.45 # NMS IoU threshold
  132. classes = None # (optional list) filter by class
  133. def __init__(self, model):
  134. super(autoShape, self).__init__()
  135. self.model = model.eval()
  136. def autoshape(self):
  137. print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
  138. return self
  139. def forward(self, imgs, size=640, augment=False, profile=False):
  140. # Inference from various sources. For height=720, width=1280, RGB images example inputs are:
  141. # filename: imgs = 'data/samples/zidane.jpg'
  142. # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
  143. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
  144. # PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
  145. # numpy: = np.zeros((720,1280,3)) # HWC
  146. # torch: = torch.zeros(16,3,720,1280) # BCHW
  147. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  148. p = next(self.model.parameters()) # for device and type
  149. if isinstance(imgs, torch.Tensor): # torch
  150. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  151. # Pre-process
  152. n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
  153. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  154. for i, im in enumerate(imgs):
  155. if isinstance(im, str): # filename or uri
  156. im, f = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im), im # open
  157. im.filename = f # for uri
  158. files.append(Path(im.filename).with_suffix('.jpg').name if isinstance(im, Image.Image) else f'image{i}.jpg')
  159. im = np.array(im) # to numpy
  160. if im.shape[0] < 5: # image in CHW
  161. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  162. im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
  163. s = im.shape[:2] # HWC
  164. shape0.append(s) # image shape
  165. g = (size / max(s)) # gain
  166. shape1.append([y * g for y in s])
  167. imgs[i] = im # update
  168. shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
  169. x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
  170. x = np.stack(x, 0) if n > 1 else x[0][None] # stack
  171. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  172. x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
  173. # Inference
  174. with torch.no_grad():
  175. y = self.model(x, augment, profile)[0] # forward
  176. y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
  177. # Post-process
  178. for i in range(n):
  179. scale_coords(shape1, y[i][:, :4], shape0[i])
  180. return Detections(imgs, y, files, self.names)
  181. class Detections:
  182. # detections class for YOLOv5 inference results
  183. def __init__(self, imgs, pred, files, names=None):
  184. super(Detections, self).__init__()
  185. d = pred[0].device # device
  186. gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
  187. self.imgs = imgs # list of images as numpy arrays
  188. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  189. self.names = names # class names
  190. self.files = files # image filenames
  191. self.xyxy = pred # xyxy pixels
  192. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  193. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  194. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  195. self.n = len(self.pred)
  196. def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
  197. colors = color_list()
  198. for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
  199. str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
  200. if pred is not None:
  201. for c in pred[:, -1].unique():
  202. n = (pred[:, -1] == c).sum() # detections per class
  203. str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  204. if show or save or render:
  205. for *box, conf, cls in pred: # xyxy, confidence, class
  206. label = f'{self.names[int(cls)]} {conf:.2f}'
  207. plot_one_box(box, img, label=label, color=colors[int(cls) % 10])
  208. img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
  209. if pprint:
  210. print(str.rstrip(', '))
  211. if show:
  212. img.show(self.files[i]) # show
  213. if save:
  214. f = Path(save_dir) / self.files[i]
  215. img.save(f) # save
  216. print(f"{'Saving' * (i == 0)} {f},", end='' if i < self.n - 1 else ' done.\n')
  217. if render:
  218. self.imgs[i] = np.asarray(img)
  219. def print(self):
  220. self.display(pprint=True) # print results
  221. def show(self):
  222. self.display(show=True) # show results
  223. def save(self, save_dir='results/'):
  224. Path(save_dir).mkdir(exist_ok=True)
  225. self.display(save=True, save_dir=save_dir) # save results
  226. def render(self):
  227. self.display(render=True) # render results
  228. return self.imgs
  229. def __len__(self):
  230. return self.n
  231. def tolist(self):
  232. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  233. x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
  234. for d in x:
  235. for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  236. setattr(d, k, getattr(d, k)[0]) # pop out of list
  237. return x
  238. class Classify(nn.Module):
  239. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  240. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  241. super(Classify, self).__init__()
  242. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  243. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  244. self.flat = nn.Flatten()
  245. def forward(self, x):
  246. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  247. return self.flat(self.conv(z)) # flatten to x(b,c2)