You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

406 lines
17KB

  1. # YOLOv5 common modules
  2. import math
  3. from copy import copy
  4. from pathlib import Path
  5. import numpy as np
  6. import pandas as pd
  7. import requests
  8. import torch
  9. import torch.nn as nn
  10. from PIL import Image
  11. from torch.cuda import amp
  12. from utils.datasets import letterbox
  13. from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
  14. from utils.plots import color_list, plot_one_box
  15. from utils.torch_utils import time_synchronized
  16. import warnings
  17. class SPPF(nn.Module):
  18. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
  19. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  20. super().__init__()
  21. c_ = c1 // 2 # hidden channels
  22. self.cv1 = Conv(c1, c_, 1, 1)
  23. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  24. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  25. def forward(self, x):
  26. x = self.cv1(x)
  27. with warnings.catch_warnings():
  28. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  29. y1 = self.m(x)
  30. y2 = self.m(y1)
  31. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
  32. def autopad(k, p=None): # kernel, padding
  33. # Pad to 'same'
  34. if p is None:
  35. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  36. return p
  37. def DWConv(c1, c2, k=1, s=1, act=True):
  38. # Depthwise convolution
  39. return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  40. class Conv(nn.Module):
  41. # Standard convolution
  42. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  43. super(Conv, self).__init__()
  44. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  45. self.bn = nn.BatchNorm2d(c2)
  46. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  47. def forward(self, x):
  48. return self.act(self.bn(self.conv(x)))
  49. def fuseforward(self, x):
  50. return self.act(self.conv(x))
  51. class TransformerLayer(nn.Module):
  52. # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
  53. def __init__(self, c, num_heads):
  54. super().__init__()
  55. self.q = nn.Linear(c, c, bias=False)
  56. self.k = nn.Linear(c, c, bias=False)
  57. self.v = nn.Linear(c, c, bias=False)
  58. self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
  59. self.fc1 = nn.Linear(c, c, bias=False)
  60. self.fc2 = nn.Linear(c, c, bias=False)
  61. def forward(self, x):
  62. x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
  63. x = self.fc2(self.fc1(x)) + x
  64. return x
  65. class TransformerBlock(nn.Module):
  66. # Vision Transformer https://arxiv.org/abs/2010.11929
  67. def __init__(self, c1, c2, num_heads, num_layers):
  68. super().__init__()
  69. self.conv = None
  70. if c1 != c2:
  71. self.conv = Conv(c1, c2)
  72. self.linear = nn.Linear(c2, c2) # learnable position embedding
  73. self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
  74. self.c2 = c2
  75. def forward(self, x):
  76. if self.conv is not None:
  77. x = self.conv(x)
  78. b, _, w, h = x.shape
  79. p = x.flatten(2)
  80. p = p.unsqueeze(0)
  81. p = p.transpose(0, 3)
  82. p = p.squeeze(3)
  83. e = self.linear(p)
  84. x = p + e
  85. x = self.tr(x)
  86. x = x.unsqueeze(3)
  87. x = x.transpose(0, 3)
  88. x = x.reshape(b, self.c2, w, h)
  89. return x
  90. class Bottleneck(nn.Module):
  91. # Standard bottleneck
  92. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  93. super(Bottleneck, self).__init__()
  94. c_ = int(c2 * e) # hidden channels
  95. self.cv1 = Conv(c1, c_, 1, 1)
  96. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  97. self.add = shortcut and c1 == c2
  98. def forward(self, x):
  99. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  100. class BottleneckCSP(nn.Module):
  101. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  102. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  103. super(BottleneckCSP, self).__init__()
  104. c_ = int(c2 * e) # hidden channels
  105. self.cv1 = Conv(c1, c_, 1, 1)
  106. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  107. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  108. self.cv4 = Conv(2 * c_, c2, 1, 1)
  109. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  110. self.act = nn.LeakyReLU(0.1, inplace=True)
  111. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  112. def forward(self, x):
  113. y1 = self.cv3(self.m(self.cv1(x)))
  114. y2 = self.cv2(x)
  115. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  116. class C3(nn.Module):
  117. # CSP Bottleneck with 3 convolutions
  118. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  119. super(C3, self).__init__()
  120. c_ = int(c2 * e) # hidden channels
  121. self.cv1 = Conv(c1, c_, 1, 1)
  122. self.cv2 = Conv(c1, c_, 1, 1)
  123. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
  124. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  125. # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
  126. def forward(self, x):
  127. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
  128. class C3TR(C3):
  129. # C3 module with TransformerBlock()
  130. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  131. super().__init__(c1, c2, n, shortcut, g, e)
  132. c_ = int(c2 * e)
  133. self.m = TransformerBlock(c_, c_, 4, n)
  134. class SPP(nn.Module):
  135. # Spatial pyramid pooling layer used in YOLOv3-SPP
  136. def __init__(self, c1, c2, k=(5, 9, 13)):
  137. super(SPP, self).__init__()
  138. c_ = c1 // 2 # hidden channels
  139. self.cv1 = Conv(c1, c_, 1, 1)
  140. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  141. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  142. def forward(self, x):
  143. x = self.cv1(x)
  144. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  145. class Focus(nn.Module):
  146. # Focus wh information into c-space
  147. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  148. super(Focus, self).__init__()
  149. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  150. # self.contract = Contract(gain=2)
  151. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  152. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  153. # return self.conv(self.contract(x))
  154. class Contract(nn.Module):
  155. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  156. def __init__(self, gain=2):
  157. super().__init__()
  158. self.gain = gain
  159. def forward(self, x):
  160. N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
  161. s = self.gain
  162. x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
  163. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  164. return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
  165. class Expand(nn.Module):
  166. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  167. def __init__(self, gain=2):
  168. super().__init__()
  169. self.gain = gain
  170. def forward(self, x):
  171. N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  172. s = self.gain
  173. x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
  174. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  175. return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
  176. class Concat(nn.Module):
  177. # Concatenate a list of tensors along dimension
  178. def __init__(self, dimension=1):
  179. super(Concat, self).__init__()
  180. self.d = dimension
  181. def forward(self, x):
  182. return torch.cat(x, self.d)
  183. class NMS(nn.Module):
  184. # Non-Maximum Suppression (NMS) module
  185. conf = 0.25 # confidence threshold
  186. iou = 0.45 # IoU threshold
  187. classes = None # (optional list) filter by class
  188. def __init__(self):
  189. super(NMS, self).__init__()
  190. def forward(self, x):
  191. return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
  192. class autoShape(nn.Module):
  193. # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  194. conf = 0.25 # NMS confidence threshold
  195. iou = 0.45 # NMS IoU threshold
  196. classes = None # (optional list) filter by class
  197. def __init__(self, model):
  198. super(autoShape, self).__init__()
  199. self.model = model.eval()
  200. def autoshape(self):
  201. print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
  202. return self
  203. @torch.no_grad()
  204. def forward(self, imgs, size=640, augment=False, profile=False):
  205. # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
  206. # filename: imgs = 'data/images/zidane.jpg'
  207. # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
  208. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
  209. # PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
  210. # numpy: = np.zeros((640,1280,3)) # HWC
  211. # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
  212. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  213. t = [time_synchronized()]
  214. p = next(self.model.parameters()) # for device and type
  215. if isinstance(imgs, torch.Tensor): # torch
  216. with amp.autocast(enabled=p.device.type != 'cpu'):
  217. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  218. # Pre-process
  219. n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
  220. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  221. for i, im in enumerate(imgs):
  222. f = f'image{i}' # filename
  223. if isinstance(im, str): # filename or uri
  224. im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
  225. elif isinstance(im, Image.Image): # PIL Image
  226. im, f = np.asarray(im), getattr(im, 'filename', f) or f
  227. files.append(Path(f).with_suffix('.jpg').name)
  228. if im.shape[0] < 5: # image in CHW
  229. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  230. im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
  231. s = im.shape[:2] # HWC
  232. shape0.append(s) # image shape
  233. g = (size / max(s)) # gain
  234. shape1.append([y * g for y in s])
  235. imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
  236. shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
  237. x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
  238. x = np.stack(x, 0) if n > 1 else x[0][None] # stack
  239. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  240. x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
  241. t.append(time_synchronized())
  242. with amp.autocast(enabled=p.device.type != 'cpu'):
  243. # Inference
  244. y = self.model(x, augment, profile)[0] # forward
  245. t.append(time_synchronized())
  246. # Post-process
  247. y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
  248. for i in range(n):
  249. scale_coords(shape1, y[i][:, :4], shape0[i])
  250. t.append(time_synchronized())
  251. return Detections(imgs, y, files, t, self.names, x.shape)
  252. class Detections:
  253. # detections class for YOLOv5 inference results
  254. def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
  255. super(Detections, self).__init__()
  256. d = pred[0].device # device
  257. gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
  258. self.imgs = imgs # list of images as numpy arrays
  259. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  260. self.names = names # class names
  261. self.files = files # image filenames
  262. self.xyxy = pred # xyxy pixels
  263. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  264. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  265. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  266. self.n = len(self.pred) # number of images (batch size)
  267. self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
  268. self.s = shape # inference BCHW shape
  269. def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
  270. colors = color_list()
  271. for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
  272. str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
  273. if pred is not None:
  274. for c in pred[:, -1].unique():
  275. n = (pred[:, -1] == c).sum() # detections per class
  276. str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  277. if show or save or render:
  278. for *box, conf, cls in pred: # xyxy, confidence, class
  279. label = f'{self.names[int(cls)]} {conf:.2f}'
  280. plot_one_box(box, img, label=label, color=colors[int(cls) % 10])
  281. img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
  282. if pprint:
  283. print(str.rstrip(', '))
  284. if show:
  285. img.show(self.files[i]) # show
  286. if save:
  287. f = self.files[i]
  288. img.save(Path(save_dir) / f) # save
  289. print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
  290. if render:
  291. self.imgs[i] = np.asarray(img)
  292. def print(self):
  293. self.display(pprint=True) # print results
  294. print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
  295. def show(self):
  296. self.display(show=True) # show results
  297. def save(self, save_dir='runs/hub/exp'):
  298. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp') # increment save_dir
  299. Path(save_dir).mkdir(parents=True, exist_ok=True)
  300. self.display(save=True, save_dir=save_dir) # save results
  301. def render(self):
  302. self.display(render=True) # render results
  303. return self.imgs
  304. def pandas(self):
  305. # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
  306. new = copy(self) # return copy
  307. ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
  308. cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
  309. for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
  310. a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
  311. setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
  312. return new
  313. def tolist(self):
  314. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  315. x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
  316. for d in x:
  317. for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  318. setattr(d, k, getattr(d, k)[0]) # pop out of list
  319. return x
  320. def __len__(self):
  321. return self.n
  322. class Classify(nn.Module):
  323. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  324. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  325. super(Classify, self).__init__()
  326. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  327. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  328. self.flat = nn.Flatten()
  329. def forward(self, x):
  330. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  331. return self.flat(self.conv(z)) # flatten to x(b,c2)