You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

514 lines
25KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
  4. Format | `export.py --include` | Model
  5. --- | --- | ---
  6. PyTorch | - | yolov5s.pt
  7. TorchScript | `torchscript` | yolov5s.torchscript
  8. ONNX | `onnx` | yolov5s.onnx
  9. OpenVINO | `openvino` | yolov5s_openvino_model/
  10. TensorRT | `engine` | yolov5s.engine
  11. CoreML | `coreml` | yolov5s.mlmodel
  12. TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
  13. TensorFlow GraphDef | `pb` | yolov5s.pb
  14. TensorFlow Lite | `tflite` | yolov5s.tflite
  15. TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
  16. TensorFlow.js | `tfjs` | yolov5s_web_model/
  17. Usage:
  18. $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ...
  19. Inference:
  20. $ python path/to/detect.py --weights yolov5s.pt # PyTorch
  21. yolov5s.torchscript # TorchScript
  22. yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
  23. yolov5s.xml # OpenVINO
  24. yolov5s.engine # TensorRT
  25. yolov5s.mlmodel # CoreML (MacOS-only)
  26. yolov5s_saved_model # TensorFlow SavedModel
  27. yolov5s.pb # TensorFlow GraphDef
  28. yolov5s.tflite # TensorFlow Lite
  29. yolov5s_edgetpu.tflite # TensorFlow Edge TPU
  30. TensorFlow.js:
  31. $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
  32. $ npm install
  33. $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
  34. $ npm start
  35. """
  36. import argparse
  37. import json
  38. import os
  39. import platform
  40. import subprocess
  41. import sys
  42. import time
  43. from pathlib import Path
  44. import torch
  45. import torch.nn as nn
  46. from torch.utils.mobile_optimizer import optimize_for_mobile
  47. FILE = Path(__file__).resolve()
  48. ROOT = FILE.parents[0] # YOLOv5 root directory
  49. if str(ROOT) not in sys.path:
  50. sys.path.append(str(ROOT)) # add ROOT to PATH
  51. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  52. from models.common import Conv
  53. from models.experimental import attempt_load
  54. from models.yolo import Detect
  55. from utils.activations import SiLU
  56. from utils.datasets import LoadImages
  57. from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
  58. file_size, print_args, url2file)
  59. from utils.torch_utils import select_device
  60. def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
  61. # YOLOv5 TorchScript model export
  62. try:
  63. LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
  64. f = file.with_suffix('.torchscript')
  65. ts = torch.jit.trace(model, im, strict=False)
  66. d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
  67. extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
  68. if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
  69. optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
  70. else:
  71. ts.save(str(f), _extra_files=extra_files)
  72. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  73. return f
  74. except Exception as e:
  75. LOGGER.info(f'{prefix} export failure: {e}')
  76. def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
  77. # YOLOv5 ONNX export
  78. try:
  79. check_requirements(('onnx',))
  80. import onnx
  81. LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
  82. f = file.with_suffix('.onnx')
  83. torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
  84. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  85. do_constant_folding=not train,
  86. input_names=['images'],
  87. output_names=['output'],
  88. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  89. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  90. } if dynamic else None)
  91. # Checks
  92. model_onnx = onnx.load(f) # load onnx model
  93. onnx.checker.check_model(model_onnx) # check onnx model
  94. # LOGGER.info(onnx.helper.printable_graph(model_onnx.graph)) # print
  95. # Simplify
  96. if simplify:
  97. try:
  98. check_requirements(('onnx-simplifier',))
  99. import onnxsim
  100. LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
  101. model_onnx, check = onnxsim.simplify(
  102. model_onnx,
  103. dynamic_input_shape=dynamic,
  104. input_shapes={'images': list(im.shape)} if dynamic else None)
  105. assert check, 'assert check failed'
  106. onnx.save(model_onnx, f)
  107. except Exception as e:
  108. LOGGER.info(f'{prefix} simplifier failure: {e}')
  109. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  110. return f
  111. except Exception as e:
  112. LOGGER.info(f'{prefix} export failure: {e}')
  113. def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')):
  114. # YOLOv5 OpenVINO export
  115. try:
  116. check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
  117. import openvino.inference_engine as ie
  118. LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
  119. f = str(file).replace('.pt', '_openvino_model' + os.sep)
  120. cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}"
  121. subprocess.check_output(cmd, shell=True)
  122. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  123. return f
  124. except Exception as e:
  125. LOGGER.info(f'\n{prefix} export failure: {e}')
  126. def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
  127. # YOLOv5 CoreML export
  128. try:
  129. check_requirements(('coremltools',))
  130. import coremltools as ct
  131. LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
  132. f = file.with_suffix('.mlmodel')
  133. ts = torch.jit.trace(model, im, strict=False) # TorchScript model
  134. ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
  135. ct_model.save(f)
  136. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  137. return ct_model, f
  138. except Exception as e:
  139. LOGGER.info(f'\n{prefix} export failure: {e}')
  140. return None, None
  141. def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
  142. # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
  143. try:
  144. check_requirements(('tensorrt',))
  145. import tensorrt as trt
  146. if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
  147. grid = model.model[-1].anchor_grid
  148. model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
  149. export_onnx(model, im, file, 12, train, False, simplify) # opset 12
  150. model.model[-1].anchor_grid = grid
  151. else: # TensorRT >= 8
  152. check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
  153. export_onnx(model, im, file, 13, train, False, simplify) # opset 13
  154. onnx = file.with_suffix('.onnx')
  155. assert onnx.exists(), f'failed to export ONNX file: {onnx}'
  156. LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
  157. f = file.with_suffix('.engine') # TensorRT engine file
  158. logger = trt.Logger(trt.Logger.INFO)
  159. if verbose:
  160. logger.min_severity = trt.Logger.Severity.VERBOSE
  161. builder = trt.Builder(logger)
  162. config = builder.create_builder_config()
  163. config.max_workspace_size = workspace * 1 << 30
  164. flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
  165. network = builder.create_network(flag)
  166. parser = trt.OnnxParser(network, logger)
  167. if not parser.parse_from_file(str(onnx)):
  168. raise RuntimeError(f'failed to load ONNX file: {onnx}')
  169. inputs = [network.get_input(i) for i in range(network.num_inputs)]
  170. outputs = [network.get_output(i) for i in range(network.num_outputs)]
  171. LOGGER.info(f'{prefix} Network Description:')
  172. for inp in inputs:
  173. LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
  174. for out in outputs:
  175. LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
  176. half &= builder.platform_has_fast_fp16
  177. LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}')
  178. if half:
  179. config.set_flag(trt.BuilderFlag.FP16)
  180. with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
  181. t.write(engine.serialize())
  182. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  183. return f
  184. except Exception as e:
  185. LOGGER.info(f'\n{prefix} export failure: {e}')
  186. def export_saved_model(model, im, file, dynamic,
  187. tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
  188. conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')):
  189. # YOLOv5 TensorFlow SavedModel export
  190. try:
  191. import tensorflow as tf
  192. from tensorflow import keras
  193. from models.tf import TFDetect, TFModel
  194. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  195. f = str(file).replace('.pt', '_saved_model')
  196. batch_size, ch, *imgsz = list(im.shape) # BCHW
  197. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  198. im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
  199. y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
  200. inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
  201. outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
  202. keras_model = keras.Model(inputs=inputs, outputs=outputs)
  203. keras_model.trainable = False
  204. keras_model.summary()
  205. keras_model.save(f, save_format='tf')
  206. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  207. return keras_model, f
  208. except Exception as e:
  209. LOGGER.info(f'\n{prefix} export failure: {e}')
  210. return None, None
  211. def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
  212. # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
  213. try:
  214. import tensorflow as tf
  215. from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
  216. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  217. f = file.with_suffix('.pb')
  218. m = tf.function(lambda x: keras_model(x)) # full model
  219. m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
  220. frozen_func = convert_variables_to_constants_v2(m)
  221. frozen_func.graph.as_graph_def()
  222. tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
  223. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  224. return f
  225. except Exception as e:
  226. LOGGER.info(f'\n{prefix} export failure: {e}')
  227. def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')):
  228. # YOLOv5 TensorFlow Lite export
  229. try:
  230. import tensorflow as tf
  231. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  232. batch_size, ch, *imgsz = list(im.shape) # BCHW
  233. f = str(file).replace('.pt', '-fp16.tflite')
  234. converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
  235. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
  236. converter.target_spec.supported_types = [tf.float16]
  237. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  238. if int8:
  239. from models.tf import representative_dataset_gen
  240. dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
  241. converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib)
  242. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  243. converter.target_spec.supported_types = []
  244. converter.inference_input_type = tf.uint8 # or tf.int8
  245. converter.inference_output_type = tf.uint8 # or tf.int8
  246. converter.experimental_new_quantizer = False
  247. f = str(file).replace('.pt', '-int8.tflite')
  248. tflite_model = converter.convert()
  249. open(f, "wb").write(tflite_model)
  250. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  251. return f
  252. except Exception as e:
  253. LOGGER.info(f'\n{prefix} export failure: {e}')
  254. def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')):
  255. # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
  256. try:
  257. cmd = 'edgetpu_compiler --version'
  258. help_url = 'https://coral.ai/docs/edgetpu/compiler/'
  259. assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
  260. if subprocess.run(cmd, shell=True).returncode != 0:
  261. LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
  262. for c in ['curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
  263. 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
  264. 'sudo apt-get update',
  265. 'sudo apt-get install edgetpu-compiler']:
  266. subprocess.run(c, shell=True, check=True)
  267. ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
  268. LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
  269. f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
  270. f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
  271. cmd = f"edgetpu_compiler -s {f_tfl}"
  272. subprocess.run(cmd, shell=True, check=True)
  273. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  274. return f
  275. except Exception as e:
  276. LOGGER.info(f'\n{prefix} export failure: {e}')
  277. def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
  278. # YOLOv5 TensorFlow.js export
  279. try:
  280. check_requirements(('tensorflowjs',))
  281. import re
  282. import tensorflowjs as tfjs
  283. LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
  284. f = str(file).replace('.pt', '_web_model') # js dir
  285. f_pb = file.with_suffix('.pb') # *.pb path
  286. f_json = f + '/model.json' # *.json path
  287. cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
  288. f'--output_node_names="Identity,Identity_1,Identity_2,Identity_3" {f_pb} {f}'
  289. subprocess.run(cmd, shell=True)
  290. json = open(f_json).read()
  291. with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
  292. subst = re.sub(
  293. r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
  294. r'"Identity.?.?": {"name": "Identity.?.?"}, '
  295. r'"Identity.?.?": {"name": "Identity.?.?"}, '
  296. r'"Identity.?.?": {"name": "Identity.?.?"}}}',
  297. r'{"outputs": {"Identity": {"name": "Identity"}, '
  298. r'"Identity_1": {"name": "Identity_1"}, '
  299. r'"Identity_2": {"name": "Identity_2"}, '
  300. r'"Identity_3": {"name": "Identity_3"}}}',
  301. json)
  302. j.write(subst)
  303. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  304. return f
  305. except Exception as e:
  306. LOGGER.info(f'\n{prefix} export failure: {e}')
  307. @torch.no_grad()
  308. def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
  309. weights=ROOT / 'yolov5s.pt', # weights path
  310. imgsz=(640, 640), # image (height, width)
  311. batch_size=1, # batch size
  312. device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  313. include=('torchscript', 'onnx'), # include formats
  314. half=False, # FP16 half-precision export
  315. inplace=False, # set YOLOv5 Detect() inplace=True
  316. train=False, # model.train() mode
  317. optimize=False, # TorchScript: optimize for mobile
  318. int8=False, # CoreML/TF INT8 quantization
  319. dynamic=False, # ONNX/TF: dynamic axes
  320. simplify=False, # ONNX: simplify model
  321. opset=12, # ONNX: opset version
  322. verbose=False, # TensorRT: verbose log
  323. workspace=4, # TensorRT: workspace size (GB)
  324. nms=False, # TF: add NMS to model
  325. agnostic_nms=False, # TF: add agnostic NMS to model
  326. topk_per_class=100, # TF.js NMS: topk per class to keep
  327. topk_all=100, # TF.js NMS: topk for all classes to keep
  328. iou_thres=0.45, # TF.js NMS: IoU threshold
  329. conf_thres=0.25 # TF.js NMS: confidence threshold
  330. ):
  331. t = time.time()
  332. include = [x.lower() for x in include]
  333. tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports
  334. file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)
  335. # Checks
  336. imgsz *= 2 if len(imgsz) == 1 else 1 # expand
  337. opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12
  338. # Load PyTorch model
  339. device = select_device(device)
  340. assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
  341. model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model
  342. nc, names = model.nc, model.names # number of classes, class names
  343. # Input
  344. gs = int(max(model.stride)) # grid size (max stride)
  345. imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
  346. im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
  347. # Update model
  348. if half:
  349. im, model = im.half(), model.half() # to FP16
  350. model.train() if train else model.eval() # training mode = no Detect() layer grid construction
  351. for k, m in model.named_modules():
  352. if isinstance(m, Conv): # assign export-friendly activations
  353. if isinstance(m.act, nn.SiLU):
  354. m.act = SiLU()
  355. elif isinstance(m, Detect):
  356. m.inplace = inplace
  357. m.onnx_dynamic = dynamic
  358. # m.forward = m.forward_export # assign forward (optional)
  359. for _ in range(2):
  360. y = model(im) # dry runs
  361. LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")
  362. # Exports
  363. if 'torchscript' in include:
  364. f = export_torchscript(model, im, file, optimize)
  365. if 'engine' in include: # TensorRT required before ONNX
  366. f = export_engine(model, im, file, train, half, simplify, workspace, verbose)
  367. if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX
  368. f = export_onnx(model, im, file, opset, train, dynamic, simplify)
  369. if 'openvino' in include:
  370. f = export_openvino(model, im, file)
  371. if 'coreml' in include:
  372. _, f = export_coreml(model, im, file)
  373. # TensorFlow Exports
  374. if any(tf_exports):
  375. pb, tflite, edgetpu, tfjs = tf_exports[1:]
  376. if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
  377. check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
  378. assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
  379. model, f = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
  380. agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class,
  381. topk_all=topk_all,
  382. conf_thres=conf_thres, iou_thres=iou_thres) # keras model
  383. if pb or tfjs: # pb prerequisite to tfjs
  384. f = export_pb(model, im, file)
  385. if tflite or edgetpu:
  386. f = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100)
  387. if edgetpu:
  388. f = export_edgetpu(model, im, file)
  389. if tfjs:
  390. f = export_tfjs(model, im, file)
  391. # Finish
  392. LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
  393. f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
  394. f"\nVisualize with https://netron.app"
  395. f"\nDetect with `python detect.py --weights {f}`"
  396. f" or `model = torch.hub.load('ultralytics/yolov5', 'custom', '{f}')"
  397. f"\nValidate with `python val.py --weights {f}`")
  398. def parse_opt():
  399. parser = argparse.ArgumentParser()
  400. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  401. parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)')
  402. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
  403. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  404. parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  405. parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
  406. parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
  407. parser.add_argument('--train', action='store_true', help='model.train() mode')
  408. parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
  409. parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
  410. parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
  411. parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
  412. parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
  413. parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
  414. parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
  415. parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
  416. parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
  417. parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
  418. parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
  419. parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
  420. parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
  421. parser.add_argument('--include', nargs='+',
  422. default=['torchscript', 'onnx'],
  423. help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
  424. opt = parser.parse_args()
  425. print_args(FILE.stem, opt)
  426. return opt
  427. def main(opt):
  428. for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
  429. run(**vars(opt))
  430. if __name__ == "__main__":
  431. opt = parse_opt()
  432. main(opt)