您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

317 行
13KB

  1. # This file contains modules common to various models
  2. import math
  3. from pathlib import Path
  4. import numpy as np
  5. import requests
  6. import torch
  7. import torch.nn as nn
  8. from PIL import Image
  9. from utils.datasets import letterbox
  10. from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
  11. from utils.plots import color_list, plot_one_box
  12. from utils.torch_utils import time_synchronized
  13. def autopad(k, p=None): # kernel, padding
  14. # Pad to 'same'
  15. if p is None:
  16. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  17. return p
  18. def DWConv(c1, c2, k=1, s=1, act=True):
  19. # Depthwise convolution
  20. return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  21. class Conv(nn.Module):
  22. # Standard convolution
  23. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  24. super(Conv, self).__init__()
  25. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  26. self.bn = nn.BatchNorm2d(c2)
  27. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  28. def forward(self, x):
  29. return self.act(self.bn(self.conv(x)))
  30. def fuseforward(self, x):
  31. return self.act(self.conv(x))
  32. class Bottleneck(nn.Module):
  33. # Standard bottleneck
  34. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  35. super(Bottleneck, self).__init__()
  36. c_ = int(c2 * e) # hidden channels
  37. self.cv1 = Conv(c1, c_, 1, 1)
  38. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  39. self.add = shortcut and c1 == c2
  40. def forward(self, x):
  41. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  42. class BottleneckCSP(nn.Module):
  43. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  44. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  45. super(BottleneckCSP, self).__init__()
  46. c_ = int(c2 * e) # hidden channels
  47. self.cv1 = Conv(c1, c_, 1, 1)
  48. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  49. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  50. self.cv4 = Conv(2 * c_, c2, 1, 1)
  51. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  52. self.act = nn.LeakyReLU(0.1, inplace=True)
  53. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  54. def forward(self, x):
  55. y1 = self.cv3(self.m(self.cv1(x)))
  56. y2 = self.cv2(x)
  57. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  58. class C3(nn.Module):
  59. # CSP Bottleneck with 3 convolutions
  60. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  61. super(C3, self).__init__()
  62. c_ = int(c2 * e) # hidden channels
  63. self.cv1 = Conv(c1, c_, 1, 1)
  64. self.cv2 = Conv(c1, c_, 1, 1)
  65. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
  66. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  67. # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
  68. def forward(self, x):
  69. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
  70. class SPP(nn.Module):
  71. # Spatial pyramid pooling layer used in YOLOv3-SPP
  72. def __init__(self, c1, c2, k=(5, 9, 13)):
  73. super(SPP, self).__init__()
  74. c_ = c1 // 2 # hidden channels
  75. self.cv1 = Conv(c1, c_, 1, 1)
  76. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  77. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  78. def forward(self, x):
  79. x = self.cv1(x)
  80. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  81. class Focus(nn.Module):
  82. # Focus wh information into c-space
  83. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  84. super(Focus, self).__init__()
  85. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  86. # self.contract = Contract(gain=2)
  87. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  88. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  89. # return self.conv(self.contract(x))
  90. class Contract(nn.Module):
  91. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  92. def __init__(self, gain=2):
  93. super().__init__()
  94. self.gain = gain
  95. def forward(self, x):
  96. N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
  97. s = self.gain
  98. x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
  99. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  100. return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
  101. class Expand(nn.Module):
  102. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  103. def __init__(self, gain=2):
  104. super().__init__()
  105. self.gain = gain
  106. def forward(self, x):
  107. N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  108. s = self.gain
  109. x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
  110. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  111. return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
  112. class Concat(nn.Module):
  113. # Concatenate a list of tensors along dimension
  114. def __init__(self, dimension=1):
  115. super(Concat, self).__init__()
  116. self.d = dimension
  117. def forward(self, x):
  118. return torch.cat(x, self.d)
  119. class NMS(nn.Module):
  120. # Non-Maximum Suppression (NMS) module
  121. conf = 0.25 # confidence threshold
  122. iou = 0.45 # IoU threshold
  123. classes = None # (optional list) filter by class
  124. def __init__(self):
  125. super(NMS, self).__init__()
  126. def forward(self, x):
  127. return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
  128. class autoShape(nn.Module):
  129. # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  130. conf = 0.25 # NMS confidence threshold
  131. iou = 0.45 # NMS IoU threshold
  132. classes = None # (optional list) filter by class
  133. def __init__(self, model):
  134. super(autoShape, self).__init__()
  135. self.model = model.eval()
  136. def autoshape(self):
  137. print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
  138. return self
  139. def forward(self, imgs, size=640, augment=False, profile=False):
  140. # Inference from various sources. For height=720, width=1280, RGB images example inputs are:
  141. # filename: imgs = 'data/samples/zidane.jpg'
  142. # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
  143. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
  144. # PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
  145. # numpy: = np.zeros((720,1280,3)) # HWC
  146. # torch: = torch.zeros(16,3,720,1280) # BCHW
  147. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  148. t = [time_synchronized()]
  149. p = next(self.model.parameters()) # for device and type
  150. if isinstance(imgs, torch.Tensor): # torch
  151. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  152. # Pre-process
  153. n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
  154. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  155. for i, im in enumerate(imgs):
  156. if isinstance(im, str): # filename or uri
  157. im, f = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im), im # open
  158. im.filename = f # for uri
  159. files.append(Path(im.filename).with_suffix('.jpg').name if isinstance(im, Image.Image) else f'image{i}.jpg')
  160. im = np.array(im) # to numpy
  161. if im.shape[0] < 5: # image in CHW
  162. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  163. im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
  164. s = im.shape[:2] # HWC
  165. shape0.append(s) # image shape
  166. g = (size / max(s)) # gain
  167. shape1.append([y * g for y in s])
  168. imgs[i] = im # update
  169. shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
  170. x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
  171. x = np.stack(x, 0) if n > 1 else x[0][None] # stack
  172. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  173. x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
  174. t.append(time_synchronized())
  175. # Inference
  176. with torch.no_grad():
  177. y = self.model(x, augment, profile)[0] # forward
  178. t.append(time_synchronized())
  179. # Post-process
  180. y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
  181. for i in range(n):
  182. scale_coords(shape1, y[i][:, :4], shape0[i])
  183. t.append(time_synchronized())
  184. return Detections(imgs, y, files, t, self.names, x.shape)
  185. class Detections:
  186. # detections class for YOLOv5 inference results
  187. def __init__(self, imgs, pred, files, times, names=None, shape=None):
  188. super(Detections, self).__init__()
  189. d = pred[0].device # device
  190. gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
  191. self.imgs = imgs # list of images as numpy arrays
  192. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  193. self.names = names # class names
  194. self.files = files # image filenames
  195. self.xyxy = pred # xyxy pixels
  196. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  197. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  198. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  199. self.n = len(self.pred)
  200. self.t = ((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
  201. self.s = shape # inference BCHW shape
  202. def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
  203. colors = color_list()
  204. for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
  205. str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
  206. if pred is not None:
  207. for c in pred[:, -1].unique():
  208. n = (pred[:, -1] == c).sum() # detections per class
  209. str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  210. if show or save or render:
  211. for *box, conf, cls in pred: # xyxy, confidence, class
  212. label = f'{self.names[int(cls)]} {conf:.2f}'
  213. plot_one_box(box, img, label=label, color=colors[int(cls) % 10])
  214. img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
  215. if pprint:
  216. print(str.rstrip(', '))
  217. if show:
  218. img.show(self.files[i]) # show
  219. if save:
  220. f = Path(save_dir) / self.files[i]
  221. img.save(f) # save
  222. print(f"{'Saving' * (i == 0)} {f},", end='' if i < self.n - 1 else ' done.\n')
  223. if render:
  224. self.imgs[i] = np.asarray(img)
  225. def print(self):
  226. self.display(pprint=True) # print results
  227. print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
  228. tuple(self.t))
  229. def show(self):
  230. self.display(show=True) # show results
  231. def save(self, save_dir='results/'):
  232. Path(save_dir).mkdir(exist_ok=True)
  233. self.display(save=True, save_dir=save_dir) # save results
  234. def render(self):
  235. self.display(render=True) # render results
  236. return self.imgs
  237. def __len__(self):
  238. return self.n
  239. def tolist(self):
  240. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  241. x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
  242. for d in x:
  243. for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  244. setattr(d, k, getattr(d, k)[0]) # pop out of list
  245. return x
  246. class Classify(nn.Module):
  247. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  248. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  249. super(Classify, self).__init__()
  250. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  251. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  252. self.flat = nn.Flatten()
  253. def forward(self, x):
  254. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  255. return self.flat(self.conv(z)) # flatten to x(b,c2)