Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

592 Zeilen
27KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Common modules
  4. """
  5. import json
  6. import math
  7. import platform
  8. import warnings
  9. from copy import copy
  10. from pathlib import Path
  11. import cv2
  12. import numpy as np
  13. import pandas as pd
  14. import requests
  15. import torch
  16. import torch.nn as nn
  17. from PIL import Image
  18. from torch.cuda import amp
  19. from utils.datasets import exif_transpose, letterbox
  20. from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible,
  21. non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
  22. from utils.plots import Annotator, colors, save_one_box
  23. from utils.torch_utils import time_sync
  24. def autopad(k, p=None): # kernel, padding
  25. # Pad to 'same'
  26. if p is None:
  27. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  28. return p
  29. class Conv(nn.Module):
  30. # Standard convolution
  31. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  32. super().__init__()
  33. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  34. self.bn = nn.BatchNorm2d(c2)
  35. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  36. def forward(self, x):
  37. return self.act(self.bn(self.conv(x)))
  38. def forward_fuse(self, x):
  39. return self.act(self.conv(x))
  40. class DWConv(Conv):
  41. # Depth-wise convolution class
  42. def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  43. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  44. class TransformerLayer(nn.Module):
  45. # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
  46. def __init__(self, c, num_heads):
  47. super().__init__()
  48. self.q = nn.Linear(c, c, bias=False)
  49. self.k = nn.Linear(c, c, bias=False)
  50. self.v = nn.Linear(c, c, bias=False)
  51. self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
  52. self.fc1 = nn.Linear(c, c, bias=False)
  53. self.fc2 = nn.Linear(c, c, bias=False)
  54. def forward(self, x):
  55. x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
  56. x = self.fc2(self.fc1(x)) + x
  57. return x
  58. class TransformerBlock(nn.Module):
  59. # Vision Transformer https://arxiv.org/abs/2010.11929
  60. def __init__(self, c1, c2, num_heads, num_layers):
  61. super().__init__()
  62. self.conv = None
  63. if c1 != c2:
  64. self.conv = Conv(c1, c2)
  65. self.linear = nn.Linear(c2, c2) # learnable position embedding
  66. self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
  67. self.c2 = c2
  68. def forward(self, x):
  69. if self.conv is not None:
  70. x = self.conv(x)
  71. b, _, w, h = x.shape
  72. p = x.flatten(2).unsqueeze(0).transpose(0, 3).squeeze(3)
  73. return self.tr(p + self.linear(p)).unsqueeze(3).transpose(0, 3).reshape(b, self.c2, w, h)
  74. class Bottleneck(nn.Module):
  75. # Standard bottleneck
  76. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  77. super().__init__()
  78. c_ = int(c2 * e) # hidden channels
  79. self.cv1 = Conv(c1, c_, 1, 1)
  80. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  81. self.add = shortcut and c1 == c2
  82. def forward(self, x):
  83. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  84. class BottleneckCSP(nn.Module):
  85. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  86. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  87. super().__init__()
  88. c_ = int(c2 * e) # hidden channels
  89. self.cv1 = Conv(c1, c_, 1, 1)
  90. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  91. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  92. self.cv4 = Conv(2 * c_, c2, 1, 1)
  93. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  94. self.act = nn.SiLU()
  95. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  96. def forward(self, x):
  97. y1 = self.cv3(self.m(self.cv1(x)))
  98. y2 = self.cv2(x)
  99. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  100. class C3(nn.Module):
  101. # CSP Bottleneck with 3 convolutions
  102. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  103. super().__init__()
  104. c_ = int(c2 * e) # hidden channels
  105. self.cv1 = Conv(c1, c_, 1, 1)
  106. self.cv2 = Conv(c1, c_, 1, 1)
  107. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
  108. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  109. # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
  110. def forward(self, x):
  111. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
  112. class C3TR(C3):
  113. # C3 module with TransformerBlock()
  114. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  115. super().__init__(c1, c2, n, shortcut, g, e)
  116. c_ = int(c2 * e)
  117. self.m = TransformerBlock(c_, c_, 4, n)
  118. class C3SPP(C3):
  119. # C3 module with SPP()
  120. def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
  121. super().__init__(c1, c2, n, shortcut, g, e)
  122. c_ = int(c2 * e)
  123. self.m = SPP(c_, c_, k)
  124. class C3Ghost(C3):
  125. # C3 module with GhostBottleneck()
  126. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  127. super().__init__(c1, c2, n, shortcut, g, e)
  128. c_ = int(c2 * e) # hidden channels
  129. self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
  130. class SPP(nn.Module):
  131. # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
  132. def __init__(self, c1, c2, k=(5, 9, 13)):
  133. super().__init__()
  134. c_ = c1 // 2 # hidden channels
  135. self.cv1 = Conv(c1, c_, 1, 1)
  136. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  137. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  138. def forward(self, x):
  139. x = self.cv1(x)
  140. with warnings.catch_warnings():
  141. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  142. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  143. class SPPF(nn.Module):
  144. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
  145. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  146. super().__init__()
  147. c_ = c1 // 2 # hidden channels
  148. self.cv1 = Conv(c1, c_, 1, 1)
  149. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  150. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  151. def forward(self, x):
  152. x = self.cv1(x)
  153. with warnings.catch_warnings():
  154. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  155. y1 = self.m(x)
  156. y2 = self.m(y1)
  157. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
  158. class Focus(nn.Module):
  159. # Focus wh information into c-space
  160. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  161. super().__init__()
  162. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  163. # self.contract = Contract(gain=2)
  164. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  165. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  166. # return self.conv(self.contract(x))
  167. class GhostConv(nn.Module):
  168. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  169. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  170. super().__init__()
  171. c_ = c2 // 2 # hidden channels
  172. self.cv1 = Conv(c1, c_, k, s, None, g, act)
  173. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
  174. def forward(self, x):
  175. y = self.cv1(x)
  176. return torch.cat([y, self.cv2(y)], 1)
  177. class GhostBottleneck(nn.Module):
  178. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  179. def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
  180. super().__init__()
  181. c_ = c2 // 2
  182. self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
  183. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  184. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  185. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
  186. Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  187. def forward(self, x):
  188. return self.conv(x) + self.shortcut(x)
  189. class Contract(nn.Module):
  190. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  191. def __init__(self, gain=2):
  192. super().__init__()
  193. self.gain = gain
  194. def forward(self, x):
  195. b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
  196. s = self.gain
  197. x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
  198. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  199. return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
  200. class Expand(nn.Module):
  201. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  202. def __init__(self, gain=2):
  203. super().__init__()
  204. self.gain = gain
  205. def forward(self, x):
  206. b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  207. s = self.gain
  208. x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
  209. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  210. return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
  211. class Concat(nn.Module):
  212. # Concatenate a list of tensors along dimension
  213. def __init__(self, dimension=1):
  214. super().__init__()
  215. self.d = dimension
  216. def forward(self, x):
  217. return torch.cat(x, self.d)
  218. class DetectMultiBackend(nn.Module):
  219. # YOLOv5 MultiBackend class for python inference on various backends
  220. def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
  221. # Usage:
  222. # PyTorch: weights = *.pt
  223. # TorchScript: *.torchscript.pt
  224. # CoreML: *.mlmodel
  225. # TensorFlow: *_saved_model
  226. # TensorFlow: *.pb
  227. # TensorFlow Lite: *.tflite
  228. # ONNX Runtime: *.onnx
  229. # OpenCV DNN: *.onnx with dnn=True
  230. super().__init__()
  231. w = str(weights[0] if isinstance(weights, list) else weights)
  232. suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
  233. check_suffix(w, suffixes) # check weights have acceptable suffix
  234. pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
  235. jit = pt and 'torchscript' in w.lower()
  236. stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
  237. if jit: # TorchScript
  238. LOGGER.info(f'Loading {w} for TorchScript inference...')
  239. extra_files = {'config.txt': ''} # model metadata
  240. model = torch.jit.load(w, _extra_files=extra_files)
  241. if extra_files['config.txt']:
  242. d = json.loads(extra_files['config.txt']) # extra_files dict
  243. stride, names = int(d['stride']), d['names']
  244. elif pt: # PyTorch
  245. from models.experimental import attempt_load # scoped to avoid circular import
  246. model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
  247. stride = int(model.stride.max()) # model stride
  248. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  249. elif coreml: # CoreML *.mlmodel
  250. import coremltools as ct
  251. model = ct.models.MLModel(w)
  252. elif dnn: # ONNX OpenCV DNN
  253. LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
  254. check_requirements(('opencv-python>=4.5.4',))
  255. net = cv2.dnn.readNetFromONNX(w)
  256. elif onnx: # ONNX Runtime
  257. LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
  258. check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
  259. import onnxruntime
  260. session = onnxruntime.InferenceSession(w, None)
  261. else: # TensorFlow model (TFLite, pb, saved_model)
  262. import tensorflow as tf
  263. if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
  264. def wrap_frozen_graph(gd, inputs, outputs):
  265. x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
  266. return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
  267. tf.nest.map_structure(x.graph.as_graph_element, outputs))
  268. LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
  269. graph_def = tf.Graph().as_graph_def()
  270. graph_def.ParseFromString(open(w, 'rb').read())
  271. frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
  272. elif saved_model:
  273. LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...')
  274. model = tf.keras.models.load_model(w)
  275. elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
  276. if 'edgetpu' in w.lower():
  277. LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...')
  278. import tflite_runtime.interpreter as tfli
  279. delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
  280. 'Darwin': 'libedgetpu.1.dylib',
  281. 'Windows': 'edgetpu.dll'}[platform.system()]
  282. interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)])
  283. else:
  284. LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
  285. interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
  286. interpreter.allocate_tensors() # allocate
  287. input_details = interpreter.get_input_details() # inputs
  288. output_details = interpreter.get_output_details() # outputs
  289. self.__dict__.update(locals()) # assign all variables to self
  290. def forward(self, im, augment=False, visualize=False, val=False):
  291. # YOLOv5 MultiBackend inference
  292. b, ch, h, w = im.shape # batch, channel, height, width
  293. if self.pt: # PyTorch
  294. y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
  295. return y if val else y[0]
  296. elif self.coreml: # CoreML *.mlmodel
  297. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  298. im = Image.fromarray((im[0] * 255).astype('uint8'))
  299. # im = im.resize((192, 320), Image.ANTIALIAS)
  300. y = self.model.predict({'image': im}) # coordinates are xywh normalized
  301. box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
  302. conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
  303. y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
  304. elif self.onnx: # ONNX
  305. im = im.cpu().numpy() # torch to numpy
  306. if self.dnn: # ONNX OpenCV DNN
  307. self.net.setInput(im)
  308. y = self.net.forward()
  309. else: # ONNX Runtime
  310. y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
  311. else: # TensorFlow model (TFLite, pb, saved_model)
  312. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  313. if self.pb:
  314. y = self.frozen_func(x=self.tf.constant(im)).numpy()
  315. elif self.saved_model:
  316. y = self.model(im, training=False).numpy()
  317. elif self.tflite:
  318. input, output = self.input_details[0], self.output_details[0]
  319. int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
  320. if int8:
  321. scale, zero_point = input['quantization']
  322. im = (im / scale + zero_point).astype(np.uint8) # de-scale
  323. self.interpreter.set_tensor(input['index'], im)
  324. self.interpreter.invoke()
  325. y = self.interpreter.get_tensor(output['index'])
  326. if int8:
  327. scale, zero_point = output['quantization']
  328. y = (y.astype(np.float32) - zero_point) * scale # re-scale
  329. y[..., 0] *= w # x
  330. y[..., 1] *= h # y
  331. y[..., 2] *= w # w
  332. y[..., 3] *= h # h
  333. y = torch.tensor(y)
  334. return (y, []) if val else y
  335. class AutoShape(nn.Module):
  336. # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  337. conf = 0.25 # NMS confidence threshold
  338. iou = 0.45 # NMS IoU threshold
  339. classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
  340. multi_label = False # NMS multiple labels per box
  341. max_det = 1000 # maximum number of detections per image
  342. def __init__(self, model):
  343. super().__init__()
  344. self.model = model.eval()
  345. def autoshape(self):
  346. LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
  347. return self
  348. def _apply(self, fn):
  349. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  350. self = super()._apply(fn)
  351. m = self.model.model[-1] # Detect()
  352. m.stride = fn(m.stride)
  353. m.grid = list(map(fn, m.grid))
  354. if isinstance(m.anchor_grid, list):
  355. m.anchor_grid = list(map(fn, m.anchor_grid))
  356. return self
  357. @torch.no_grad()
  358. def forward(self, imgs, size=640, augment=False, profile=False):
  359. # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
  360. # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
  361. # URI: = 'https://ultralytics.com/images/zidane.jpg'
  362. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
  363. # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
  364. # numpy: = np.zeros((640,1280,3)) # HWC
  365. # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
  366. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  367. t = [time_sync()]
  368. p = next(self.model.parameters()) # for device and type
  369. if isinstance(imgs, torch.Tensor): # torch
  370. with amp.autocast(enabled=p.device.type != 'cpu'):
  371. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  372. # Pre-process
  373. n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
  374. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  375. for i, im in enumerate(imgs):
  376. f = f'image{i}' # filename
  377. if isinstance(im, (str, Path)): # filename or uri
  378. im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
  379. im = np.asarray(exif_transpose(im))
  380. elif isinstance(im, Image.Image): # PIL Image
  381. im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
  382. files.append(Path(f).with_suffix('.jpg').name)
  383. if im.shape[0] < 5: # image in CHW
  384. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  385. im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
  386. s = im.shape[:2] # HWC
  387. shape0.append(s) # image shape
  388. g = (size / max(s)) # gain
  389. shape1.append([y * g for y in s])
  390. imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
  391. shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
  392. x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
  393. x = np.stack(x, 0) if n > 1 else x[0][None] # stack
  394. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  395. x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
  396. t.append(time_sync())
  397. with amp.autocast(enabled=p.device.type != 'cpu'):
  398. # Inference
  399. y = self.model(x, augment, profile)[0] # forward
  400. t.append(time_sync())
  401. # Post-process
  402. y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes,
  403. multi_label=self.multi_label, max_det=self.max_det) # NMS
  404. for i in range(n):
  405. scale_coords(shape1, y[i][:, :4], shape0[i])
  406. t.append(time_sync())
  407. return Detections(imgs, y, files, t, self.names, x.shape)
  408. class Detections:
  409. # YOLOv5 detections class for inference results
  410. def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
  411. super().__init__()
  412. d = pred[0].device # device
  413. gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
  414. self.imgs = imgs # list of images as numpy arrays
  415. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  416. self.names = names # class names
  417. self.files = files # image filenames
  418. self.xyxy = pred # xyxy pixels
  419. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  420. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  421. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  422. self.n = len(self.pred) # number of images (batch size)
  423. self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
  424. self.s = shape # inference BCHW shape
  425. def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
  426. crops = []
  427. for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
  428. s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
  429. if pred.shape[0]:
  430. for c in pred[:, -1].unique():
  431. n = (pred[:, -1] == c).sum() # detections per class
  432. s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  433. if show or save or render or crop:
  434. annotator = Annotator(im, example=str(self.names))
  435. for *box, conf, cls in reversed(pred): # xyxy, confidence, class
  436. label = f'{self.names[int(cls)]} {conf:.2f}'
  437. if crop:
  438. file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
  439. crops.append({'box': box, 'conf': conf, 'cls': cls, 'label': label,
  440. 'im': save_one_box(box, im, file=file, save=save)})
  441. else: # all others
  442. annotator.box_label(box, label, color=colors(cls))
  443. im = annotator.im
  444. else:
  445. s += '(no detections)'
  446. im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
  447. if pprint:
  448. LOGGER.info(s.rstrip(', '))
  449. if show:
  450. im.show(self.files[i]) # show
  451. if save:
  452. f = self.files[i]
  453. im.save(save_dir / f) # save
  454. if i == self.n - 1:
  455. LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
  456. if render:
  457. self.imgs[i] = np.asarray(im)
  458. if crop:
  459. if save:
  460. LOGGER.info(f'Saved results to {save_dir}\n')
  461. return crops
  462. def print(self):
  463. self.display(pprint=True) # print results
  464. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
  465. self.t)
  466. def show(self):
  467. self.display(show=True) # show results
  468. def save(self, save_dir='runs/detect/exp'):
  469. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
  470. self.display(save=True, save_dir=save_dir) # save results
  471. def crop(self, save=True, save_dir='runs/detect/exp'):
  472. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
  473. return self.display(crop=True, save=save, save_dir=save_dir) # crop results
  474. def render(self):
  475. self.display(render=True) # render results
  476. return self.imgs
  477. def pandas(self):
  478. # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
  479. new = copy(self) # return copy
  480. ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
  481. cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
  482. for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
  483. a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
  484. setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
  485. return new
  486. def tolist(self):
  487. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  488. x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
  489. for d in x:
  490. for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  491. setattr(d, k, getattr(d, k)[0]) # pop out of list
  492. return x
  493. def __len__(self):
  494. return self.n
  495. class Classify(nn.Module):
  496. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  497. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  498. super().__init__()
  499. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  500. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  501. self.flat = nn.Flatten()
  502. def forward(self, x):
  503. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  504. return self.flat(self.conv(z)) # flatten to x(b,c2)