Browse Source

OpenVINO Export (#6057)

* OpenVINO export

* Remove timeout

* Add 3 files

* str

* Constrain opset to 12

* Default ONNX opset to 12

* Make dir

* Make dir

* Cleanup

* Cleanup

* check_requirements(('openvino-dev',))
modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
95c7bc25d3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 6 deletions
  1. +30
    -6
      export.py
  2. +1
    -0
      requirements.txt

+ 30
- 6
export.py View File

TorchScript | yolov5s.torchscript | `torchscript` TorchScript | yolov5s.torchscript | `torchscript`
ONNX | yolov5s.onnx | `onnx` ONNX | yolov5s.onnx | `onnx`
CoreML | yolov5s.mlmodel | `coreml` CoreML | yolov5s.mlmodel | `coreml`
OpenVINO | yolov5s_openvino_model/ | `openvino`
TensorFlow SavedModel | yolov5s_saved_model/ | `saved_model` TensorFlow SavedModel | yolov5s_saved_model/ | `saved_model`
TensorFlow GraphDef | yolov5s.pb | `pb` TensorFlow GraphDef | yolov5s.pb | `pb`
TensorFlow Lite | yolov5s.tflite | `tflite` TensorFlow Lite | yolov5s.tflite | `tflite`
TensorRT | yolov5s.engine | `engine` TensorRT | yolov5s.engine | `engine`


Usage: Usage:
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml openvino saved_model tflite tfjs


Inference: Inference:
$ python path/to/detect.py --weights yolov5s.pt $ python path/to/detect.py --weights yolov5s.pt
yolov5s.torchscript yolov5s.torchscript
yolov5s.onnx yolov5s.onnx
yolov5s.mlmodel (under development) yolov5s.mlmodel (under development)
yolov5s_openvino_model (under development)
yolov5s_saved_model yolov5s_saved_model
yolov5s.pb yolov5s.pb
yolov5s.tflite yolov5s.tflite
return ct_model return ct_model




def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')):
# YOLOv5 OpenVINO export
try:
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.inference_engine as ie

LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
f = str(file).replace('.pt', '_openvino_model' + os.sep)

cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}"
subprocess.check_output(cmd, shell=True)

LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_saved_model(model, im, file, dynamic, def export_saved_model(model, im, file, dynamic,
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
conf_thres=0.25, prefix=colorstr('TensorFlow saved_model:')): conf_thres=0.25, prefix=colorstr('TensorFlow saved_model:')):
imgsz=(640, 640), # image (height, width) imgsz=(640, 640), # image (height, width)
batch_size=1, # batch size batch_size=1, # batch size
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
include=('torchscript', 'onnx', 'coreml'), # include formats
include=('torchscript', 'onnx'), # include formats
half=False, # FP16 half-precision export half=False, # FP16 half-precision export
inplace=False, # set YOLOv5 Detect() inplace=True inplace=False, # set YOLOv5 Detect() inplace=True
train=False, # model.train() mode train=False, # model.train() mode
int8=False, # CoreML/TF INT8 quantization int8=False, # CoreML/TF INT8 quantization
dynamic=False, # ONNX/TF: dynamic axes dynamic=False, # ONNX/TF: dynamic axes
simplify=False, # ONNX: simplify model simplify=False, # ONNX: simplify model
opset=14, # ONNX: opset version
opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB) workspace=4, # TensorRT: workspace size (GB)
nms=False, # TF: add NMS to model nms=False, # TF: add NMS to model
t = time.time() t = time.time()
include = [x.lower() for x in include] include = [x.lower() for x in include]
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)


# Checks
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12

# Load PyTorch model # Load PyTorch model
device = select_device(device) device = select_device(device)
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0' assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
# Exports # Exports
if 'torchscript' in include: if 'torchscript' in include:
export_torchscript(model, im, file, optimize) export_torchscript(model, im, file, optimize)
if 'onnx' in include:
if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX
export_onnx(model, im, file, opset, train, dynamic, simplify) export_onnx(model, im, file, opset, train, dynamic, simplify)
if 'engine' in include: if 'engine' in include:
export_engine(model, im, file, train, half, simplify, workspace, verbose) export_engine(model, im, file, train, half, simplify, workspace, verbose)
if 'coreml' in include: if 'coreml' in include:
export_coreml(model, im, file) export_coreml(model, im, file)
if 'openvino' in include:
export_openvino(model, im, file)


# TensorFlow Exports # TensorFlow Exports
if any(tf_exports): if any(tf_exports):
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization') parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=14, help='ONNX: opset version')
parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log') parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)') parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--nms', action='store_true', help='TF: add NMS to model') parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')

+ 1
- 0
requirements.txt View File

# scikit-learn==0.19.2 # CoreML quantization # scikit-learn==0.19.2 # CoreML quantization
# tensorflow>=2.4.1 # TFLite export # tensorflow>=2.4.1 # TFLite export
# tensorflowjs>=3.9.0 # TF.js export # tensorflowjs>=3.9.0 # TF.js export
# openvino-dev # OpenVINO export


# Extras -------------------------------------- # Extras --------------------------------------
# albumentations>=1.0.3 # albumentations>=1.0.3

Loading…
Cancel
Save