Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

488 lines
23KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
  4. Format | `export.py --include` | Model
  5. --- | --- | ---
  6. PyTorch | - | yolov5s.pt
  7. TorchScript | `torchscript` | yolov5s.torchscript
  8. ONNX | `onnx` | yolov5s.onnx
  9. OpenVINO | `openvino` | yolov5s_openvino_model/
  10. TensorRT | `engine` | yolov5s.engine
  11. CoreML | `coreml` | yolov5s.mlmodel
  12. TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
  13. TensorFlow GraphDef | `pb` | yolov5s.pb
  14. TensorFlow Lite | `tflite` | yolov5s.tflite
  15. TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
  16. TensorFlow.js | `tfjs` | yolov5s_web_model/
  17. Usage:
  18. $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml openvino saved_model tflite tfjs
  19. Inference:
  20. $ python path/to/detect.py --weights yolov5s.pt # PyTorch
  21. yolov5s.torchscript # TorchScript
  22. yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
  23. yolov5s.xml # OpenVINO
  24. yolov5s.engine # TensorRT
  25. yolov5s.mlmodel # CoreML (under development)
  26. yolov5s_saved_model # TensorFlow SavedModel
  27. yolov5s.pb # TensorFlow GraphDef
  28. yolov5s.tflite # TensorFlow Lite
  29. yolov5s_edgetpu.tflite # TensorFlow Edge TPU
  30. TensorFlow.js:
  31. $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
  32. $ npm install
  33. $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
  34. $ npm start
  35. """
  36. import argparse
  37. import json
  38. import os
  39. import subprocess
  40. import sys
  41. import time
  42. from pathlib import Path
  43. import torch
  44. import torch.nn as nn
  45. from torch.utils.mobile_optimizer import optimize_for_mobile
  46. FILE = Path(__file__).resolve()
  47. ROOT = FILE.parents[0] # YOLOv5 root directory
  48. if str(ROOT) not in sys.path:
  49. sys.path.append(str(ROOT)) # add ROOT to PATH
  50. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  51. from models.common import Conv
  52. from models.experimental import attempt_load
  53. from models.yolo import Detect
  54. from utils.activations import SiLU
  55. from utils.datasets import LoadImages
  56. from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, colorstr, file_size, print_args,
  57. url2file)
  58. from utils.torch_utils import select_device
  59. def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
  60. # YOLOv5 TorchScript model export
  61. try:
  62. LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
  63. f = file.with_suffix('.torchscript')
  64. ts = torch.jit.trace(model, im, strict=False)
  65. d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
  66. extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
  67. (optimize_for_mobile(ts) if optimize else ts).save(str(f), _extra_files=extra_files)
  68. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  69. except Exception as e:
  70. LOGGER.info(f'{prefix} export failure: {e}')
  71. def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
  72. # YOLOv5 ONNX export
  73. try:
  74. check_requirements(('onnx',))
  75. import onnx
  76. LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
  77. f = file.with_suffix('.onnx')
  78. torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
  79. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  80. do_constant_folding=not train,
  81. input_names=['images'],
  82. output_names=['output'],
  83. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  84. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  85. } if dynamic else None)
  86. # Checks
  87. model_onnx = onnx.load(f) # load onnx model
  88. onnx.checker.check_model(model_onnx) # check onnx model
  89. # LOGGER.info(onnx.helper.printable_graph(model_onnx.graph)) # print
  90. # Simplify
  91. if simplify:
  92. try:
  93. check_requirements(('onnx-simplifier',))
  94. import onnxsim
  95. LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
  96. model_onnx, check = onnxsim.simplify(
  97. model_onnx,
  98. dynamic_input_shape=dynamic,
  99. input_shapes={'images': list(im.shape)} if dynamic else None)
  100. assert check, 'assert check failed'
  101. onnx.save(model_onnx, f)
  102. except Exception as e:
  103. LOGGER.info(f'{prefix} simplifier failure: {e}')
  104. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  105. LOGGER.info(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'")
  106. except Exception as e:
  107. LOGGER.info(f'{prefix} export failure: {e}')
  108. def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')):
  109. # YOLOv5 OpenVINO export
  110. try:
  111. check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
  112. import openvino.inference_engine as ie
  113. LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
  114. f = str(file).replace('.pt', '_openvino_model' + os.sep)
  115. cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}"
  116. subprocess.check_output(cmd, shell=True)
  117. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  118. except Exception as e:
  119. LOGGER.info(f'\n{prefix} export failure: {e}')
  120. def export_coreml(model, im, file, prefix=colorstr('CoreML:')):
  121. # YOLOv5 CoreML export
  122. ct_model = None
  123. try:
  124. check_requirements(('coremltools',))
  125. import coremltools as ct
  126. LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
  127. f = file.with_suffix('.mlmodel')
  128. model.train() # CoreML exports should be placed in model.train() mode
  129. ts = torch.jit.trace(model, im, strict=False) # TorchScript model
  130. ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
  131. ct_model.save(f)
  132. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  133. except Exception as e:
  134. LOGGER.info(f'\n{prefix} export failure: {e}')
  135. return ct_model
  136. def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
  137. # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
  138. try:
  139. check_requirements(('tensorrt',))
  140. import tensorrt as trt
  141. opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
  142. export_onnx(model, im, file, opset, train, False, simplify)
  143. onnx = file.with_suffix('.onnx')
  144. assert onnx.exists(), f'failed to export ONNX file: {onnx}'
  145. LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
  146. f = file.with_suffix('.engine') # TensorRT engine file
  147. logger = trt.Logger(trt.Logger.INFO)
  148. if verbose:
  149. logger.min_severity = trt.Logger.Severity.VERBOSE
  150. builder = trt.Builder(logger)
  151. config = builder.create_builder_config()
  152. config.max_workspace_size = workspace * 1 << 30
  153. flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
  154. network = builder.create_network(flag)
  155. parser = trt.OnnxParser(network, logger)
  156. if not parser.parse_from_file(str(onnx)):
  157. raise RuntimeError(f'failed to load ONNX file: {onnx}')
  158. inputs = [network.get_input(i) for i in range(network.num_inputs)]
  159. outputs = [network.get_output(i) for i in range(network.num_outputs)]
  160. LOGGER.info(f'{prefix} Network Description:')
  161. for inp in inputs:
  162. LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
  163. for out in outputs:
  164. LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
  165. half &= builder.platform_has_fast_fp16
  166. LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}')
  167. if half:
  168. config.set_flag(trt.BuilderFlag.FP16)
  169. with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
  170. t.write(engine.serialize())
  171. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  172. except Exception as e:
  173. LOGGER.info(f'\n{prefix} export failure: {e}')
  174. def export_saved_model(model, im, file, dynamic,
  175. tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
  176. conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')):
  177. # YOLOv5 TensorFlow SavedModel export
  178. keras_model = None
  179. try:
  180. import tensorflow as tf
  181. from tensorflow import keras
  182. from models.tf import TFDetect, TFModel
  183. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  184. f = str(file).replace('.pt', '_saved_model')
  185. batch_size, ch, *imgsz = list(im.shape) # BCHW
  186. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  187. im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
  188. y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
  189. inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
  190. outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
  191. keras_model = keras.Model(inputs=inputs, outputs=outputs)
  192. keras_model.trainable = False
  193. keras_model.summary()
  194. keras_model.save(f, save_format='tf')
  195. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  196. except Exception as e:
  197. LOGGER.info(f'\n{prefix} export failure: {e}')
  198. return keras_model
  199. def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
  200. # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
  201. try:
  202. import tensorflow as tf
  203. from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
  204. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  205. f = file.with_suffix('.pb')
  206. m = tf.function(lambda x: keras_model(x)) # full model
  207. m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
  208. frozen_func = convert_variables_to_constants_v2(m)
  209. frozen_func.graph.as_graph_def()
  210. tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
  211. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  212. except Exception as e:
  213. LOGGER.info(f'\n{prefix} export failure: {e}')
  214. def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')):
  215. # YOLOv5 TensorFlow Lite export
  216. try:
  217. import tensorflow as tf
  218. from models.tf import representative_dataset_gen
  219. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  220. batch_size, ch, *imgsz = list(im.shape) # BCHW
  221. f = str(file).replace('.pt', '-fp16.tflite')
  222. converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
  223. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
  224. converter.target_spec.supported_types = [tf.float16]
  225. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  226. if int8:
  227. dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
  228. converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib)
  229. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  230. converter.target_spec.supported_types = []
  231. converter.inference_input_type = tf.uint8 # or tf.int8
  232. converter.inference_output_type = tf.uint8 # or tf.int8
  233. converter.experimental_new_quantizer = False
  234. f = str(file).replace('.pt', '-int8.tflite')
  235. tflite_model = converter.convert()
  236. open(f, "wb").write(tflite_model)
  237. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  238. except Exception as e:
  239. LOGGER.info(f'\n{prefix} export failure: {e}')
  240. def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')):
  241. # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
  242. try:
  243. cmd = 'edgetpu_compiler --version'
  244. out = subprocess.run(cmd, shell=True, capture_output=True, check=True)
  245. ver = out.stdout.decode().split()[-1]
  246. LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
  247. f = str(file).replace('.pt', '-int8_edgetpu.tflite')
  248. f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
  249. cmd = f"edgetpu_compiler -s {f_tfl}"
  250. subprocess.run(cmd, shell=True, check=True)
  251. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  252. except Exception as e:
  253. LOGGER.info(f'\n{prefix} export failure: {e}')
  254. def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
  255. # YOLOv5 TensorFlow.js export
  256. try:
  257. check_requirements(('tensorflowjs',))
  258. import re
  259. import tensorflowjs as tfjs
  260. LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
  261. f = str(file).replace('.pt', '_web_model') # js dir
  262. f_pb = file.with_suffix('.pb') # *.pb path
  263. f_json = f + '/model.json' # *.json path
  264. cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
  265. f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
  266. subprocess.run(cmd, shell=True)
  267. json = open(f_json).read()
  268. with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
  269. subst = re.sub(
  270. r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
  271. r'"Identity.?.?": {"name": "Identity.?.?"}, '
  272. r'"Identity.?.?": {"name": "Identity.?.?"}, '
  273. r'"Identity.?.?": {"name": "Identity.?.?"}}}',
  274. r'{"outputs": {"Identity": {"name": "Identity"}, '
  275. r'"Identity_1": {"name": "Identity_1"}, '
  276. r'"Identity_2": {"name": "Identity_2"}, '
  277. r'"Identity_3": {"name": "Identity_3"}}}',
  278. json)
  279. j.write(subst)
  280. LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  281. except Exception as e:
  282. LOGGER.info(f'\n{prefix} export failure: {e}')
  283. @torch.no_grad()
  284. def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
  285. weights=ROOT / 'yolov5s.pt', # weights path
  286. imgsz=(640, 640), # image (height, width)
  287. batch_size=1, # batch size
  288. device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  289. include=('torchscript', 'onnx'), # include formats
  290. half=False, # FP16 half-precision export
  291. inplace=False, # set YOLOv5 Detect() inplace=True
  292. train=False, # model.train() mode
  293. optimize=False, # TorchScript: optimize for mobile
  294. int8=False, # CoreML/TF INT8 quantization
  295. dynamic=False, # ONNX/TF: dynamic axes
  296. simplify=False, # ONNX: simplify model
  297. opset=12, # ONNX: opset version
  298. verbose=False, # TensorRT: verbose log
  299. workspace=4, # TensorRT: workspace size (GB)
  300. nms=False, # TF: add NMS to model
  301. agnostic_nms=False, # TF: add agnostic NMS to model
  302. topk_per_class=100, # TF.js NMS: topk per class to keep
  303. topk_all=100, # TF.js NMS: topk for all classes to keep
  304. iou_thres=0.45, # TF.js NMS: IoU threshold
  305. conf_thres=0.25 # TF.js NMS: confidence threshold
  306. ):
  307. t = time.time()
  308. include = [x.lower() for x in include]
  309. tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports
  310. file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)
  311. # Checks
  312. imgsz *= 2 if len(imgsz) == 1 else 1 # expand
  313. opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12
  314. # Load PyTorch model
  315. device = select_device(device)
  316. assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
  317. model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model
  318. nc, names = model.nc, model.names # number of classes, class names
  319. # Input
  320. gs = int(max(model.stride)) # grid size (max stride)
  321. imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
  322. im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
  323. # Update model
  324. if half:
  325. im, model = im.half(), model.half() # to FP16
  326. model.train() if train else model.eval() # training mode = no Detect() layer grid construction
  327. for k, m in model.named_modules():
  328. if isinstance(m, Conv): # assign export-friendly activations
  329. if isinstance(m.act, nn.SiLU):
  330. m.act = SiLU()
  331. elif isinstance(m, Detect):
  332. m.inplace = inplace
  333. m.onnx_dynamic = dynamic
  334. # m.forward = m.forward_export # assign forward (optional)
  335. for _ in range(2):
  336. y = model(im) # dry runs
  337. LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")
  338. # Exports
  339. if 'torchscript' in include:
  340. export_torchscript(model, im, file, optimize)
  341. if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX
  342. export_onnx(model, im, file, opset, train, dynamic, simplify)
  343. if 'openvino' in include:
  344. export_openvino(model, im, file)
  345. if 'engine' in include:
  346. export_engine(model, im, file, train, half, simplify, workspace, verbose)
  347. if 'coreml' in include:
  348. export_coreml(model, im, file)
  349. # TensorFlow Exports
  350. if any(tf_exports):
  351. pb, tflite, edgetpu, tfjs = tf_exports[1:]
  352. assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
  353. model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
  354. agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all,
  355. conf_thres=conf_thres, iou_thres=iou_thres) # keras model
  356. if pb or tfjs: # pb prerequisite to tfjs
  357. export_pb(model, im, file)
  358. if tflite or edgetpu:
  359. export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100)
  360. if edgetpu:
  361. export_edgetpu(model, im, file)
  362. if tfjs:
  363. export_tfjs(model, im, file)
  364. # Finish
  365. LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
  366. f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
  367. f'\nVisualize with https://netron.app')
  368. def parse_opt():
  369. parser = argparse.ArgumentParser()
  370. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  371. parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)')
  372. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
  373. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  374. parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  375. parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
  376. parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
  377. parser.add_argument('--train', action='store_true', help='model.train() mode')
  378. parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
  379. parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
  380. parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
  381. parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
  382. parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
  383. parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
  384. parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
  385. parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
  386. parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
  387. parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
  388. parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
  389. parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
  390. parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
  391. parser.add_argument('--include', nargs='+',
  392. default=['torchscript', 'onnx'],
  393. help='available formats are (torchscript, onnx, engine, coreml, saved_model, pb, tflite, tfjs)')
  394. opt = parser.parse_args()
  395. print_args(FILE.stem, opt)
  396. return opt
  397. def main(opt):
  398. for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
  399. run(**vars(opt))
  400. if __name__ == "__main__":
  401. opt = parse_opt()
  402. main(opt)