Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

451 lignes
20KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. TensorFlow, Keras and TFLite versions of YOLOv5
  4. Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
  5. Usage:
  6. $ python models/tf.py --weights yolov5s.pt
  7. Export:
  8. $ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
  9. """
  10. import argparse
  11. import logging
  12. import sys
  13. from copy import deepcopy
  14. from pathlib import Path
  15. FILE = Path(__file__).resolve()
  16. ROOT = FILE.parents[1] # YOLOv5 root directory
  17. if str(ROOT) not in sys.path:
  18. sys.path.append(str(ROOT)) # add ROOT to PATH
  19. # ROOT = ROOT.relative_to(Path.cwd()) # relative
  20. import numpy as np
  21. import tensorflow as tf
  22. import torch
  23. import torch.nn as nn
  24. from tensorflow import keras
  25. from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
  26. from models.experimental import CrossConv, MixConv2d, attempt_load
  27. from models.yolo import Detect
  28. from utils.general import make_divisible, print_args, set_logging
  29. from utils.activations import SiLU
  30. LOGGER = logging.getLogger(__name__)
  31. class TFBN(keras.layers.Layer):
  32. # TensorFlow BatchNormalization wrapper
  33. def __init__(self, w=None):
  34. super(TFBN, self).__init__()
  35. self.bn = keras.layers.BatchNormalization(
  36. beta_initializer=keras.initializers.Constant(w.bias.numpy()),
  37. gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
  38. moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
  39. moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
  40. epsilon=w.eps)
  41. def call(self, inputs):
  42. return self.bn(inputs)
  43. class TFPad(keras.layers.Layer):
  44. def __init__(self, pad):
  45. super(TFPad, self).__init__()
  46. self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
  47. def call(self, inputs):
  48. return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
  49. class TFConv(keras.layers.Layer):
  50. # Standard convolution
  51. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  52. # ch_in, ch_out, weights, kernel, stride, padding, groups
  53. super(TFConv, self).__init__()
  54. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  55. assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
  56. # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
  57. # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
  58. conv = keras.layers.Conv2D(
  59. c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False if hasattr(w, 'bn') else True,
  60. kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
  61. bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
  62. self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
  63. self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
  64. # YOLOv5 activations
  65. if isinstance(w.act, nn.LeakyReLU):
  66. self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
  67. elif isinstance(w.act, nn.Hardswish):
  68. self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
  69. elif isinstance(w.act, (nn.SiLU, SiLU)):
  70. self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
  71. else:
  72. raise Exception(f'no matching TensorFlow activation found for {w.act}')
  73. def call(self, inputs):
  74. return self.act(self.bn(self.conv(inputs)))
  75. class TFFocus(keras.layers.Layer):
  76. # Focus wh information into c-space
  77. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  78. # ch_in, ch_out, kernel, stride, padding, groups
  79. super(TFFocus, self).__init__()
  80. self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
  81. def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
  82. # inputs = inputs / 255. # normalize 0-255 to 0-1
  83. return self.conv(tf.concat([inputs[:, ::2, ::2, :],
  84. inputs[:, 1::2, ::2, :],
  85. inputs[:, ::2, 1::2, :],
  86. inputs[:, 1::2, 1::2, :]], 3))
  87. class TFBottleneck(keras.layers.Layer):
  88. # Standard bottleneck
  89. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
  90. super(TFBottleneck, self).__init__()
  91. c_ = int(c2 * e) # hidden channels
  92. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  93. self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
  94. self.add = shortcut and c1 == c2
  95. def call(self, inputs):
  96. return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
  97. class TFConv2d(keras.layers.Layer):
  98. # Substitution for PyTorch nn.Conv2D
  99. def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
  100. super(TFConv2d, self).__init__()
  101. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  102. self.conv = keras.layers.Conv2D(
  103. c2, k, s, 'VALID', use_bias=bias,
  104. kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
  105. bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )
  106. def call(self, inputs):
  107. return self.conv(inputs)
  108. class TFBottleneckCSP(keras.layers.Layer):
  109. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  110. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  111. # ch_in, ch_out, number, shortcut, groups, expansion
  112. super(TFBottleneckCSP, self).__init__()
  113. c_ = int(c2 * e) # hidden channels
  114. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  115. self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
  116. self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
  117. self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
  118. self.bn = TFBN(w.bn)
  119. self.act = lambda x: keras.activations.relu(x, alpha=0.1)
  120. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  121. def call(self, inputs):
  122. y1 = self.cv3(self.m(self.cv1(inputs)))
  123. y2 = self.cv2(inputs)
  124. return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
  125. class TFC3(keras.layers.Layer):
  126. # CSP Bottleneck with 3 convolutions
  127. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  128. # ch_in, ch_out, number, shortcut, groups, expansion
  129. super(TFC3, self).__init__()
  130. c_ = int(c2 * e) # hidden channels
  131. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  132. self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
  133. self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
  134. self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  135. def call(self, inputs):
  136. return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
  137. class TFSPP(keras.layers.Layer):
  138. # Spatial pyramid pooling layer used in YOLOv3-SPP
  139. def __init__(self, c1, c2, k=(5, 9, 13), w=None):
  140. super(TFSPP, self).__init__()
  141. c_ = c1 // 2 # hidden channels
  142. self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
  143. self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
  144. self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
  145. def call(self, inputs):
  146. x = self.cv1(inputs)
  147. return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
  148. class TFDetect(keras.layers.Layer):
  149. def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
  150. super(TFDetect, self).__init__()
  151. self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
  152. self.nc = nc # number of classes
  153. self.no = nc + 5 # number of outputs per anchor
  154. self.nl = len(anchors) # number of detection layers
  155. self.na = len(anchors[0]) // 2 # number of anchors
  156. self.grid = [tf.zeros(1)] * self.nl # init grid
  157. self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
  158. self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]),
  159. [self.nl, 1, -1, 1, 2])
  160. self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
  161. self.training = False # set to False after building model
  162. self.imgsz = imgsz
  163. for i in range(self.nl):
  164. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  165. self.grid[i] = self._make_grid(nx, ny)
  166. def call(self, inputs):
  167. z = [] # inference output
  168. x = []
  169. for i in range(self.nl):
  170. x.append(self.m[i](inputs[i]))
  171. # x(bs,20,20,255) to x(bs,3,20,20,85)
  172. ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
  173. x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])
  174. if not self.training: # inference
  175. y = tf.sigmoid(x[i])
  176. xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  177. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
  178. # Normalize xywh to 0-1 to reduce calibration error
  179. xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  180. wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
  181. y = tf.concat([xy, wh, y[..., 4:]], -1)
  182. z.append(tf.reshape(y, [-1, 3 * ny * nx, self.no]))
  183. return x if self.training else (tf.concat(z, 1), x)
  184. @staticmethod
  185. def _make_grid(nx=20, ny=20):
  186. # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  187. # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  188. xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
  189. return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
  190. class TFUpsample(keras.layers.Layer):
  191. def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
  192. super(TFUpsample, self).__init__()
  193. assert scale_factor == 2, "scale_factor must be 2"
  194. self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
  195. # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
  196. # with default arguments: align_corners=False, half_pixel_centers=False
  197. # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
  198. # size=(x.shape[1] * 2, x.shape[2] * 2))
  199. def call(self, inputs):
  200. return self.upsample(inputs)
  201. class TFConcat(keras.layers.Layer):
  202. def __init__(self, dimension=1, w=None):
  203. super(TFConcat, self).__init__()
  204. assert dimension == 1, "convert only NCHW to NHWC concat"
  205. self.d = 3
  206. def call(self, inputs):
  207. return tf.concat(inputs, self.d)
  208. def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
  209. LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  210. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  211. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  212. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  213. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  214. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  215. m_str = m
  216. m = eval(m) if isinstance(m, str) else m # eval strings
  217. for j, a in enumerate(args):
  218. try:
  219. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  220. except NameError:
  221. pass
  222. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  223. if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  224. c1, c2 = ch[f], args[0]
  225. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  226. args = [c1, c2, *args[1:]]
  227. if m in [BottleneckCSP, C3]:
  228. args.insert(2, n)
  229. n = 1
  230. elif m is nn.BatchNorm2d:
  231. args = [ch[f]]
  232. elif m is Concat:
  233. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  234. elif m is Detect:
  235. args.append([ch[x + 1] for x in f])
  236. if isinstance(args[1], int): # number of anchors
  237. args[1] = [list(range(args[1] * 2))] * len(f)
  238. args.append(imgsz)
  239. else:
  240. c2 = ch[f]
  241. tf_m = eval('TF' + m_str.replace('nn.', ''))
  242. m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
  243. else tf_m(*args, w=model.model[i]) # module
  244. torch_m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  245. t = str(m)[8:-2].replace('__main__.', '') # module type
  246. np = sum([x.numel() for x in torch_m_.parameters()]) # number params
  247. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  248. LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  249. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  250. layers.append(m_)
  251. ch.append(c2)
  252. return keras.Sequential(layers), sorted(save)
  253. class TFModel:
  254. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
  255. super(TFModel, self).__init__()
  256. if isinstance(cfg, dict):
  257. self.yaml = cfg # model dict
  258. else: # is *.yaml
  259. import yaml # for torch hub
  260. self.yaml_file = Path(cfg).name
  261. with open(cfg) as f:
  262. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  263. # Define model
  264. if nc and nc != self.yaml['nc']:
  265. print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
  266. self.yaml['nc'] = nc # override yaml value
  267. self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
  268. def predict(self, inputs, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
  269. conf_thres=0.25):
  270. y = [] # outputs
  271. x = inputs
  272. for i, m in enumerate(self.model.layers):
  273. if m.f != -1: # if not from previous layer
  274. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  275. x = m(x) # run
  276. y.append(x if m.i in self.savelist else None) # save output
  277. # Add TensorFlow NMS
  278. if tf_nms:
  279. boxes = self._xywh2xyxy(x[0][..., :4])
  280. probs = x[0][:, :, 4:5]
  281. classes = x[0][:, :, 5:]
  282. scores = probs * classes
  283. if agnostic_nms:
  284. nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
  285. return nms, x[1]
  286. else:
  287. boxes = tf.expand_dims(boxes, 2)
  288. nms = tf.image.combined_non_max_suppression(
  289. boxes, scores, topk_per_class, topk_all, iou_thres, conf_thres, clip_boxes=False)
  290. return nms, x[1]
  291. return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
  292. # x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
  293. # xywh = x[..., :4] # x(6300,4) boxes
  294. # conf = x[..., 4:5] # x(6300,1) confidences
  295. # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
  296. # return tf.concat([conf, cls, xywh], 1)
  297. @staticmethod
  298. def _xywh2xyxy(xywh):
  299. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  300. x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
  301. return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
  302. class AgnosticNMS(keras.layers.Layer):
  303. # TF Agnostic NMS
  304. def call(self, input, topk_all, iou_thres, conf_thres):
  305. # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
  306. return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input,
  307. fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
  308. name='agnostic_nms')
  309. @staticmethod
  310. def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
  311. boxes, classes, scores = x
  312. class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
  313. scores_inp = tf.reduce_max(scores, -1)
  314. selected_inds = tf.image.non_max_suppression(
  315. boxes, scores_inp, max_output_size=topk_all, iou_threshold=iou_thres, score_threshold=conf_thres)
  316. selected_boxes = tf.gather(boxes, selected_inds)
  317. padded_boxes = tf.pad(selected_boxes,
  318. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
  319. mode="CONSTANT", constant_values=0.0)
  320. selected_scores = tf.gather(scores_inp, selected_inds)
  321. padded_scores = tf.pad(selected_scores,
  322. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  323. mode="CONSTANT", constant_values=-1.0)
  324. selected_classes = tf.gather(class_inds, selected_inds)
  325. padded_classes = tf.pad(selected_classes,
  326. paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
  327. mode="CONSTANT", constant_values=-1.0)
  328. valid_detections = tf.shape(selected_inds)[0]
  329. return padded_boxes, padded_scores, padded_classes, valid_detections
  330. def representative_dataset_gen(dataset, ncalib=100):
  331. # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
  332. for n, (path, img, im0s, vid_cap) in enumerate(dataset):
  333. input = np.transpose(img, [1, 2, 0])
  334. input = np.expand_dims(input, axis=0).astype(np.float32)
  335. input /= 255.0
  336. yield [input]
  337. if n >= ncalib:
  338. break
  339. def run(weights=ROOT / 'yolov5s.pt', # weights path
  340. imgsz=(640, 640), # inference size h,w
  341. batch_size=1, # batch size
  342. dynamic=False, # dynamic batch size
  343. ):
  344. # PyTorch model
  345. im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
  346. model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
  347. y = model(im) # inference
  348. model.info()
  349. # TensorFlow model
  350. im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
  351. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  352. y = tf_model.predict(im) # inference
  353. # Keras model
  354. im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
  355. keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
  356. keras_model.summary()
  357. def parse_opt():
  358. parser = argparse.ArgumentParser()
  359. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
  360. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
  361. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  362. parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
  363. opt = parser.parse_args()
  364. opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
  365. print_args(FILE.stem, opt)
  366. return opt
  367. def main(opt):
  368. set_logging()
  369. run(**vars(opt))
  370. if __name__ == "__main__":
  371. opt = parse_opt()
  372. main(opt)