You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

559 lines
26KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. TensorFlow/Keras and TFLite versions of YOLOv5
  4. Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
  5. Usage:
  6. $ python models/tf.py --weights yolov5s.pt --cfg yolov5s.yaml
  7. Export int8 TFLite models:
  8. $ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --tfl-int8 \
  9. --source path/to/images/ --ncalib 100
  10. Detection:
  11. $ python detect.py --weights yolov5s.pb --img 320
  12. $ python detect.py --weights yolov5s_saved_model --img 320
  13. $ python detect.py --weights yolov5s-fp16.tflite --img 320
  14. $ python detect.py --weights yolov5s-int8.tflite --img 320 --tfl-int8
  15. For TensorFlow.js:
  16. $ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --img 320 --tf-nms --agnostic-nms
  17. $ pip install tensorflowjs
  18. $ tensorflowjs_converter \
  19. --input_format=tf_frozen_model \
  20. --output_node_names='Identity,Identity_1,Identity_2,Identity_3' \
  21. yolov5s.pb \
  22. web_model
  23. $ # Edit web_model/model.json to sort Identity* in ascending order
  24. $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
  25. $ npm install
  26. $ ln -s ../../yolov5/web_model public/web_model
  27. $ npm start
  28. """
  29. import argparse
  30. import logging
  31. import os
  32. import sys
  33. import traceback
  34. from copy import deepcopy
  35. from pathlib import Path
  36. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  37. import numpy as np
  38. import tensorflow as tf
  39. import torch
  40. import torch.nn as nn
  41. import yaml
  42. from tensorflow import keras
  43. from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
  44. from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
  45. from models.experimental import MixConv2d, CrossConv, attempt_load
  46. from models.yolo import Detect
  47. from utils.datasets import LoadImages
  48. from utils.general import make_divisible, check_file, check_dataset
  49. logger = logging.getLogger(__name__)
  50. class tf_BN(keras.layers.Layer):
  51. # TensorFlow BatchNormalization wrapper
  52. def __init__(self, w=None):
  53. super(tf_BN, self).__init__()
  54. self.bn = keras.layers.BatchNormalization(
  55. beta_initializer=keras.initializers.Constant(w.bias.numpy()),
  56. gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
  57. moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
  58. moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
  59. epsilon=w.eps)
  60. def call(self, inputs):
  61. return self.bn(inputs)
  62. class tf_Pad(keras.layers.Layer):
  63. def __init__(self, pad):
  64. super(tf_Pad, self).__init__()
  65. self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
  66. def call(self, inputs):
  67. return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
  68. class tf_Conv(keras.layers.Layer):
  69. # Standard convolution
  70. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  71. # ch_in, ch_out, weights, kernel, stride, padding, groups
  72. super(tf_Conv, self).__init__()
  73. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  74. assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
  75. # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
  76. # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
  77. conv = keras.layers.Conv2D(
  78. c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False,
  79. kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()))
  80. self.conv = conv if s == 1 else keras.Sequential([tf_Pad(autopad(k, p)), conv])
  81. self.bn = tf_BN(w.bn) if hasattr(w, 'bn') else tf.identity
  82. # YOLOv5 activations
  83. if isinstance(w.act, nn.LeakyReLU):
  84. self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
  85. elif isinstance(w.act, nn.Hardswish):
  86. self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
  87. elif isinstance(w.act, nn.SiLU):
  88. self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
  89. def call(self, inputs):
  90. return self.act(self.bn(self.conv(inputs)))
  91. class tf_Focus(keras.layers.Layer):
  92. # Focus wh information into c-space
  93. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
  94. # ch_in, ch_out, kernel, stride, padding, groups
  95. super(tf_Focus, self).__init__()
  96. self.conv = tf_Conv(c1 * 4, c2, k, s, p, g, act, w.conv)
  97. def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
  98. # inputs = inputs / 255. # normalize 0-255 to 0-1
  99. return self.conv(tf.concat([inputs[:, ::2, ::2, :],
  100. inputs[:, 1::2, ::2, :],
  101. inputs[:, ::2, 1::2, :],
  102. inputs[:, 1::2, 1::2, :]], 3))
  103. class tf_Bottleneck(keras.layers.Layer):
  104. # Standard bottleneck
  105. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
  106. super(tf_Bottleneck, self).__init__()
  107. c_ = int(c2 * e) # hidden channels
  108. self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
  109. self.cv2 = tf_Conv(c_, c2, 3, 1, g=g, w=w.cv2)
  110. self.add = shortcut and c1 == c2
  111. def call(self, inputs):
  112. return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
  113. class tf_Conv2d(keras.layers.Layer):
  114. # Substitution for PyTorch nn.Conv2D
  115. def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
  116. super(tf_Conv2d, self).__init__()
  117. assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
  118. self.conv = keras.layers.Conv2D(
  119. c2, k, s, 'VALID', use_bias=bias,
  120. kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
  121. bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )
  122. def call(self, inputs):
  123. return self.conv(inputs)
  124. class tf_BottleneckCSP(keras.layers.Layer):
  125. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  126. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  127. # ch_in, ch_out, number, shortcut, groups, expansion
  128. super(tf_BottleneckCSP, self).__init__()
  129. c_ = int(c2 * e) # hidden channels
  130. self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
  131. self.cv2 = tf_Conv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
  132. self.cv3 = tf_Conv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
  133. self.cv4 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv4)
  134. self.bn = tf_BN(w.bn)
  135. self.act = lambda x: keras.activations.relu(x, alpha=0.1)
  136. self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  137. def call(self, inputs):
  138. y1 = self.cv3(self.m(self.cv1(inputs)))
  139. y2 = self.cv2(inputs)
  140. return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
  141. class tf_C3(keras.layers.Layer):
  142. # CSP Bottleneck with 3 convolutions
  143. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
  144. # ch_in, ch_out, number, shortcut, groups, expansion
  145. super(tf_C3, self).__init__()
  146. c_ = int(c2 * e) # hidden channels
  147. self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
  148. self.cv2 = tf_Conv(c1, c_, 1, 1, w=w.cv2)
  149. self.cv3 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv3)
  150. self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
  151. def call(self, inputs):
  152. return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
  153. class tf_SPP(keras.layers.Layer):
  154. # Spatial pyramid pooling layer used in YOLOv3-SPP
  155. def __init__(self, c1, c2, k=(5, 9, 13), w=None):
  156. super(tf_SPP, self).__init__()
  157. c_ = c1 // 2 # hidden channels
  158. self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
  159. self.cv2 = tf_Conv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
  160. self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
  161. def call(self, inputs):
  162. x = self.cv1(inputs)
  163. return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
  164. class tf_Detect(keras.layers.Layer):
  165. def __init__(self, nc=80, anchors=(), ch=(), w=None): # detection layer
  166. super(tf_Detect, self).__init__()
  167. self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
  168. self.nc = nc # number of classes
  169. self.no = nc + 5 # number of outputs per anchor
  170. self.nl = len(anchors) # number of detection layers
  171. self.na = len(anchors[0]) // 2 # number of anchors
  172. self.grid = [tf.zeros(1)] * self.nl # init grid
  173. self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
  174. self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32),
  175. [self.nl, 1, -1, 1, 2])
  176. self.m = [tf_Conv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
  177. self.export = False # onnx export
  178. self.training = True # set to False after building model
  179. for i in range(self.nl):
  180. ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i]
  181. self.grid[i] = self._make_grid(nx, ny)
  182. def call(self, inputs):
  183. # x = x.copy() # for profiling
  184. z = [] # inference output
  185. self.training |= self.export
  186. x = []
  187. for i in range(self.nl):
  188. x.append(self.m[i](inputs[i]))
  189. # x(bs,20,20,255) to x(bs,3,20,20,85)
  190. ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i]
  191. x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])
  192. if not self.training: # inference
  193. y = tf.sigmoid(x[i])
  194. xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  195. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
  196. # Normalize xywh to 0-1 to reduce calibration error
  197. xy /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32)
  198. wh /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32)
  199. y = tf.concat([xy, wh, y[..., 4:]], -1)
  200. z.append(tf.reshape(y, [-1, 3 * ny * nx, self.no]))
  201. return x if self.training else (tf.concat(z, 1), x)
  202. @staticmethod
  203. def _make_grid(nx=20, ny=20):
  204. # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  205. # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  206. xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
  207. return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
  208. class tf_Upsample(keras.layers.Layer):
  209. def __init__(self, size, scale_factor, mode, w=None):
  210. super(tf_Upsample, self).__init__()
  211. assert scale_factor == 2, "scale_factor must be 2"
  212. # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
  213. if opt.tf_raw_resize:
  214. # with default arguments: align_corners=False, half_pixel_centers=False
  215. self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
  216. size=(x.shape[1] * 2, x.shape[2] * 2))
  217. else:
  218. self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
  219. def call(self, inputs):
  220. return self.upsample(inputs)
  221. class tf_Concat(keras.layers.Layer):
  222. def __init__(self, dimension=1, w=None):
  223. super(tf_Concat, self).__init__()
  224. assert dimension == 1, "convert only NCHW to NHWC concat"
  225. self.d = 3
  226. def call(self, inputs):
  227. return tf.concat(inputs, self.d)
  228. def parse_model(d, ch, model): # model_dict, input_channels(3)
  229. logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  230. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  231. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  232. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  233. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  234. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  235. m_str = m
  236. m = eval(m) if isinstance(m, str) else m # eval strings
  237. for j, a in enumerate(args):
  238. try:
  239. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  240. except:
  241. pass
  242. n = max(round(n * gd), 1) if n > 1 else n # depth gain
  243. if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
  244. c1, c2 = ch[f], args[0]
  245. c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
  246. args = [c1, c2, *args[1:]]
  247. if m in [BottleneckCSP, C3]:
  248. args.insert(2, n)
  249. n = 1
  250. elif m is nn.BatchNorm2d:
  251. args = [ch[f]]
  252. elif m is Concat:
  253. c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
  254. elif m is Detect:
  255. args.append([ch[x + 1] for x in f])
  256. if isinstance(args[1], int): # number of anchors
  257. args[1] = [list(range(args[1] * 2))] * len(f)
  258. else:
  259. c2 = ch[f]
  260. tf_m = eval('tf_' + m_str.replace('nn.', ''))
  261. m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
  262. else tf_m(*args, w=model.model[i]) # module
  263. torch_m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  264. t = str(m)[8:-2].replace('__main__.', '') # module type
  265. np = sum([x.numel() for x in torch_m_.parameters()]) # number params
  266. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  267. logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
  268. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  269. layers.append(m_)
  270. ch.append(c2)
  271. return keras.Sequential(layers), sorted(save)
  272. class tf_Model():
  273. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None): # model, input channels, number of classes
  274. super(tf_Model, self).__init__()
  275. if isinstance(cfg, dict):
  276. self.yaml = cfg # model dict
  277. else: # is *.yaml
  278. import yaml # for torch hub
  279. self.yaml_file = Path(cfg).name
  280. with open(cfg) as f:
  281. self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
  282. # Define model
  283. if nc and nc != self.yaml['nc']:
  284. print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
  285. self.yaml['nc'] = nc # override yaml value
  286. self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model) # model, savelist, ch_out
  287. def predict(self, inputs, profile=False):
  288. y = [] # outputs
  289. x = inputs
  290. for i, m in enumerate(self.model.layers):
  291. if m.f != -1: # if not from previous layer
  292. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  293. x = m(x) # run
  294. y.append(x if m.i in self.savelist else None) # save output
  295. # Add TensorFlow NMS
  296. if opt.tf_nms:
  297. boxes = xywh2xyxy(x[0][..., :4])
  298. probs = x[0][:, :, 4:5]
  299. classes = x[0][:, :, 5:]
  300. scores = probs * classes
  301. if opt.agnostic_nms:
  302. nms = agnostic_nms_layer()((boxes, classes, scores))
  303. return nms, x[1]
  304. else:
  305. boxes = tf.expand_dims(boxes, 2)
  306. nms = tf.image.combined_non_max_suppression(
  307. boxes, scores, opt.topk_per_class, opt.topk_all, opt.iou_thres, opt.score_thres, clip_boxes=False)
  308. return nms, x[1]
  309. return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
  310. # x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
  311. # xywh = x[..., :4] # x(6300,4) boxes
  312. # conf = x[..., 4:5] # x(6300,1) confidences
  313. # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
  314. # return tf.concat([conf, cls, xywh], 1)
  315. class agnostic_nms_layer(keras.layers.Layer):
  316. # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
  317. def call(self, input):
  318. return tf.map_fn(agnostic_nms, input,
  319. fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
  320. name='agnostic_nms')
  321. def agnostic_nms(x):
  322. boxes, classes, scores = x
  323. class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
  324. scores_inp = tf.reduce_max(scores, -1)
  325. selected_inds = tf.image.non_max_suppression(
  326. boxes, scores_inp, max_output_size=opt.topk_all, iou_threshold=opt.iou_thres, score_threshold=opt.score_thres)
  327. selected_boxes = tf.gather(boxes, selected_inds)
  328. padded_boxes = tf.pad(selected_boxes,
  329. paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
  330. mode="CONSTANT", constant_values=0.0)
  331. selected_scores = tf.gather(scores_inp, selected_inds)
  332. padded_scores = tf.pad(selected_scores,
  333. paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]],
  334. mode="CONSTANT", constant_values=-1.0)
  335. selected_classes = tf.gather(class_inds, selected_inds)
  336. padded_classes = tf.pad(selected_classes,
  337. paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]],
  338. mode="CONSTANT", constant_values=-1.0)
  339. valid_detections = tf.shape(selected_inds)[0]
  340. return padded_boxes, padded_scores, padded_classes, valid_detections
  341. def xywh2xyxy(xywh):
  342. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  343. x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
  344. return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
  345. def representative_dataset_gen():
  346. # Representative dataset for use with converter.representative_dataset
  347. n = 0
  348. for path, img, im0s, vid_cap in dataset:
  349. # Get sample input data as a numpy array in a method of your choosing.
  350. n += 1
  351. input = np.transpose(img, [1, 2, 0])
  352. input = np.expand_dims(input, axis=0).astype(np.float32)
  353. input /= 255.0
  354. yield [input]
  355. if n >= opt.ncalib:
  356. break
  357. if __name__ == "__main__":
  358. parser = argparse.ArgumentParser()
  359. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='cfg path')
  360. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='weights path')
  361. parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size') # height, width
  362. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  363. parser.add_argument('--dynamic-batch-size', action='store_true', help='dynamic batch size')
  364. parser.add_argument('--source', type=str, default='../data/coco128.yaml', help='dir of images or data.yaml file')
  365. parser.add_argument('--ncalib', type=int, default=100, help='number of calibration images')
  366. parser.add_argument('--tfl-int8', action='store_true', dest='tfl_int8', help='export TFLite int8 model')
  367. parser.add_argument('--tf-nms', action='store_true', dest='tf_nms', help='TF NMS (without TFLite export)')
  368. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  369. parser.add_argument('--tf-raw-resize', action='store_true', dest='tf_raw_resize',
  370. help='use tf.raw_ops.ResizeNearestNeighbor for resize')
  371. parser.add_argument('--topk-per-class', type=int, default=100, help='topk per class to keep in NMS')
  372. parser.add_argument('--topk-all', type=int, default=100, help='topk for all classes to keep in NMS')
  373. parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
  374. parser.add_argument('--score-thres', type=float, default=0.4, help='score threshold for NMS')
  375. opt = parser.parse_args()
  376. opt.cfg = check_file(opt.cfg) # check file
  377. opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
  378. print(opt)
  379. # Input
  380. img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection
  381. # Load PyTorch model
  382. model = attempt_load(opt.weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
  383. model.model[-1].export = False # set Detect() layer export=True
  384. y = model(img) # dry run
  385. nc = y[0].shape[-1] - 5
  386. # TensorFlow saved_model export
  387. try:
  388. print('\nStarting TensorFlow saved_model export with TensorFlow %s...' % tf.__version__)
  389. tf_model = tf_Model(opt.cfg, model=model, nc=nc)
  390. img = tf.zeros((opt.batch_size, *opt.img_size, 3)) # NHWC Input for TensorFlow
  391. m = tf_model.model.layers[-1]
  392. assert isinstance(m, tf_Detect), "the last layer must be Detect"
  393. m.training = False
  394. y = tf_model.predict(img)
  395. inputs = keras.Input(shape=(*opt.img_size, 3), batch_size=None if opt.dynamic_batch_size else opt.batch_size)
  396. keras_model = keras.Model(inputs=inputs, outputs=tf_model.predict(inputs))
  397. keras_model.summary()
  398. path = opt.weights.replace('.pt', '_saved_model') # filename
  399. keras_model.save(path, save_format='tf')
  400. print('TensorFlow saved_model export success, saved as %s' % path)
  401. except Exception as e:
  402. print('TensorFlow saved_model export failure: %s' % e)
  403. traceback.print_exc(file=sys.stdout)
  404. # TensorFlow GraphDef export
  405. try:
  406. print('\nStarting TensorFlow GraphDef export with TensorFlow %s...' % tf.__version__)
  407. # https://github.com/leimao/Frozen_Graph_TensorFlow
  408. full_model = tf.function(lambda x: keras_model(x))
  409. full_model = full_model.get_concrete_function(
  410. tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
  411. frozen_func = convert_variables_to_constants_v2(full_model)
  412. frozen_func.graph.as_graph_def()
  413. f = opt.weights.replace('.pt', '.pb') # filename
  414. tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
  415. logdir=os.path.dirname(f),
  416. name=os.path.basename(f),
  417. as_text=False)
  418. print('TensorFlow GraphDef export success, saved as %s' % f)
  419. except Exception as e:
  420. print('TensorFlow GraphDef export failure: %s' % e)
  421. traceback.print_exc(file=sys.stdout)
  422. # TFLite model export
  423. if not opt.tf_nms:
  424. try:
  425. print('\nStarting TFLite export with TensorFlow %s...' % tf.__version__)
  426. # fp32 TFLite model export ---------------------------------------------------------------------------------
  427. # converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
  428. # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
  429. # converter.allow_custom_ops = False
  430. # converter.experimental_new_converter = True
  431. # tflite_model = converter.convert()
  432. # f = opt.weights.replace('.pt', '.tflite') # filename
  433. # open(f, "wb").write(tflite_model)
  434. # fp16 TFLite model export ---------------------------------------------------------------------------------
  435. converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
  436. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  437. # converter.representative_dataset = representative_dataset_gen
  438. # converter.target_spec.supported_types = [tf.float16]
  439. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
  440. converter.allow_custom_ops = False
  441. converter.experimental_new_converter = True
  442. tflite_model = converter.convert()
  443. f = opt.weights.replace('.pt', '-fp16.tflite') # filename
  444. open(f, "wb").write(tflite_model)
  445. print('\nTFLite export success, saved as %s' % f)
  446. # int8 TFLite model export ---------------------------------------------------------------------------------
  447. if opt.tfl_int8:
  448. # Representative Dataset
  449. if opt.source.endswith('.yaml'):
  450. with open(check_file(opt.source)) as f:
  451. data = yaml.load(f, Loader=yaml.FullLoader) # data dict
  452. check_dataset(data) # check
  453. opt.source = data['train']
  454. dataset = LoadImages(opt.source, img_size=opt.img_size, auto=False)
  455. converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
  456. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  457. converter.representative_dataset = representative_dataset_gen
  458. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  459. converter.inference_input_type = tf.uint8 # or tf.int8
  460. converter.inference_output_type = tf.uint8 # or tf.int8
  461. converter.allow_custom_ops = False
  462. converter.experimental_new_converter = True
  463. converter.experimental_new_quantizer = False
  464. tflite_model = converter.convert()
  465. f = opt.weights.replace('.pt', '-int8.tflite') # filename
  466. open(f, "wb").write(tflite_model)
  467. print('\nTFLite (int8) export success, saved as %s' % f)
  468. except Exception as e:
  469. print('\nTFLite export failure: %s' % e)
  470. traceback.print_exc(file=sys.stdout)