You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

337 lines
15KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Export a YOLOv5 PyTorch model to TorchScript, ONNX, CoreML, TensorFlow (saved_model, pb, TFLite, TF.js,) formats
  4. TensorFlow exports authored by https://github.com/zldrobit
  5. Usage:
  6. $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs
  7. Inference:
  8. $ python path/to/detect.py --weights yolov5s.pt
  9. yolov5s.onnx (must export with --dynamic)
  10. yolov5s_saved_model
  11. yolov5s.pb
  12. yolov5s.tflite
  13. TensorFlow.js:
  14. $ # Edit yolov5s_web_model/model.json to sort Identity* in ascending order
  15. $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
  16. $ npm install
  17. $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
  18. $ npm start
  19. """
  20. import argparse
  21. import subprocess
  22. import sys
  23. import time
  24. from pathlib import Path
  25. import torch
  26. import torch.nn as nn
  27. from torch.utils.mobile_optimizer import optimize_for_mobile
  28. FILE = Path(__file__).resolve()
  29. ROOT = FILE.parents[0] # yolov5/ dir
  30. if str(ROOT) not in sys.path:
  31. sys.path.append(str(ROOT)) # add ROOT to PATH
  32. from models.common import Conv
  33. from models.experimental import attempt_load
  34. from models.yolo import Detect
  35. from utils.activations import SiLU
  36. from utils.datasets import LoadImages
  37. from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, set_logging, url2file
  38. from utils.torch_utils import select_device
  39. def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
  40. # YOLOv5 TorchScript model export
  41. try:
  42. print(f'\n{prefix} starting export with torch {torch.__version__}...')
  43. f = file.with_suffix('.torchscript.pt')
  44. ts = torch.jit.trace(model, im, strict=False)
  45. (optimize_for_mobile(ts) if optimize else ts).save(f)
  46. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  47. except Exception as e:
  48. print(f'{prefix} export failure: {e}')
  49. def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
  50. # YOLOv5 ONNX export
  51. try:
  52. check_requirements(('onnx',))
  53. import onnx
  54. print(f'\n{prefix} starting export with onnx {onnx.__version__}...')
  55. f = file.with_suffix('.onnx')
  56. torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
  57. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  58. do_constant_folding=not train,
  59. input_names=['images'],
  60. output_names=['output'],
  61. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  62. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  63. } if dynamic else None)
  64. # Checks
  65. model_onnx = onnx.load(f) # load onnx model
  66. onnx.checker.check_model(model_onnx) # check onnx model
  67. # print(onnx.helper.printable_graph(model_onnx.graph)) # print
  68. # Simplify
  69. if simplify:
  70. try:
  71. check_requirements(('onnx-simplifier',))
  72. import onnxsim
  73. print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
  74. model_onnx, check = onnxsim.simplify(
  75. model_onnx,
  76. dynamic_input_shape=dynamic,
  77. input_shapes={'images': list(im.shape)} if dynamic else None)
  78. assert check, 'assert check failed'
  79. onnx.save(model_onnx, f)
  80. except Exception as e:
  81. print(f'{prefix} simplifier failure: {e}')
  82. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  83. print(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'")
  84. except Exception as e:
  85. print(f'{prefix} export failure: {e}')
  86. def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
  87. # YOLOv5 CoreML export
  88. ct_model = None
  89. try:
  90. check_requirements(('coremltools',))
  91. import coremltools as ct
  92. print(f'\n{prefix} starting export with coremltools {ct.__version__}...')
  93. f = file.with_suffix('.mlmodel')
  94. model.train() # CoreML exports should be placed in model.train() mode
  95. ts = torch.jit.trace(model, im, strict=False) # TorchScript model
  96. ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255.0, bias=[0, 0, 0])])
  97. ct_model.save(f)
  98. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  99. except Exception as e:
  100. print(f'\n{prefix} export failure: {e}')
  101. return ct_model
  102. def export_saved_model(model, im, file, dynamic,
  103. tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
  104. conf_thres=0.25, prefix=colorstr('TensorFlow saved_model:')):
  105. # YOLOv5 TensorFlow saved_model export
  106. keras_model = None
  107. try:
  108. import tensorflow as tf
  109. from tensorflow import keras
  110. from models.tf import TFModel, TFDetect
  111. print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  112. f = str(file).replace('.pt', '_saved_model')
  113. batch_size, ch, *imgsz = list(im.shape) # BCHW
  114. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  115. im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
  116. y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
  117. inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
  118. outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
  119. keras_model = keras.Model(inputs=inputs, outputs=outputs)
  120. keras_model.trainable = False
  121. keras_model.summary()
  122. keras_model.save(f, save_format='tf')
  123. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  124. except Exception as e:
  125. print(f'\n{prefix} export failure: {e}')
  126. return keras_model
  127. def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
  128. # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
  129. try:
  130. import tensorflow as tf
  131. from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
  132. print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  133. f = file.with_suffix('.pb')
  134. m = tf.function(lambda x: keras_model(x)) # full model
  135. m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
  136. frozen_func = convert_variables_to_constants_v2(m)
  137. frozen_func.graph.as_graph_def()
  138. tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
  139. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  140. except Exception as e:
  141. print(f'\n{prefix} export failure: {e}')
  142. def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')):
  143. # YOLOv5 TensorFlow Lite export
  144. try:
  145. import tensorflow as tf
  146. from models.tf import representative_dataset_gen
  147. print(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  148. batch_size, ch, *imgsz = list(im.shape) # BCHW
  149. f = str(file).replace('.pt', '-fp16.tflite')
  150. converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
  151. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
  152. converter.target_spec.supported_types = [tf.float16]
  153. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  154. if int8:
  155. dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
  156. converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib)
  157. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  158. converter.target_spec.supported_types = []
  159. converter.inference_input_type = tf.uint8 # or tf.int8
  160. converter.inference_output_type = tf.uint8 # or tf.int8
  161. converter.experimental_new_quantizer = False
  162. f = str(file).replace('.pt', '-int8.tflite')
  163. tflite_model = converter.convert()
  164. open(f, "wb").write(tflite_model)
  165. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  166. except Exception as e:
  167. print(f'\n{prefix} export failure: {e}')
  168. def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
  169. # YOLOv5 TensorFlow.js export
  170. try:
  171. check_requirements(('tensorflowjs',))
  172. import tensorflowjs as tfjs
  173. print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
  174. f = str(file).replace('.pt', '_web_model') # js dir
  175. f_pb = file.with_suffix('.pb') # *.pb path
  176. cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
  177. f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
  178. subprocess.run(cmd, shell=True)
  179. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  180. except Exception as e:
  181. print(f'\n{prefix} export failure: {e}')
  182. @torch.no_grad()
  183. def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
  184. weights=ROOT / 'yolov5s.pt', # weights path
  185. imgsz=(640, 640), # image (height, width)
  186. batch_size=1, # batch size
  187. device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  188. include=('torchscript', 'onnx', 'coreml'), # include formats
  189. half=False, # FP16 half-precision export
  190. inplace=False, # set YOLOv5 Detect() inplace=True
  191. train=False, # model.train() mode
  192. optimize=False, # TorchScript: optimize for mobile
  193. int8=False, # CoreML/TF INT8 quantization
  194. dynamic=False, # ONNX/TF: dynamic axes
  195. simplify=False, # ONNX: simplify model
  196. opset=12, # ONNX: opset version
  197. ):
  198. t = time.time()
  199. include = [x.lower() for x in include]
  200. tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports
  201. imgsz *= 2 if len(imgsz) == 1 else 1 # expand
  202. file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)
  203. # Load PyTorch model
  204. device = select_device(device)
  205. assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
  206. model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model
  207. nc, names = model.nc, model.names # number of classes, class names
  208. # Input
  209. gs = int(max(model.stride)) # grid size (max stride)
  210. imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
  211. im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
  212. # Update model
  213. if half:
  214. im, model = im.half(), model.half() # to FP16
  215. model.train() if train else model.eval() # training mode = no Detect() layer grid construction
  216. for k, m in model.named_modules():
  217. if isinstance(m, Conv): # assign export-friendly activations
  218. if isinstance(m.act, nn.SiLU):
  219. m.act = SiLU()
  220. elif isinstance(m, Detect):
  221. m.inplace = inplace
  222. m.onnx_dynamic = dynamic
  223. # m.forward = m.forward_export # assign forward (optional)
  224. for _ in range(2):
  225. y = model(im) # dry runs
  226. print(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")
  227. # Exports
  228. if 'torchscript' in include:
  229. export_torchscript(model, im, file, optimize)
  230. if 'onnx' in include:
  231. export_onnx(model, im, file, opset, train, dynamic, simplify)
  232. if 'coreml' in include:
  233. export_coreml(model, im, file)
  234. # TensorFlow Exports
  235. if any(tf_exports):
  236. pb, tflite, tfjs = tf_exports[1:]
  237. assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
  238. model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs) # keras model
  239. if pb or tfjs: # pb prerequisite to tfjs
  240. export_pb(model, im, file)
  241. if tflite:
  242. export_tflite(model, im, file, int8=int8, data=data, ncalib=100)
  243. if tfjs:
  244. export_tfjs(model, im, file)
  245. # Finish
  246. print(f'\nExport complete ({time.time() - t:.2f}s)'
  247. f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
  248. f'\nVisualize with https://netron.app')
  249. def parse_opt():
  250. parser = argparse.ArgumentParser()
  251. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  252. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path')
  253. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
  254. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  255. parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  256. parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
  257. parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
  258. parser.add_argument('--train', action='store_true', help='model.train() mode')
  259. parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
  260. parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
  261. parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
  262. parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
  263. parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
  264. parser.add_argument('--include', nargs='+',
  265. default=['torchscript', 'onnx'],
  266. help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
  267. opt = parser.parse_args()
  268. return opt
  269. def main(opt):
  270. set_logging()
  271. print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  272. run(**vars(opt))
  273. if __name__ == "__main__":
  274. opt = parse_opt()
  275. main(opt)