您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

575 行
25KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. TensorFlow, Keras and TFLite versions of YOLOv5
  4. Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
  5. Usage:
  6. $ python models/tf.py --weights yolov5s.pt
  7. Export:
  8. $ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
  9. """
  10. import argparse
  11. import sys
  12. from copy import deepcopy
  13. from pathlib import Path
  14. FILE = Path(__file__).resolve()
  15. ROOT = FILE.parents[1] # YOLOv5 root directory
  16. if str(ROOT) not in sys.path:
  17. sys.path.append(str(ROOT)) # add ROOT to PATH
  18. # ROOT = ROOT.relative_to(Path.cwd()) # relative
  19. import numpy as np
  20. import tensorflow as tf
  21. import torch
  22. import torch.nn as nn
  23. from tensorflow import keras
  24. from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
  25. DWConvTranspose2d, Focus, autopad)
  26. from models.experimental import MixConv2d, attempt_load
  27. from models.yolo import Detect
  28. from utils.activations import SiLU
  29. from utils.general import LOGGER, make_divisible, print_args
  30. class TFBN(keras.layers.Layer):
  31. # TensorFlow BatchNormalization wrapper
  32. def __init__(self, w=None):
  33. super().__init__()
  34. self.bn = keras.layers.BatchNormalization(
  35. beta_initializer=keras.initializers.Constant(w.bias.numpy()),
  36. gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
  37. moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
  38. moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
  39. epsilon=w.eps)
  40. def call(self, inputs):
  41. return self.bn(inputs)
  42. class TFPad(keras.layers.Layer):
  43. # Pad inputs in spatial dimensions 1 and 2
  44. def __init__(self, pad):
  45. super().__init__()
  46. if isinstance(pad, int):
  47. self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
  48. else: # tuple/list
  49. self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]])
  50. def call(self, inputs):
  51. return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
  52. class TFConv(keras.layers.Layer):
  53. # Standard convolution
  54. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  55. # ch_in, ch_out, weights, kernel, stride, padding, groups
  56. super().__init__()
  57. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  58. # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
  59. # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
  60. conv = keras.layers.Conv2D(
  61. filters=c2,
  62. kernel_size=k,
  63. strides=s,
  64. padding='SAME' if s == 1 else 'VALID',
  65. use_bias=not hasattr(w, 'bn'),
  66. kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
  67. bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
  68. self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
  69. self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
  70. self.act = activations(w.act) if act else tf.identity
  71. def call(self, inputs):
  72. return self.act(self.bn(self.conv(inputs)))
  73. class TFDWConv(keras.layers.Layer):
  74. # Depthwise convolution
  75. def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
  76. # ch_in, ch_out, weights, kernel, stride, padding, groups
  77. super().__init__()
  78. assert c2 % c1 == 0, f'TFDWConv() output={c2} must be a multiple of input={c1} channels'
  79. conv = keras.layers.DepthwiseConv2D(
  80. kernel_size=k,
  81. depth_multiplier=c2 // c1,
  82. strides=s,
  83. padding='SAME' if s == 1 else 'VALID',
  84. use_bias=not hasattr(w, 'bn'),
  85. depthwise_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
  86. bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
  87. self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
  88. self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
  89. self.act = activations(w.act) if act else tf.identity
  90. def call(self, inputs):
  91. return self.act(self.bn(self.conv(inputs)))
  92. class TFDWConvTranspose2d(keras.layers.Layer):
  93. # Depthwise ConvTranspose2d
  94. def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None):
  95. # ch_in, ch_out, weights, kernel, stride, padding, groups
  96. super().__init__()
  97. assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels'
  98. assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1'
  99. weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy()
  100. self.c1 = c1
  101. self.conv = [
  102. keras.layers.Conv2DTranspose(filters=1,
  103. kernel_size=k,
  104. strides=s,
  105. padding='VALID',
  106. output_padding=p2,
  107. use_bias=True,
  108. kernel_initializer=keras.initializers.Constant(weight[..., i:i + 1]),
  109. bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c1)]
  110. def call(self, inputs):
  111. return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1]
  112. class TFFocus(keras.layers.Layer):
  113. # Focus wh information into c-space
  114. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  115. # ch_in, ch_out, kernel, stride, padding, groups
  116. super().__init__()
  117. self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
  118. def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
  119. # inputs = inputs / 255 # normalize 0-255 to 0-1
  120. inputs = [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]]
  121. return self.conv(tf.concat(inputs, 3))
  122. class TFBottleneck(keras.layers.Layer):
  123. # Standard bottleneck
  124. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
  125. super().__init__()
  126. c_ = int(c2 * e) # hidden channels
  127. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  128. self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
  129. self.add = shortcut and c1 == c2
  130. def call(self, inputs):
  131. return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
  132. class TFCrossConv(keras.layers.Layer):
  133. # Cross Convolution
  134. def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None):
  135. super().__init__()
  136. c_ = int(c2 * e) # hidden channels
  137. self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1)
  138. self.cv2 = TFConv(c_, c2, (k, 1), (s, 1), g=g, w=w.cv2)
  139. self.add = shortcut and c1 == c2
  140. def call(self, inputs):
  141. return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
  142. class TFConv2d(keras.layers.Layer):
  143. # Substitution for PyTorch nn.Conv2D
  144. def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
  145. super().__init__()
  146. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  147. self.conv = keras.layers.Conv2D(filters=c2,
  148. kernel_size=k,
  149. strides=s,
  150. padding='VALID',
  151. use_bias=bias,
  152. kernel_initializer=keras.initializers.Constant(
  153. w.weight.permute(2, 3, 1, 0).numpy()),
  154. bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None)
  155. def call(self, inputs):
  156. return self.conv(inputs)
  157. class TFBottleneckCSP(keras.layers.Layer):
  158. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  159. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  160. # ch_in, ch_out, number, shortcut, groups, expansion
  161. super().__init__()
  162. c_ = int(c2 * e) # hidden channels
  163. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  164. self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
  165. self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
  166. self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
  167. self.bn = TFBN(w.bn)
  168. self.act = lambda x: keras.activations.swish(x)
  169. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  170. def call(self, inputs):
  171. y1 = self.cv3(self.m(self.cv1(inputs)))
  172. y2 = self.cv2(inputs)
  173. return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
  174. class TFC3(keras.layers.Layer):
  175. # CSP Bottleneck with 3 convolutions
  176. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  177. # ch_in, ch_out, number, shortcut, groups, expansion
  178. super().__init__()
  179. c_ = int(c2 * e) # hidden channels
  180. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  181. self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
  182. self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
  183. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  184. def call(self, inputs):
  185. return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
  186. class TFC3x(keras.layers.Layer):
  187. # 3 module with cross-convolutions
  188. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  189. # ch_in, ch_out, number, shortcut, groups, expansion
  190. super().__init__()
  191. c_ = int(c2 * e) # hidden channels
  192. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  193. self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
  194. self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
  195. self.m = keras.Sequential([
  196. TFCrossConv(c_, c_, k=3, s=1, g=g, e=1.0, shortcut=shortcut, w=w.m[j]) for j in range(n)])
  197. def call(self, inputs):
  198. return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
  199. class TFSPP(keras.layers.Layer):
  200. # Spatial pyramid pooling layer used in YOLOv3-SPP
  201. def __init__(self, c1, c2, k=(5, 9, 13), w=None):
  202. super().__init__()
  203. c_ = c1 // 2 # hidden channels
  204. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  205. self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
  206. self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
  207. def call(self, inputs):
  208. x = self.cv1(inputs)
  209. return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
  210. class TFSPPF(keras.layers.Layer):
  211. # Spatial pyramid pooling-Fast layer
  212. def __init__(self, c1, c2, k=5, w=None):
  213. super().__init__()
  214. c_ = c1 // 2 # hidden channels
  215. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  216. self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
  217. self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
  218. def call(self, inputs):
  219. x = self.cv1(inputs)
  220. y1 = self.m(x)
  221. y2 = self.m(y1)
  222. return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
  223. class TFDetect(keras.layers.Layer):
  224. # TF YOLOv5 Detect layer
  225. def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
  226. super().__init__()
  227. self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
  228. self.nc = nc # number of classes
  229. self.no = nc + 5 # number of outputs per anchor
  230. self.nl = len(anchors) # number of detection layers
  231. self.na = len(anchors[0]) // 2 # number of anchors
  232. self.grid = [tf.zeros(1)] * self.nl # init grid
  233. self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
  234. self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]), [self.nl, 1, -1, 1, 2])
  235. self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
  236. self.training = False # set to False after building model
  237. self.imgsz = imgsz
  238. for i in range(self.nl):
  239. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  240. self.grid[i] = self._make_grid(nx, ny)
  241. def call(self, inputs):
  242. z = [] # inference output
  243. x = []
  244. for i in range(self.nl):
  245. x.append(self.m[i](inputs[i]))
  246. # x(bs,20,20,255) to x(bs,3,20,20,85)
  247. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  248. x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
  249. if not self.training: # inference
  250. y = tf.sigmoid(x[i])
  251. grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
  252. anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
  253. xy = (y[..., 0:2] * 2 + grid) * self.stride[i] # xy
  254. wh = y[..., 2:4] ** 2 * anchor_grid
  255. # Normalize xywh to 0-1 to reduce calibration error
  256. xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  257. wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  258. y = tf.concat([xy, wh, y[..., 4:]], -1)
  259. z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
  260. return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1), x)
  261. @staticmethod
  262. def _make_grid(nx=20, ny=20):
  263. # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  264. # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  265. xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
  266. return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
  267. class TFUpsample(keras.layers.Layer):
  268. # TF version of torch.nn.Upsample()
  269. def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
  270. super().__init__()
  271. assert scale_factor == 2, "scale_factor must be 2"
  272. self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
  273. # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
  274. # with default arguments: align_corners=False, half_pixel_centers=False
  275. # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
  276. # size=(x.shape[1] * 2, x.shape[2] * 2))
  277. def call(self, inputs):
  278. return self.upsample(inputs)
  279. class TFConcat(keras.layers.Layer):
  280. # TF version of torch.concat()
  281. def __init__(self, dimension=1, w=None):
  282. super().__init__()
  283. assert dimension == 1, "convert only NCHW to NHWC concat"
  284. self.d = 3
  285. def call(self, inputs):
  286. return tf.concat(inputs, self.d)
  287. def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
  288. LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
  289. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  290. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  291. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  292. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  293. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  294. m_str = m
  295. m = eval(m) if isinstance(m, str) else m # eval strings
  296. for j, a in enumerate(args):
  297. try:
  298. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  299. except NameError:
  300. pass
  301. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  302. if m in [
  303. nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
  304. BottleneckCSP, C3, C3x]:
  305. c1, c2 = ch[f], args[0]
  306. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  307. args = [c1, c2, *args[1:]]
  308. if m in [BottleneckCSP, C3, C3x]:
  309. args.insert(2, n)
  310. n = 1
  311. elif m is nn.BatchNorm2d:
  312. args = [ch[f]]
  313. elif m is Concat:
  314. c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
  315. elif m is Detect:
  316. args.append([ch[x + 1] for x in f])
  317. if isinstance(args[1], int): # number of anchors
  318. args[1] = [list(range(args[1] * 2))] * len(f)
  319. args.append(imgsz)
  320. else:
  321. c2 = ch[f]
  322. tf_m = eval('TF' + m_str.replace('nn.', ''))
  323. m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
  324. else tf_m(*args, w=model.model[i]) # module
  325. torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  326. t = str(m)[8:-2].replace('__main__.', '') # module type
  327. np = sum(x.numel() for x in torch_m_.parameters()) # number params
  328. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  329. LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print
  330. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  331. layers.append(m_)
  332. ch.append(c2)
  333. return keras.Sequential(layers), sorted(save)
  334. class TFModel:
  335. # TF YOLOv5 model
  336. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
  337. super().__init__()
  338. if isinstance(cfg, dict):
  339. self.yaml = cfg # model dict
  340. else: # is *.yaml
  341. import yaml # for torch hub
  342. self.yaml_file = Path(cfg).name
  343. with open(cfg) as f:
  344. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  345. # Define model
  346. if nc and nc != self.yaml['nc']:
  347. LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
  348. self.yaml['nc'] = nc # override yaml value
  349. self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
  350. def predict(self,
  351. inputs,
  352. tf_nms=False,
  353. agnostic_nms=False,
  354. topk_per_class=100,
  355. topk_all=100,
  356. iou_thres=0.45,
  357. conf_thres=0.25):
  358. y = [] # outputs
  359. x = inputs
  360. for m in self.model.layers:
  361. if m.f != -1: # if not from previous layer
  362. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  363. x = m(x) # run
  364. y.append(x if m.i in self.savelist else None) # save output
  365. # Add TensorFlow NMS
  366. if tf_nms:
  367. boxes = self._xywh2xyxy(x[0][..., :4])
  368. probs = x[0][:, :, 4:5]
  369. classes = x[0][:, :, 5:]
  370. scores = probs * classes
  371. if agnostic_nms:
  372. nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
  373. else:
  374. boxes = tf.expand_dims(boxes, 2)
  375. nms = tf.image.combined_non_max_suppression(boxes,
  376. scores,
  377. topk_per_class,
  378. topk_all,
  379. iou_thres,
  380. conf_thres,
  381. clip_boxes=False)
  382. return nms, x[1]
  383. return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
  384. # x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
  385. # xywh = x[..., :4] # x(6300,4) boxes
  386. # conf = x[..., 4:5] # x(6300,1) confidences
  387. # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
  388. # return tf.concat([conf, cls, xywh], 1)
  389. @staticmethod
  390. def _xywh2xyxy(xywh):
  391. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  392. x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
  393. return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
  394. class AgnosticNMS(keras.layers.Layer):
  395. # TF Agnostic NMS
  396. def call(self, input, topk_all, iou_thres, conf_thres):
  397. # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
  398. return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres),
  399. input,
  400. fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
  401. name='agnostic_nms')
  402. @staticmethod
  403. def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
  404. boxes, classes, scores = x
  405. class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
  406. scores_inp = tf.reduce_max(scores, -1)
  407. selected_inds = tf.image.non_max_suppression(boxes,
  408. scores_inp,
  409. max_output_size=topk_all,
  410. iou_threshold=iou_thres,
  411. score_threshold=conf_thres)
  412. selected_boxes = tf.gather(boxes, selected_inds)
  413. padded_boxes = tf.pad(selected_boxes,
  414. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
  415. mode="CONSTANT",
  416. constant_values=0.0)
  417. selected_scores = tf.gather(scores_inp, selected_inds)
  418. padded_scores = tf.pad(selected_scores,
  419. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  420. mode="CONSTANT",
  421. constant_values=-1.0)
  422. selected_classes = tf.gather(class_inds, selected_inds)
  423. padded_classes = tf.pad(selected_classes,
  424. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  425. mode="CONSTANT",
  426. constant_values=-1.0)
  427. valid_detections = tf.shape(selected_inds)[0]
  428. return padded_boxes, padded_scores, padded_classes, valid_detections
  429. def activations(act=nn.SiLU):
  430. # Returns TF activation from input PyTorch activation
  431. if isinstance(act, nn.LeakyReLU):
  432. return lambda x: keras.activations.relu(x, alpha=0.1)
  433. elif isinstance(act, nn.Hardswish):
  434. return lambda x: x * tf.nn.relu6(x + 3) * 0.166666667
  435. elif isinstance(act, (nn.SiLU, SiLU)):
  436. return lambda x: keras.activations.swish(x)
  437. else:
  438. raise Exception(f'no matching TensorFlow activation found for PyTorch activation {act}')
  439. def representative_dataset_gen(dataset, ncalib=100):
  440. # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
  441. for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
  442. im = np.transpose(img, [1, 2, 0])
  443. im = np.expand_dims(im, axis=0).astype(np.float32)
  444. im /= 255
  445. yield [im]
  446. if n >= ncalib:
  447. break
  448. def run(
  449. weights=ROOT / 'yolov5s.pt', # weights path
  450. imgsz=(640, 640), # inference size h,w
  451. batch_size=1, # batch size
  452. dynamic=False, # dynamic batch size
  453. ):
  454. # PyTorch model
  455. im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
  456. model = attempt_load(weights, device=torch.device('cpu'), inplace=True, fuse=False)
  457. _ = model(im) # inference
  458. model.info()
  459. # TensorFlow model
  460. im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
  461. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  462. _ = tf_model.predict(im) # inference
  463. # Keras model
  464. im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
  465. keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
  466. keras_model.summary()
  467. LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
  468. def parse_opt():
  469. parser = argparse.ArgumentParser()
  470. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
  471. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
  472. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  473. parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
  474. opt = parser.parse_args()
  475. opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
  476. print_args(vars(opt))
  477. return opt
  478. def main(opt):
  479. run(**vars(opt))
  480. if __name__ == "__main__":
  481. opt = parse_opt()
  482. main(opt)