您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

466 行
20KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. TensorFlow, Keras and TFLite versions of YOLOv5
  4. Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
  5. Usage:
  6. $ python models/tf.py --weights yolov5s.pt
  7. Export:
  8. $ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
  9. """
  10. import argparse
  11. import logging
  12. import sys
  13. from copy import deepcopy
  14. from pathlib import Path
  15. FILE = Path(__file__).resolve()
  16. ROOT = FILE.parents[1] # YOLOv5 root directory
  17. if str(ROOT) not in sys.path:
  18. sys.path.append(str(ROOT)) # add ROOT to PATH
  19. # ROOT = ROOT.relative_to(Path.cwd()) # relative
  20. import numpy as np
  21. import tensorflow as tf
  22. import torch
  23. import torch.nn as nn
  24. from tensorflow import keras
  25. from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad
  26. from models.experimental import CrossConv, MixConv2d, attempt_load
  27. from models.yolo import Detect
  28. from utils.general import make_divisible, print_args, LOGGER
  29. from utils.activations import SiLU
  30. class TFBN(keras.layers.Layer):
  31. # TensorFlow BatchNormalization wrapper
  32. def __init__(self, w=None):
  33. super().__init__()
  34. self.bn = keras.layers.BatchNormalization(
  35. beta_initializer=keras.initializers.Constant(w.bias.numpy()),
  36. gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
  37. moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
  38. moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
  39. epsilon=w.eps)
  40. def call(self, inputs):
  41. return self.bn(inputs)
  42. class TFPad(keras.layers.Layer):
  43. def __init__(self, pad):
  44. super().__init__()
  45. self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
  46. def call(self, inputs):
  47. return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
  48. class TFConv(keras.layers.Layer):
  49. # Standard convolution
  50. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  51. # ch_in, ch_out, weights, kernel, stride, padding, groups
  52. super().__init__()
  53. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  54. assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
  55. # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
  56. # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
  57. conv = keras.layers.Conv2D(
  58. c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False if hasattr(w, 'bn') else True,
  59. kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
  60. bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
  61. self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
  62. self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
  63. # YOLOv5 activations
  64. if isinstance(w.act, nn.LeakyReLU):
  65. self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
  66. elif isinstance(w.act, nn.Hardswish):
  67. self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
  68. elif isinstance(w.act, (nn.SiLU, SiLU)):
  69. self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
  70. else:
  71. raise Exception(f'no matching TensorFlow activation found for {w.act}')
  72. def call(self, inputs):
  73. return self.act(self.bn(self.conv(inputs)))
  74. class TFFocus(keras.layers.Layer):
  75. # Focus wh information into c-space
  76. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  77. # ch_in, ch_out, kernel, stride, padding, groups
  78. super().__init__()
  79. self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
  80. def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
  81. # inputs = inputs / 255 # normalize 0-255 to 0-1
  82. return self.conv(tf.concat([inputs[:, ::2, ::2, :],
  83. inputs[:, 1::2, ::2, :],
  84. inputs[:, ::2, 1::2, :],
  85. inputs[:, 1::2, 1::2, :]], 3))
  86. class TFBottleneck(keras.layers.Layer):
  87. # Standard bottleneck
  88. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
  89. super().__init__()
  90. c_ = int(c2 * e) # hidden channels
  91. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  92. self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
  93. self.add = shortcut and c1 == c2
  94. def call(self, inputs):
  95. return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
  96. class TFConv2d(keras.layers.Layer):
  97. # Substitution for PyTorch nn.Conv2D
  98. def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
  99. super().__init__()
  100. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  101. self.conv = keras.layers.Conv2D(
  102. c2, k, s, 'VALID', use_bias=bias,
  103. kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
  104. bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )
  105. def call(self, inputs):
  106. return self.conv(inputs)
  107. class TFBottleneckCSP(keras.layers.Layer):
  108. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  109. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  110. # ch_in, ch_out, number, shortcut, groups, expansion
  111. super().__init__()
  112. c_ = int(c2 * e) # hidden channels
  113. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  114. self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
  115. self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
  116. self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
  117. self.bn = TFBN(w.bn)
  118. self.act = lambda x: keras.activations.relu(x, alpha=0.1)
  119. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  120. def call(self, inputs):
  121. y1 = self.cv3(self.m(self.cv1(inputs)))
  122. y2 = self.cv2(inputs)
  123. return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
  124. class TFC3(keras.layers.Layer):
  125. # CSP Bottleneck with 3 convolutions
  126. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  127. # ch_in, ch_out, number, shortcut, groups, expansion
  128. super().__init__()
  129. c_ = int(c2 * e) # hidden channels
  130. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  131. self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
  132. self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
  133. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  134. def call(self, inputs):
  135. return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
  136. class TFSPP(keras.layers.Layer):
  137. # Spatial pyramid pooling layer used in YOLOv3-SPP
  138. def __init__(self, c1, c2, k=(5, 9, 13), w=None):
  139. super().__init__()
  140. c_ = c1 // 2 # hidden channels
  141. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  142. self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
  143. self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
  144. def call(self, inputs):
  145. x = self.cv1(inputs)
  146. return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
  147. class TFSPPF(keras.layers.Layer):
  148. # Spatial pyramid pooling-Fast layer
  149. def __init__(self, c1, c2, k=5, w=None):
  150. super().__init__()
  151. c_ = c1 // 2 # hidden channels
  152. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  153. self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
  154. self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
  155. def call(self, inputs):
  156. x = self.cv1(inputs)
  157. y1 = self.m(x)
  158. y2 = self.m(y1)
  159. return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
  160. class TFDetect(keras.layers.Layer):
  161. def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
  162. super().__init__()
  163. self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
  164. self.nc = nc # number of classes
  165. self.no = nc + 5 # number of outputs per anchor
  166. self.nl = len(anchors) # number of detection layers
  167. self.na = len(anchors[0]) // 2 # number of anchors
  168. self.grid = [tf.zeros(1)] * self.nl # init grid
  169. self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
  170. self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),
  171. [self.nl, 1, -1, 1, 2])
  172. self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
  173. self.training = False # set to False after building model
  174. self.imgsz = imgsz
  175. for i in range(self.nl):
  176. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  177. self.grid[i] = self._make_grid(nx, ny)
  178. def call(self, inputs):
  179. z = [] # inference output
  180. x = []
  181. for i in range(self.nl):
  182. x.append(self.m[i](inputs[i]))
  183. # x(bs,20,20,255) to x(bs,3,20,20,85)
  184. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  185. x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])
  186. if not self.training: # inference
  187. y = tf.sigmoid(x[i])
  188. xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy
  189. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
  190. # Normalize xywh to 0-1 to reduce calibration error
  191. xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  192. wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  193. y = tf.concat([xy, wh, y[..., 4:]], -1)
  194. z.append(tf.reshape(y, [-1, 3 * ny * nx, self.no]))
  195. return x if self.training else (tf.concat(z, 1), x)
  196. @staticmethod
  197. def _make_grid(nx=20, ny=20):
  198. # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  199. # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  200. xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
  201. return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
  202. class TFUpsample(keras.layers.Layer):
  203. def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
  204. super().__init__()
  205. assert scale_factor == 2, "scale_factor must be 2"
  206. self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
  207. # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
  208. # with default arguments: align_corners=False, half_pixel_centers=False
  209. # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
  210. # size=(x.shape[1] * 2, x.shape[2] * 2))
  211. def call(self, inputs):
  212. return self.upsample(inputs)
  213. class TFConcat(keras.layers.Layer):
  214. def __init__(self, dimension=1, w=None):
  215. super().__init__()
  216. assert dimension == 1, "convert only NCHW to NHWC concat"
  217. self.d = 3
  218. def call(self, inputs):
  219. return tf.concat(inputs, self.d)
  220. def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
  221. LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
  222. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  223. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  224. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  225. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  226. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  227. m_str = m
  228. m = eval(m) if isinstance(m, str) else m # eval strings
  229. for j, a in enumerate(args):
  230. try:
  231. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  232. except NameError:
  233. pass
  234. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  235. if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  236. c1, c2 = ch[f], args[0]
  237. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  238. args = [c1, c2, *args[1:]]
  239. if m in [BottleneckCSP, C3]:
  240. args.insert(2, n)
  241. n = 1
  242. elif m is nn.BatchNorm2d:
  243. args = [ch[f]]
  244. elif m is Concat:
  245. c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
  246. elif m is Detect:
  247. args.append([ch[x + 1] for x in f])
  248. if isinstance(args[1], int): # number of anchors
  249. args[1] = [list(range(args[1] * 2))] * len(f)
  250. args.append(imgsz)
  251. else:
  252. c2 = ch[f]
  253. tf_m = eval('TF' + m_str.replace('nn.', ''))
  254. m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
  255. else tf_m(*args, w=model.model[i]) # module
  256. torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  257. t = str(m)[8:-2].replace('__main__.', '') # module type
  258. np = sum(x.numel() for x in torch_m_.parameters()) # number params
  259. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  260. LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print
  261. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  262. layers.append(m_)
  263. ch.append(c2)
  264. return keras.Sequential(layers), sorted(save)
  265. class TFModel:
  266. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
  267. super().__init__()
  268. if isinstance(cfg, dict):
  269. self.yaml = cfg # model dict
  270. else: # is *.yaml
  271. import yaml # for torch hub
  272. self.yaml_file = Path(cfg).name
  273. with open(cfg) as f:
  274. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  275. # Define model
  276. if nc and nc != self.yaml['nc']:
  277. LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
  278. self.yaml['nc'] = nc # override yaml value
  279. self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
  280. def predict(self, inputs, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
  281. conf_thres=0.25):
  282. y = [] # outputs
  283. x = inputs
  284. for i, m in enumerate(self.model.layers):
  285. if m.f != -1: # if not from previous layer
  286. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  287. x = m(x) # run
  288. y.append(x if m.i in self.savelist else None) # save output
  289. # Add TensorFlow NMS
  290. if tf_nms:
  291. boxes = self._xywh2xyxy(x[0][..., :4])
  292. probs = x[0][:, :, 4:5]
  293. classes = x[0][:, :, 5:]
  294. scores = probs * classes
  295. if agnostic_nms:
  296. nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
  297. return nms, x[1]
  298. else:
  299. boxes = tf.expand_dims(boxes, 2)
  300. nms = tf.image.combined_non_max_suppression(
  301. boxes, scores, topk_per_class, topk_all, iou_thres, conf_thres, clip_boxes=False)
  302. return nms, x[1]
  303. return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
  304. # x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
  305. # xywh = x[..., :4] # x(6300,4) boxes
  306. # conf = x[..., 4:5] # x(6300,1) confidences
  307. # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
  308. # return tf.concat([conf, cls, xywh], 1)
  309. @staticmethod
  310. def _xywh2xyxy(xywh):
  311. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  312. x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
  313. return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
  314. class AgnosticNMS(keras.layers.Layer):
  315. # TF Agnostic NMS
  316. def call(self, input, topk_all, iou_thres, conf_thres):
  317. # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
  318. return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input,
  319. fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
  320. name='agnostic_nms')
  321. @staticmethod
  322. def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
  323. boxes, classes, scores = x
  324. class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
  325. scores_inp = tf.reduce_max(scores, -1)
  326. selected_inds = tf.image.non_max_suppression(
  327. boxes, scores_inp, max_output_size=topk_all, iou_threshold=iou_thres, score_threshold=conf_thres)
  328. selected_boxes = tf.gather(boxes, selected_inds)
  329. padded_boxes = tf.pad(selected_boxes,
  330. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
  331. mode="CONSTANT", constant_values=0.0)
  332. selected_scores = tf.gather(scores_inp, selected_inds)
  333. padded_scores = tf.pad(selected_scores,
  334. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  335. mode="CONSTANT", constant_values=-1.0)
  336. selected_classes = tf.gather(class_inds, selected_inds)
  337. padded_classes = tf.pad(selected_classes,
  338. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  339. mode="CONSTANT", constant_values=-1.0)
  340. valid_detections = tf.shape(selected_inds)[0]
  341. return padded_boxes, padded_scores, padded_classes, valid_detections
  342. def representative_dataset_gen(dataset, ncalib=100):
  343. # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
  344. for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
  345. input = np.transpose(img, [1, 2, 0])
  346. input = np.expand_dims(input, axis=0).astype(np.float32)
  347. input /= 255
  348. yield [input]
  349. if n >= ncalib:
  350. break
  351. def run(weights=ROOT / 'yolov5s.pt', # weights path
  352. imgsz=(640, 640), # inference size h,w
  353. batch_size=1, # batch size
  354. dynamic=False, # dynamic batch size
  355. ):
  356. # PyTorch model
  357. im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
  358. model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
  359. y = model(im) # inference
  360. model.info()
  361. # TensorFlow model
  362. im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
  363. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  364. y = tf_model.predict(im) # inference
  365. # Keras model
  366. im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
  367. keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
  368. keras_model.summary()
  369. LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
  370. def parse_opt():
  371. parser = argparse.ArgumentParser()
  372. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
  373. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
  374. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  375. parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
  376. opt = parser.parse_args()
  377. opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
  378. print_args(FILE.stem, opt)
  379. return opt
  380. def main(opt):
  381. run(**vars(opt))
  382. if __name__ == "__main__":
  383. opt = parse_opt()
  384. main(opt)