Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

624 lines
29KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Common modules
  4. """
  5. import json
  6. import math
  7. import platform
  8. import warnings
  9. from collections import OrderedDict, namedtuple
  10. from copy import copy
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import pandas as pd
  15. import requests
  16. import torch
  17. import torch.nn as nn
  18. from PIL import Image
  19. from torch.cuda import amp
  20. from utils.datasets import exif_transpose, letterbox
  21. from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible,
  22. non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
  23. from utils.plots import Annotator, colors, save_one_box
  24. from utils.torch_utils import copy_attr, time_sync
  25. def autopad(k, p=None): # kernel, padding
  26. # Pad to 'same'
  27. if p is None:
  28. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  29. return p
  30. class Conv(nn.Module):
  31. # Standard convolution
  32. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  33. super().__init__()
  34. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  35. self.bn = nn.BatchNorm2d(c2)
  36. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  37. def forward(self, x):
  38. return self.act(self.bn(self.conv(x)))
  39. def forward_fuse(self, x):
  40. return self.act(self.conv(x))
  41. class DWConv(Conv):
  42. # Depth-wise convolution class
  43. def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  44. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  45. class TransformerLayer(nn.Module):
  46. # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
  47. def __init__(self, c, num_heads):
  48. super().__init__()
  49. self.q = nn.Linear(c, c, bias=False)
  50. self.k = nn.Linear(c, c, bias=False)
  51. self.v = nn.Linear(c, c, bias=False)
  52. self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
  53. self.fc1 = nn.Linear(c, c, bias=False)
  54. self.fc2 = nn.Linear(c, c, bias=False)
  55. def forward(self, x):
  56. x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
  57. x = self.fc2(self.fc1(x)) + x
  58. return x
  59. class TransformerBlock(nn.Module):
  60. # Vision Transformer https://arxiv.org/abs/2010.11929
  61. def __init__(self, c1, c2, num_heads, num_layers):
  62. super().__init__()
  63. self.conv = None
  64. if c1 != c2:
  65. self.conv = Conv(c1, c2)
  66. self.linear = nn.Linear(c2, c2) # learnable position embedding
  67. self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
  68. self.c2 = c2
  69. def forward(self, x):
  70. if self.conv is not None:
  71. x = self.conv(x)
  72. b, _, w, h = x.shape
  73. p = x.flatten(2).permute(2, 0, 1)
  74. return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
  75. class Bottleneck(nn.Module):
  76. # Standard bottleneck
  77. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  78. super().__init__()
  79. c_ = int(c2 * e) # hidden channels
  80. self.cv1 = Conv(c1, c_, 1, 1)
  81. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  82. self.add = shortcut and c1 == c2
  83. def forward(self, x):
  84. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  85. class BottleneckCSP(nn.Module):
  86. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  87. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  88. super().__init__()
  89. c_ = int(c2 * e) # hidden channels
  90. self.cv1 = Conv(c1, c_, 1, 1)
  91. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  92. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  93. self.cv4 = Conv(2 * c_, c2, 1, 1)
  94. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  95. self.act = nn.SiLU()
  96. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  97. def forward(self, x):
  98. y1 = self.cv3(self.m(self.cv1(x)))
  99. y2 = self.cv2(x)
  100. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  101. class C3(nn.Module):
  102. # CSP Bottleneck with 3 convolutions
  103. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  104. super().__init__()
  105. c_ = int(c2 * e) # hidden channels
  106. self.cv1 = Conv(c1, c_, 1, 1)
  107. self.cv2 = Conv(c1, c_, 1, 1)
  108. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
  109. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  110. # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
  111. def forward(self, x):
  112. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
  113. class C3TR(C3):
  114. # C3 module with TransformerBlock()
  115. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  116. super().__init__(c1, c2, n, shortcut, g, e)
  117. c_ = int(c2 * e)
  118. self.m = TransformerBlock(c_, c_, 4, n)
  119. class C3SPP(C3):
  120. # C3 module with SPP()
  121. def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
  122. super().__init__(c1, c2, n, shortcut, g, e)
  123. c_ = int(c2 * e)
  124. self.m = SPP(c_, c_, k)
  125. class C3Ghost(C3):
  126. # C3 module with GhostBottleneck()
  127. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  128. super().__init__(c1, c2, n, shortcut, g, e)
  129. c_ = int(c2 * e) # hidden channels
  130. self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
  131. class SPP(nn.Module):
  132. # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
  133. def __init__(self, c1, c2, k=(5, 9, 13)):
  134. super().__init__()
  135. c_ = c1 // 2 # hidden channels
  136. self.cv1 = Conv(c1, c_, 1, 1)
  137. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  138. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  139. def forward(self, x):
  140. x = self.cv1(x)
  141. with warnings.catch_warnings():
  142. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  143. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  144. class SPPF(nn.Module):
  145. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
  146. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  147. super().__init__()
  148. c_ = c1 // 2 # hidden channels
  149. self.cv1 = Conv(c1, c_, 1, 1)
  150. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  151. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  152. def forward(self, x):
  153. x = self.cv1(x)
  154. with warnings.catch_warnings():
  155. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  156. y1 = self.m(x)
  157. y2 = self.m(y1)
  158. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
  159. class Focus(nn.Module):
  160. # Focus wh information into c-space
  161. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  162. super().__init__()
  163. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  164. # self.contract = Contract(gain=2)
  165. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  166. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  167. # return self.conv(self.contract(x))
  168. class GhostConv(nn.Module):
  169. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  170. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  171. super().__init__()
  172. c_ = c2 // 2 # hidden channels
  173. self.cv1 = Conv(c1, c_, k, s, None, g, act)
  174. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
  175. def forward(self, x):
  176. y = self.cv1(x)
  177. return torch.cat([y, self.cv2(y)], 1)
  178. class GhostBottleneck(nn.Module):
  179. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  180. def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
  181. super().__init__()
  182. c_ = c2 // 2
  183. self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
  184. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  185. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  186. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
  187. Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  188. def forward(self, x):
  189. return self.conv(x) + self.shortcut(x)
  190. class Contract(nn.Module):
  191. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  192. def __init__(self, gain=2):
  193. super().__init__()
  194. self.gain = gain
  195. def forward(self, x):
  196. b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
  197. s = self.gain
  198. x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
  199. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  200. return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
  201. class Expand(nn.Module):
  202. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  203. def __init__(self, gain=2):
  204. super().__init__()
  205. self.gain = gain
  206. def forward(self, x):
  207. b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  208. s = self.gain
  209. x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
  210. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  211. return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
  212. class Concat(nn.Module):
  213. # Concatenate a list of tensors along dimension
  214. def __init__(self, dimension=1):
  215. super().__init__()
  216. self.d = dimension
  217. def forward(self, x):
  218. return torch.cat(x, self.d)
  219. class DetectMultiBackend(nn.Module):
  220. # YOLOv5 MultiBackend class for python inference on various backends
  221. def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
  222. # Usage:
  223. # PyTorch: weights = *.pt
  224. # TorchScript: *.torchscript
  225. # CoreML: *.mlmodel
  226. # TensorFlow: *_saved_model
  227. # TensorFlow: *.pb
  228. # TensorFlow Lite: *.tflite
  229. # ONNX Runtime: *.onnx
  230. # OpenCV DNN: *.onnx with dnn=True
  231. # TensorRT: *.engine
  232. super().__init__()
  233. w = str(weights[0] if isinstance(weights, list) else weights)
  234. suffix = Path(w).suffix.lower()
  235. suffixes = ['.pt', '.torchscript', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
  236. check_suffix(w, suffixes) # check weights have acceptable suffix
  237. pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
  238. stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
  239. if jit: # TorchScript
  240. LOGGER.info(f'Loading {w} for TorchScript inference...')
  241. extra_files = {'config.txt': ''} # model metadata
  242. model = torch.jit.load(w, _extra_files=extra_files)
  243. if extra_files['config.txt']:
  244. d = json.loads(extra_files['config.txt']) # extra_files dict
  245. stride, names = int(d['stride']), d['names']
  246. elif pt: # PyTorch
  247. from models.experimental import attempt_load # scoped to avoid circular import
  248. model = attempt_load(weights, map_location=device)
  249. stride = int(model.stride.max()) # model stride
  250. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  251. elif coreml: # CoreML
  252. import coremltools as ct
  253. model = ct.models.MLModel(w)
  254. elif dnn: # ONNX OpenCV DNN
  255. LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
  256. check_requirements(('opencv-python>=4.5.4',))
  257. net = cv2.dnn.readNetFromONNX(w)
  258. elif onnx: # ONNX Runtime
  259. LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
  260. check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
  261. import onnxruntime
  262. session = onnxruntime.InferenceSession(w, None)
  263. elif engine: # TensorRT
  264. LOGGER.info(f'Loading {w} for TensorRT inference...')
  265. import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
  266. Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
  267. logger = trt.Logger(trt.Logger.INFO)
  268. with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
  269. model = runtime.deserialize_cuda_engine(f.read())
  270. bindings = OrderedDict()
  271. for index in range(model.num_bindings):
  272. name = model.get_binding_name(index)
  273. dtype = trt.nptype(model.get_binding_dtype(index))
  274. shape = tuple(model.get_binding_shape(index))
  275. data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
  276. bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
  277. binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
  278. context = model.create_execution_context()
  279. batch_size = bindings['images'].shape[0]
  280. else: # TensorFlow model (TFLite, pb, saved_model)
  281. if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
  282. LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
  283. import tensorflow as tf
  284. def wrap_frozen_graph(gd, inputs, outputs):
  285. x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
  286. return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
  287. tf.nest.map_structure(x.graph.as_graph_element, outputs))
  288. graph_def = tf.Graph().as_graph_def()
  289. graph_def.ParseFromString(open(w, 'rb').read())
  290. frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
  291. elif saved_model:
  292. LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...')
  293. import tensorflow as tf
  294. model = tf.keras.models.load_model(w)
  295. elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
  296. if 'edgetpu' in w.lower():
  297. LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
  298. import tflite_runtime.interpreter as tfli
  299. delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
  300. 'Darwin': 'libedgetpu.1.dylib',
  301. 'Windows': 'edgetpu.dll'}[platform.system()]
  302. interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)])
  303. else:
  304. LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
  305. import tensorflow as tf
  306. interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
  307. interpreter.allocate_tensors() # allocate
  308. input_details = interpreter.get_input_details() # inputs
  309. output_details = interpreter.get_output_details() # outputs
  310. self.__dict__.update(locals()) # assign all variables to self
  311. def forward(self, im, augment=False, visualize=False, val=False):
  312. # YOLOv5 MultiBackend inference
  313. b, ch, h, w = im.shape # batch, channel, height, width
  314. if self.pt: # PyTorch
  315. y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
  316. return y if val else y[0]
  317. elif self.coreml: # CoreML *.mlmodel
  318. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  319. im = Image.fromarray((im[0] * 255).astype('uint8'))
  320. # im = im.resize((192, 320), Image.ANTIALIAS)
  321. y = self.model.predict({'image': im}) # coordinates are xywh normalized
  322. box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
  323. conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
  324. y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
  325. elif self.onnx: # ONNX
  326. im = im.cpu().numpy() # torch to numpy
  327. if self.dnn: # ONNX OpenCV DNN
  328. self.net.setInput(im)
  329. y = self.net.forward()
  330. else: # ONNX Runtime
  331. y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
  332. elif self.engine: # TensorRT
  333. assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
  334. self.binding_addrs['images'] = int(im.data_ptr())
  335. self.context.execute_v2(list(self.binding_addrs.values()))
  336. y = self.bindings['output'].data
  337. else: # TensorFlow model (TFLite, pb, saved_model)
  338. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  339. if self.pb:
  340. y = self.frozen_func(x=self.tf.constant(im)).numpy()
  341. elif self.saved_model:
  342. y = self.model(im, training=False).numpy()
  343. elif self.tflite:
  344. input, output = self.input_details[0], self.output_details[0]
  345. int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
  346. if int8:
  347. scale, zero_point = input['quantization']
  348. im = (im / scale + zero_point).astype(np.uint8) # de-scale
  349. self.interpreter.set_tensor(input['index'], im)
  350. self.interpreter.invoke()
  351. y = self.interpreter.get_tensor(output['index'])
  352. if int8:
  353. scale, zero_point = output['quantization']
  354. y = (y.astype(np.float32) - zero_point) * scale # re-scale
  355. y[..., 0] *= w # x
  356. y[..., 1] *= h # y
  357. y[..., 2] *= w # w
  358. y[..., 3] *= h # h
  359. y = torch.tensor(y) if isinstance(y, np.ndarray) else y
  360. return (y, []) if val else y
  361. def warmup(self, imgsz=(1, 3, 640, 640), half=False):
  362. # Warmup model by running inference once
  363. if self.pt or self.engine or self.onnx: # warmup types
  364. if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
  365. im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
  366. self.forward(im) # warmup
  367. class AutoShape(nn.Module):
  368. # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  369. conf = 0.25 # NMS confidence threshold
  370. iou = 0.45 # NMS IoU threshold
  371. classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
  372. multi_label = False # NMS multiple labels per box
  373. max_det = 1000 # maximum number of detections per image
  374. def __init__(self, model):
  375. super().__init__()
  376. LOGGER.info('Adding AutoShape... ')
  377. copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
  378. self.model = model.eval()
  379. def _apply(self, fn):
  380. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  381. self = super()._apply(fn)
  382. m = self.model.model[-1] # Detect()
  383. m.stride = fn(m.stride)
  384. m.grid = list(map(fn, m.grid))
  385. if isinstance(m.anchor_grid, list):
  386. m.anchor_grid = list(map(fn, m.anchor_grid))
  387. return self
  388. @torch.no_grad()
  389. def forward(self, imgs, size=640, augment=False, profile=False):
  390. # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
  391. # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
  392. # URI: = 'https://ultralytics.com/images/zidane.jpg'
  393. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
  394. # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
  395. # numpy: = np.zeros((640,1280,3)) # HWC
  396. # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
  397. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  398. t = [time_sync()]
  399. p = next(self.model.parameters()) # for device and type
  400. if isinstance(imgs, torch.Tensor): # torch
  401. with amp.autocast(enabled=p.device.type != 'cpu'):
  402. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  403. # Pre-process
  404. n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
  405. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  406. for i, im in enumerate(imgs):
  407. f = f'image{i}' # filename
  408. if isinstance(im, (str, Path)): # filename or uri
  409. im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
  410. im = np.asarray(exif_transpose(im))
  411. elif isinstance(im, Image.Image): # PIL Image
  412. im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
  413. files.append(Path(f).with_suffix('.jpg').name)
  414. if im.shape[0] < 5: # image in CHW
  415. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  416. im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
  417. s = im.shape[:2] # HWC
  418. shape0.append(s) # image shape
  419. g = (size / max(s)) # gain
  420. shape1.append([y * g for y in s])
  421. imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
  422. shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
  423. x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
  424. x = np.stack(x, 0) if n > 1 else x[0][None] # stack
  425. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  426. x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
  427. t.append(time_sync())
  428. with amp.autocast(enabled=p.device.type != 'cpu'):
  429. # Inference
  430. y = self.model(x, augment, profile)[0] # forward
  431. t.append(time_sync())
  432. # Post-process
  433. y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes,
  434. multi_label=self.multi_label, max_det=self.max_det) # NMS
  435. for i in range(n):
  436. scale_coords(shape1, y[i][:, :4], shape0[i])
  437. t.append(time_sync())
  438. return Detections(imgs, y, files, t, self.names, x.shape)
  439. class Detections:
  440. # YOLOv5 detections class for inference results
  441. def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
  442. super().__init__()
  443. d = pred[0].device # device
  444. gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
  445. self.imgs = imgs # list of images as numpy arrays
  446. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  447. self.names = names # class names
  448. self.files = files # image filenames
  449. self.xyxy = pred # xyxy pixels
  450. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  451. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  452. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  453. self.n = len(self.pred) # number of images (batch size)
  454. self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
  455. self.s = shape # inference BCHW shape
  456. def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
  457. crops = []
  458. for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
  459. s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
  460. if pred.shape[0]:
  461. for c in pred[:, -1].unique():
  462. n = (pred[:, -1] == c).sum() # detections per class
  463. s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  464. if show or save or render or crop:
  465. annotator = Annotator(im, example=str(self.names))
  466. for *box, conf, cls in reversed(pred): # xyxy, confidence, class
  467. label = f'{self.names[int(cls)]} {conf:.2f}'
  468. if crop:
  469. file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
  470. crops.append({'box': box, 'conf': conf, 'cls': cls, 'label': label,
  471. 'im': save_one_box(box, im, file=file, save=save)})
  472. else: # all others
  473. annotator.box_label(box, label, color=colors(cls))
  474. im = annotator.im
  475. else:
  476. s += '(no detections)'
  477. im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
  478. if pprint:
  479. LOGGER.info(s.rstrip(', '))
  480. if show:
  481. im.show(self.files[i]) # show
  482. if save:
  483. f = self.files[i]
  484. im.save(save_dir / f) # save
  485. if i == self.n - 1:
  486. LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
  487. if render:
  488. self.imgs[i] = np.asarray(im)
  489. if crop:
  490. if save:
  491. LOGGER.info(f'Saved results to {save_dir}\n')
  492. return crops
  493. def print(self):
  494. self.display(pprint=True) # print results
  495. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
  496. self.t)
  497. def show(self):
  498. self.display(show=True) # show results
  499. def save(self, save_dir='runs/detect/exp'):
  500. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
  501. self.display(save=True, save_dir=save_dir) # save results
  502. def crop(self, save=True, save_dir='runs/detect/exp'):
  503. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
  504. return self.display(crop=True, save=save, save_dir=save_dir) # crop results
  505. def render(self):
  506. self.display(render=True) # render results
  507. return self.imgs
  508. def pandas(self):
  509. # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
  510. new = copy(self) # return copy
  511. ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
  512. cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
  513. for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
  514. a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
  515. setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
  516. return new
  517. def tolist(self):
  518. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  519. x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
  520. for d in x:
  521. for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  522. setattr(d, k, getattr(d, k)[0]) # pop out of list
  523. return x
  524. def __len__(self):
  525. return self.n
  526. class Classify(nn.Module):
  527. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  528. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  529. super().__init__()
  530. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  531. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  532. self.flat = nn.Flatten()
  533. def forward(self, x):
  534. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  535. return self.flat(self.conv(z)) # flatten to x(b,c2)