Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

685 lines
33KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Common modules
  4. """
  5. import json
  6. import math
  7. import platform
  8. import warnings
  9. from collections import OrderedDict, namedtuple
  10. from copy import copy
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import pandas as pd
  15. import requests
  16. import torch
  17. import torch.nn as nn
  18. import yaml
  19. from PIL import Image
  20. from torch.cuda import amp
  21. from utils.datasets import exif_transpose, letterbox
  22. from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
  23. make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
  24. from utils.plots import Annotator, colors, save_one_box
  25. from utils.torch_utils import copy_attr, time_sync
  26. def autopad(k, p=None): # kernel, padding
  27. # Pad to 'same'
  28. if p is None:
  29. p = k // 2 if isinstance(k, int) else (x // 2 for x in k) # auto-pad
  30. return p
  31. class Conv(nn.Module):
  32. # Standard convolution
  33. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  34. super().__init__()
  35. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  36. self.bn = nn.BatchNorm2d(c2)
  37. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  38. def forward(self, x):
  39. return self.act(self.bn(self.conv(x)))
  40. def forward_fuse(self, x):
  41. return self.act(self.conv(x))
  42. class DWConv(Conv):
  43. # Depth-wise convolution class
  44. def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  45. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  46. class TransformerLayer(nn.Module):
  47. # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
  48. def __init__(self, c, num_heads):
  49. super().__init__()
  50. self.q = nn.Linear(c, c, bias=False)
  51. self.k = nn.Linear(c, c, bias=False)
  52. self.v = nn.Linear(c, c, bias=False)
  53. self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
  54. self.fc1 = nn.Linear(c, c, bias=False)
  55. self.fc2 = nn.Linear(c, c, bias=False)
  56. def forward(self, x):
  57. x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
  58. x = self.fc2(self.fc1(x)) + x
  59. return x
  60. class TransformerBlock(nn.Module):
  61. # Vision Transformer https://arxiv.org/abs/2010.11929
  62. def __init__(self, c1, c2, num_heads, num_layers):
  63. super().__init__()
  64. self.conv = None
  65. if c1 != c2:
  66. self.conv = Conv(c1, c2)
  67. self.linear = nn.Linear(c2, c2) # learnable position embedding
  68. self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
  69. self.c2 = c2
  70. def forward(self, x):
  71. if self.conv is not None:
  72. x = self.conv(x)
  73. b, _, w, h = x.shape
  74. p = x.flatten(2).permute(2, 0, 1)
  75. return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
  76. class Bottleneck(nn.Module):
  77. # Standard bottleneck
  78. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  79. super().__init__()
  80. c_ = int(c2 * e) # hidden channels
  81. self.cv1 = Conv(c1, c_, 1, 1)
  82. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  83. self.add = shortcut and c1 == c2
  84. def forward(self, x):
  85. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  86. class BottleneckCSP(nn.Module):
  87. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  88. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  89. super().__init__()
  90. c_ = int(c2 * e) # hidden channels
  91. self.cv1 = Conv(c1, c_, 1, 1)
  92. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  93. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  94. self.cv4 = Conv(2 * c_, c2, 1, 1)
  95. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  96. self.act = nn.SiLU()
  97. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  98. def forward(self, x):
  99. y1 = self.cv3(self.m(self.cv1(x)))
  100. y2 = self.cv2(x)
  101. return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
  102. class C3(nn.Module):
  103. # CSP Bottleneck with 3 convolutions
  104. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  105. super().__init__()
  106. c_ = int(c2 * e) # hidden channels
  107. self.cv1 = Conv(c1, c_, 1, 1)
  108. self.cv2 = Conv(c1, c_, 1, 1)
  109. self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
  110. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  111. # self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))
  112. def forward(self, x):
  113. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
  114. class C3TR(C3):
  115. # C3 module with TransformerBlock()
  116. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  117. super().__init__(c1, c2, n, shortcut, g, e)
  118. c_ = int(c2 * e)
  119. self.m = TransformerBlock(c_, c_, 4, n)
  120. class C3SPP(C3):
  121. # C3 module with SPP()
  122. def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
  123. super().__init__(c1, c2, n, shortcut, g, e)
  124. c_ = int(c2 * e)
  125. self.m = SPP(c_, c_, k)
  126. class C3Ghost(C3):
  127. # C3 module with GhostBottleneck()
  128. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  129. super().__init__(c1, c2, n, shortcut, g, e)
  130. c_ = int(c2 * e) # hidden channels
  131. self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
  132. class SPP(nn.Module):
  133. # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
  134. def __init__(self, c1, c2, k=(5, 9, 13)):
  135. super().__init__()
  136. c_ = c1 // 2 # hidden channels
  137. self.cv1 = Conv(c1, c_, 1, 1)
  138. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  139. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  140. def forward(self, x):
  141. x = self.cv1(x)
  142. with warnings.catch_warnings():
  143. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  144. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  145. class SPPF(nn.Module):
  146. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
  147. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  148. super().__init__()
  149. c_ = c1 // 2 # hidden channels
  150. self.cv1 = Conv(c1, c_, 1, 1)
  151. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  152. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  153. def forward(self, x):
  154. x = self.cv1(x)
  155. with warnings.catch_warnings():
  156. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  157. y1 = self.m(x)
  158. y2 = self.m(y1)
  159. return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
  160. class Focus(nn.Module):
  161. # Focus wh information into c-space
  162. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  163. super().__init__()
  164. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  165. # self.contract = Contract(gain=2)
  166. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  167. return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
  168. # return self.conv(self.contract(x))
  169. class GhostConv(nn.Module):
  170. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  171. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  172. super().__init__()
  173. c_ = c2 // 2 # hidden channels
  174. self.cv1 = Conv(c1, c_, k, s, None, g, act)
  175. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
  176. def forward(self, x):
  177. y = self.cv1(x)
  178. return torch.cat((y, self.cv2(y)), 1)
  179. class GhostBottleneck(nn.Module):
  180. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  181. def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
  182. super().__init__()
  183. c_ = c2 // 2
  184. self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
  185. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  186. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  187. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
  188. Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  189. def forward(self, x):
  190. return self.conv(x) + self.shortcut(x)
  191. class Contract(nn.Module):
  192. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  193. def __init__(self, gain=2):
  194. super().__init__()
  195. self.gain = gain
  196. def forward(self, x):
  197. b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
  198. s = self.gain
  199. x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
  200. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  201. return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
  202. class Expand(nn.Module):
  203. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  204. def __init__(self, gain=2):
  205. super().__init__()
  206. self.gain = gain
  207. def forward(self, x):
  208. b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  209. s = self.gain
  210. x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
  211. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  212. return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
  213. class Concat(nn.Module):
  214. # Concatenate a list of tensors along dimension
  215. def __init__(self, dimension=1):
  216. super().__init__()
  217. self.d = dimension
  218. def forward(self, x):
  219. return torch.cat(x, self.d)
  220. class DetectMultiBackend(nn.Module):
  221. # YOLOv5 MultiBackend class for python inference on various backends
  222. def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
  223. # Usage:
  224. # PyTorch: weights = *.pt
  225. # TorchScript: *.torchscript
  226. # ONNX Runtime: *.onnx
  227. # ONNX OpenCV DNN: *.onnx with --dnn
  228. # OpenVINO: *.xml
  229. # CoreML: *.mlmodel
  230. # TensorRT: *.engine
  231. # TensorFlow SavedModel: *_saved_model
  232. # TensorFlow GraphDef: *.pb
  233. # TensorFlow Lite: *.tflite
  234. # TensorFlow Edge TPU: *_edgetpu.tflite
  235. from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
  236. super().__init__()
  237. w = str(weights[0] if isinstance(weights, list) else weights)
  238. pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend
  239. stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
  240. w = attempt_download(w) # download if not local
  241. fp16 &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16
  242. if data: # data.yaml path (optional)
  243. with open(data, errors='ignore') as f:
  244. names = yaml.safe_load(f)['names'] # class names
  245. if pt: # PyTorch
  246. model = attempt_load(weights if isinstance(weights, list) else w, map_location=device)
  247. stride = max(int(model.stride.max()), 32) # model stride
  248. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  249. model.half() if fp16 else model.float()
  250. self.model = model # explicitly assign for to(), cpu(), cuda(), half()
  251. elif jit: # TorchScript
  252. LOGGER.info(f'Loading {w} for TorchScript inference...')
  253. extra_files = {'config.txt': ''} # model metadata
  254. model = torch.jit.load(w, _extra_files=extra_files)
  255. model.half() if fp16 else model.float()
  256. if extra_files['config.txt']:
  257. d = json.loads(extra_files['config.txt']) # extra_files dict
  258. stride, names = int(d['stride']), d['names']
  259. elif dnn: # ONNX OpenCV DNN
  260. LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
  261. check_requirements(('opencv-python>=4.5.4',))
  262. net = cv2.dnn.readNetFromONNX(w)
  263. elif onnx: # ONNX Runtime
  264. LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
  265. cuda = torch.cuda.is_available()
  266. check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
  267. import onnxruntime
  268. providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
  269. session = onnxruntime.InferenceSession(w, providers=providers)
  270. elif xml: # OpenVINO
  271. LOGGER.info(f'Loading {w} for OpenVINO inference...')
  272. check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
  273. import openvino.inference_engine as ie
  274. core = ie.IECore()
  275. if not Path(w).is_file(): # if not *.xml
  276. w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
  277. network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths
  278. executable_network = core.load_network(network, device_name='CPU', num_requests=1)
  279. elif engine: # TensorRT
  280. LOGGER.info(f'Loading {w} for TensorRT inference...')
  281. import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
  282. check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
  283. Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
  284. logger = trt.Logger(trt.Logger.INFO)
  285. with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
  286. model = runtime.deserialize_cuda_engine(f.read())
  287. bindings = OrderedDict()
  288. fp16 = False # default updated below
  289. for index in range(model.num_bindings):
  290. name = model.get_binding_name(index)
  291. dtype = trt.nptype(model.get_binding_dtype(index))
  292. shape = tuple(model.get_binding_shape(index))
  293. data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
  294. bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
  295. if model.binding_is_input(index) and dtype == np.float16:
  296. fp16 = True
  297. binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
  298. context = model.create_execution_context()
  299. batch_size = bindings['images'].shape[0]
  300. elif coreml: # CoreML
  301. LOGGER.info(f'Loading {w} for CoreML inference...')
  302. import coremltools as ct
  303. model = ct.models.MLModel(w)
  304. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  305. if saved_model: # SavedModel
  306. LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
  307. import tensorflow as tf
  308. keras = False # assume TF1 saved_model
  309. model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
  310. elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
  311. LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
  312. import tensorflow as tf
  313. def wrap_frozen_graph(gd, inputs, outputs):
  314. x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
  315. ge = x.graph.as_graph_element
  316. return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
  317. gd = tf.Graph().as_graph_def() # graph_def
  318. gd.ParseFromString(open(w, 'rb').read())
  319. frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
  320. elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
  321. try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
  322. from tflite_runtime.interpreter import Interpreter, load_delegate
  323. except ImportError:
  324. import tensorflow as tf
  325. Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
  326. if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime
  327. LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
  328. delegate = {'Linux': 'libedgetpu.so.1',
  329. 'Darwin': 'libedgetpu.1.dylib',
  330. 'Windows': 'edgetpu.dll'}[platform.system()]
  331. interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
  332. else: # Lite
  333. LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
  334. interpreter = Interpreter(model_path=w) # load TFLite model
  335. interpreter.allocate_tensors() # allocate
  336. input_details = interpreter.get_input_details() # inputs
  337. output_details = interpreter.get_output_details() # outputs
  338. elif tfjs:
  339. raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
  340. self.__dict__.update(locals()) # assign all variables to self
  341. def forward(self, im, augment=False, visualize=False, val=False):
  342. # YOLOv5 MultiBackend inference
  343. b, ch, h, w = im.shape # batch, channel, height, width
  344. if self.pt or self.jit: # PyTorch
  345. y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
  346. return y if val else y[0]
  347. elif self.dnn: # ONNX OpenCV DNN
  348. im = im.cpu().numpy() # torch to numpy
  349. self.net.setInput(im)
  350. y = self.net.forward()
  351. elif self.onnx: # ONNX Runtime
  352. im = im.cpu().numpy() # torch to numpy
  353. y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
  354. elif self.xml: # OpenVINO
  355. im = im.cpu().numpy() # FP32
  356. desc = self.ie.TensorDesc(precision='FP32', dims=im.shape, layout='NCHW') # Tensor Description
  357. request = self.executable_network.requests[0] # inference request
  358. request.set_blob(blob_name='images', blob=self.ie.Blob(desc, im)) # name=next(iter(request.input_blobs))
  359. request.infer()
  360. y = request.output_blobs['output'].buffer # name=next(iter(request.output_blobs))
  361. elif self.engine: # TensorRT
  362. assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
  363. self.binding_addrs['images'] = int(im.data_ptr())
  364. self.context.execute_v2(list(self.binding_addrs.values()))
  365. y = self.bindings['output'].data
  366. elif self.coreml: # CoreML
  367. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  368. im = Image.fromarray((im[0] * 255).astype('uint8'))
  369. # im = im.resize((192, 320), Image.ANTIALIAS)
  370. y = self.model.predict({'image': im}) # coordinates are xywh normalized
  371. if 'confidence' in y:
  372. box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
  373. conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
  374. y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
  375. else:
  376. k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key
  377. y = y[k] # output
  378. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  379. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  380. if self.saved_model: # SavedModel
  381. y = (self.model(im, training=False) if self.keras else self.model(im)).numpy()
  382. elif self.pb: # GraphDef
  383. y = self.frozen_func(x=self.tf.constant(im)).numpy()
  384. else: # Lite or Edge TPU
  385. input, output = self.input_details[0], self.output_details[0]
  386. int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
  387. if int8:
  388. scale, zero_point = input['quantization']
  389. im = (im / scale + zero_point).astype(np.uint8) # de-scale
  390. self.interpreter.set_tensor(input['index'], im)
  391. self.interpreter.invoke()
  392. y = self.interpreter.get_tensor(output['index'])
  393. if int8:
  394. scale, zero_point = output['quantization']
  395. y = (y.astype(np.float32) - zero_point) * scale # re-scale
  396. y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
  397. if isinstance(y, np.ndarray):
  398. y = torch.tensor(y, device=self.device)
  399. return (y, []) if val else y
  400. def warmup(self, imgsz=(1, 3, 640, 640)):
  401. # Warmup model by running inference once
  402. if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)): # warmup types
  403. if self.device.type != 'cpu': # only warmup GPU models
  404. im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
  405. for _ in range(2 if self.jit else 1): #
  406. self.forward(im) # warmup
  407. @staticmethod
  408. def model_type(p='path/to/model.pt'):
  409. # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
  410. from export import export_formats
  411. suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
  412. check_suffix(p, suffixes) # checks
  413. p = Path(p).name # eliminate trailing separators
  414. pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes)
  415. xml |= xml2 # *_openvino_model or *.xml
  416. tflite &= not edgetpu # *.tflite
  417. return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
  418. class AutoShape(nn.Module):
  419. # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  420. conf = 0.25 # NMS confidence threshold
  421. iou = 0.45 # NMS IoU threshold
  422. agnostic = False # NMS class-agnostic
  423. multi_label = False # NMS multiple labels per box
  424. classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
  425. max_det = 1000 # maximum number of detections per image
  426. amp = False # Automatic Mixed Precision (AMP) inference
  427. def __init__(self, model):
  428. super().__init__()
  429. LOGGER.info('Adding AutoShape... ')
  430. copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
  431. self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
  432. self.pt = not self.dmb or model.pt # PyTorch model
  433. self.model = model.eval()
  434. def _apply(self, fn):
  435. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  436. self = super()._apply(fn)
  437. if self.pt:
  438. m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
  439. m.stride = fn(m.stride)
  440. m.grid = list(map(fn, m.grid))
  441. if isinstance(m.anchor_grid, list):
  442. m.anchor_grid = list(map(fn, m.anchor_grid))
  443. return self
  444. @torch.no_grad()
  445. def forward(self, imgs, size=640, augment=False, profile=False):
  446. # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
  447. # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
  448. # URI: = 'https://ultralytics.com/images/zidane.jpg'
  449. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
  450. # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
  451. # numpy: = np.zeros((640,1280,3)) # HWC
  452. # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
  453. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  454. t = [time_sync()]
  455. p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
  456. autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
  457. if isinstance(imgs, torch.Tensor): # torch
  458. with amp.autocast(autocast):
  459. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  460. # Pre-process
  461. n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
  462. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  463. for i, im in enumerate(imgs):
  464. f = f'image{i}' # filename
  465. if isinstance(im, (str, Path)): # filename or uri
  466. im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
  467. im = np.asarray(exif_transpose(im))
  468. elif isinstance(im, Image.Image): # PIL Image
  469. im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
  470. files.append(Path(f).with_suffix('.jpg').name)
  471. if im.shape[0] < 5: # image in CHW
  472. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  473. im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
  474. s = im.shape[:2] # HWC
  475. shape0.append(s) # image shape
  476. g = (size / max(s)) # gain
  477. shape1.append([y * g for y in s])
  478. imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
  479. shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
  480. x = [letterbox(im, shape1, auto=False)[0] for im in imgs] # pad
  481. x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
  482. x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
  483. t.append(time_sync())
  484. with amp.autocast(autocast):
  485. # Inference
  486. y = self.model(x, augment, profile) # forward
  487. t.append(time_sync())
  488. # Post-process
  489. y = non_max_suppression(y if self.dmb else y[0], self.conf, self.iou, self.classes, self.agnostic,
  490. self.multi_label, max_det=self.max_det) # NMS
  491. for i in range(n):
  492. scale_coords(shape1, y[i][:, :4], shape0[i])
  493. t.append(time_sync())
  494. return Detections(imgs, y, files, t, self.names, x.shape)
  495. class Detections:
  496. # YOLOv5 detections class for inference results
  497. def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
  498. super().__init__()
  499. d = pred[0].device # device
  500. gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
  501. self.imgs = imgs # list of images as numpy arrays
  502. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  503. self.names = names # class names
  504. self.files = files # image filenames
  505. self.times = times # profiling times
  506. self.xyxy = pred # xyxy pixels
  507. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  508. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  509. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  510. self.n = len(self.pred) # number of images (batch size)
  511. self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
  512. self.s = shape # inference BCHW shape
  513. def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
  514. crops = []
  515. for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
  516. s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
  517. if pred.shape[0]:
  518. for c in pred[:, -1].unique():
  519. n = (pred[:, -1] == c).sum() # detections per class
  520. s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  521. if show or save or render or crop:
  522. annotator = Annotator(im, example=str(self.names))
  523. for *box, conf, cls in reversed(pred): # xyxy, confidence, class
  524. label = f'{self.names[int(cls)]} {conf:.2f}'
  525. if crop:
  526. file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
  527. crops.append({'box': box, 'conf': conf, 'cls': cls, 'label': label,
  528. 'im': save_one_box(box, im, file=file, save=save)})
  529. else: # all others
  530. annotator.box_label(box, label if labels else '', color=colors(cls))
  531. im = annotator.im
  532. else:
  533. s += '(no detections)'
  534. im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
  535. if pprint:
  536. LOGGER.info(s.rstrip(', '))
  537. if show:
  538. im.show(self.files[i]) # show
  539. if save:
  540. f = self.files[i]
  541. im.save(save_dir / f) # save
  542. if i == self.n - 1:
  543. LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
  544. if render:
  545. self.imgs[i] = np.asarray(im)
  546. if crop:
  547. if save:
  548. LOGGER.info(f'Saved results to {save_dir}\n')
  549. return crops
  550. def print(self):
  551. self.display(pprint=True) # print results
  552. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
  553. self.t)
  554. def show(self, labels=True):
  555. self.display(show=True, labels=labels) # show results
  556. def save(self, labels=True, save_dir='runs/detect/exp'):
  557. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
  558. self.display(save=True, labels=labels, save_dir=save_dir) # save results
  559. def crop(self, save=True, save_dir='runs/detect/exp'):
  560. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
  561. return self.display(crop=True, save=save, save_dir=save_dir) # crop results
  562. def render(self, labels=True):
  563. self.display(render=True, labels=labels) # render results
  564. return self.imgs
  565. def pandas(self):
  566. # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
  567. new = copy(self) # return copy
  568. ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
  569. cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
  570. for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
  571. a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
  572. setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
  573. return new
  574. def tolist(self):
  575. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  576. r = range(self.n) # iterable
  577. x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
  578. # for d in x:
  579. # for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  580. # setattr(d, k, getattr(d, k)[0]) # pop out of list
  581. return x
  582. def __len__(self):
  583. return self.n
  584. class Classify(nn.Module):
  585. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  586. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  587. super().__init__()
  588. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  589. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  590. self.flat = nn.Flatten()
  591. def forward(self, x):
  592. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  593. return self.flat(self.conv(z)) # flatten to x(b,c2)