* Global export sort * CleanupmodifyDataloader
@@ -15,13 +15,13 @@ Usage - formats: | |||
$ python path/to/detect.py --weights yolov5s.pt # PyTorch | |||
yolov5s.torchscript # TorchScript | |||
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn | |||
yolov5s.mlmodel # CoreML (under development) | |||
yolov5s.xml # OpenVINO | |||
yolov5s.engine # TensorRT | |||
yolov5s.mlmodel # CoreML (under development) | |||
yolov5s_saved_model # TensorFlow SavedModel | |||
yolov5s.pb # TensorFlow protobuf | |||
yolov5s.pb # TensorFlow GraphDef | |||
yolov5s.tflite # TensorFlow Lite | |||
yolov5s_edgetpu.tflite # TensorFlow Edge TPU | |||
yolov5s.engine # TensorRT | |||
""" | |||
import argparse |
@@ -2,19 +2,19 @@ | |||
""" | |||
Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit | |||
Format | Example | `--include ...` argument | |||
Format | `export.py --include` | Model | |||
--- | --- | --- | |||
PyTorch | yolov5s.pt | - | |||
TorchScript | yolov5s.torchscript | `torchscript` | |||
ONNX | yolov5s.onnx | `onnx` | |||
CoreML | yolov5s.mlmodel | `coreml` | |||
OpenVINO | yolov5s_openvino_model/ | `openvino` | |||
TensorFlow SavedModel | yolov5s_saved_model/ | `saved_model` | |||
TensorFlow GraphDef | yolov5s.pb | `pb` | |||
TensorFlow Lite | yolov5s.tflite | `tflite` | |||
TensorFlow Edge TPU | yolov5s_edgetpu.tflite | `edgetpu` | |||
TensorFlow.js | yolov5s_web_model/ | `tfjs` | |||
TensorRT | yolov5s.engine | `engine` | |||
PyTorch | - | yolov5s.pt | |||
TorchScript | `torchscript` | yolov5s.torchscript | |||
ONNX | `onnx` | yolov5s.onnx | |||
OpenVINO | `openvino` | yolov5s_openvino_model/ | |||
TensorRT | `engine` | yolov5s.engine | |||
CoreML | `coreml` | yolov5s.mlmodel | |||
TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/ | |||
TensorFlow GraphDef | `pb` | yolov5s.pb | |||
TensorFlow Lite | `tflite` | yolov5s.tflite | |||
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite | |||
TensorFlow.js | `tfjs` | yolov5s_web_model/ | |||
Usage: | |||
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml openvino saved_model tflite tfjs | |||
@@ -23,13 +23,13 @@ Inference: | |||
$ python path/to/detect.py --weights yolov5s.pt # PyTorch | |||
yolov5s.torchscript # TorchScript | |||
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn | |||
yolov5s.mlmodel # CoreML (under development) | |||
yolov5s.xml # OpenVINO | |||
yolov5s.engine # TensorRT | |||
yolov5s.mlmodel # CoreML (under development) | |||
yolov5s_saved_model # TensorFlow SavedModel | |||
yolov5s.pb # TensorFlow protobuf | |||
yolov5s.pb # TensorFlow GraphDef | |||
yolov5s.tflite # TensorFlow Lite | |||
yolov5s_edgetpu.tflite # TensorFlow Edge TPU | |||
yolov5s.engine # TensorRT | |||
TensorFlow.js: | |||
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example | |||
@@ -126,6 +126,23 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst | |||
LOGGER.info(f'{prefix} export failure: {e}') | |||
def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')): | |||
# YOLOv5 OpenVINO export | |||
try: | |||
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ | |||
import openvino.inference_engine as ie | |||
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') | |||
f = str(file).replace('.pt', '_openvino_model' + os.sep) | |||
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}" | |||
subprocess.check_output(cmd, shell=True) | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
except Exception as e: | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
def export_coreml(model, im, file, prefix=colorstr('CoreML:')): | |||
# YOLOv5 CoreML export | |||
ct_model = None | |||
@@ -148,27 +165,57 @@ def export_coreml(model, im, file, prefix=colorstr('CoreML:')): | |||
return ct_model | |||
def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')): | |||
# YOLOv5 OpenVINO export | |||
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): | |||
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt | |||
try: | |||
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ | |||
import openvino.inference_engine as ie | |||
check_requirements(('tensorrt',)) | |||
import tensorrt as trt | |||
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') | |||
f = str(file).replace('.pt', '_openvino_model' + os.sep) | |||
opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x | |||
export_onnx(model, im, file, opset, train, False, simplify) | |||
onnx = file.with_suffix('.onnx') | |||
assert onnx.exists(), f'failed to export ONNX file: {onnx}' | |||
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}" | |||
subprocess.check_output(cmd, shell=True) | |||
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') | |||
f = file.with_suffix('.engine') # TensorRT engine file | |||
logger = trt.Logger(trt.Logger.INFO) | |||
if verbose: | |||
logger.min_severity = trt.Logger.Severity.VERBOSE | |||
builder = trt.Builder(logger) | |||
config = builder.create_builder_config() | |||
config.max_workspace_size = workspace * 1 << 30 | |||
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) | |||
network = builder.create_network(flag) | |||
parser = trt.OnnxParser(network, logger) | |||
if not parser.parse_from_file(str(onnx)): | |||
raise RuntimeError(f'failed to load ONNX file: {onnx}') | |||
inputs = [network.get_input(i) for i in range(network.num_inputs)] | |||
outputs = [network.get_output(i) for i in range(network.num_outputs)] | |||
LOGGER.info(f'{prefix} Network Description:') | |||
for inp in inputs: | |||
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') | |||
for out in outputs: | |||
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') | |||
half &= builder.platform_has_fast_fp16 | |||
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}') | |||
if half: | |||
config.set_flag(trt.BuilderFlag.FP16) | |||
with builder.build_engine(network, config) as engine, open(f, 'wb') as t: | |||
t.write(engine.serialize()) | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
except Exception as e: | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
def export_saved_model(model, im, file, dynamic, | |||
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45, | |||
conf_thres=0.25, prefix=colorstr('TensorFlow saved_model:')): | |||
# YOLOv5 TensorFlow saved_model export | |||
conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')): | |||
# YOLOv5 TensorFlow SavedModel export | |||
keras_model = None | |||
try: | |||
import tensorflow as tf | |||
@@ -304,53 +351,6 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): | |||
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt | |||
try: | |||
check_requirements(('tensorrt',)) | |||
import tensorrt as trt | |||
opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x | |||
export_onnx(model, im, file, opset, train, False, simplify) | |||
onnx = file.with_suffix('.onnx') | |||
assert onnx.exists(), f'failed to export ONNX file: {onnx}' | |||
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') | |||
f = file.with_suffix('.engine') # TensorRT engine file | |||
logger = trt.Logger(trt.Logger.INFO) | |||
if verbose: | |||
logger.min_severity = trt.Logger.Severity.VERBOSE | |||
builder = trt.Builder(logger) | |||
config = builder.create_builder_config() | |||
config.max_workspace_size = workspace * 1 << 30 | |||
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) | |||
network = builder.create_network(flag) | |||
parser = trt.OnnxParser(network, logger) | |||
if not parser.parse_from_file(str(onnx)): | |||
raise RuntimeError(f'failed to load ONNX file: {onnx}') | |||
inputs = [network.get_input(i) for i in range(network.num_inputs)] | |||
outputs = [network.get_output(i) for i in range(network.num_outputs)] | |||
LOGGER.info(f'{prefix} Network Description:') | |||
for inp in inputs: | |||
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') | |||
for out in outputs: | |||
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') | |||
half &= builder.platform_has_fast_fp16 | |||
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}') | |||
if half: | |||
config.set_flag(trt.BuilderFlag.FP16) | |||
with builder.build_engine(network, config) as engine, open(f, 'wb') as t: | |||
t.write(engine.serialize()) | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
except Exception as e: | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
@torch.no_grad() | |||
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | |||
weights=ROOT / 'yolov5s.pt', # weights path | |||
@@ -417,12 +417,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | |||
export_torchscript(model, im, file, optimize) | |||
if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX | |||
export_onnx(model, im, file, opset, train, dynamic, simplify) | |||
if 'openvino' in include: | |||
export_openvino(model, im, file) | |||
if 'engine' in include: | |||
export_engine(model, im, file, train, half, simplify, workspace, verbose) | |||
if 'coreml' in include: | |||
export_coreml(model, im, file) | |||
if 'openvino' in include: | |||
export_openvino(model, im, file) | |||
# TensorFlow Exports | |||
if any(tf_exports): |
@@ -316,17 +316,6 @@ class DetectMultiBackend(nn.Module): | |||
if extra_files['config.txt']: | |||
d = json.loads(extra_files['config.txt']) # extra_files dict | |||
stride, names = int(d['stride']), d['names'] | |||
elif coreml: # CoreML | |||
LOGGER.info(f'Loading {w} for CoreML inference...') | |||
import coremltools as ct | |||
model = ct.models.MLModel(w) | |||
elif xml: # OpenVINO | |||
LOGGER.info(f'Loading {w} for OpenVINO inference...') | |||
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ | |||
import openvino.inference_engine as ie | |||
core = ie.IECore() | |||
network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths | |||
executable_network = core.load_network(network, device_name='CPU', num_requests=1) | |||
elif dnn: # ONNX OpenCV DNN | |||
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...') | |||
check_requirements(('opencv-python>=4.5.4',)) | |||
@@ -338,6 +327,13 @@ class DetectMultiBackend(nn.Module): | |||
import onnxruntime | |||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] | |||
session = onnxruntime.InferenceSession(w, providers=providers) | |||
elif xml: # OpenVINO | |||
LOGGER.info(f'Loading {w} for OpenVINO inference...') | |||
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ | |||
import openvino.inference_engine as ie | |||
core = ie.IECore() | |||
network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths | |||
executable_network = core.load_network(network, device_name='CPU', num_requests=1) | |||
elif engine: # TensorRT | |||
LOGGER.info(f'Loading {w} for TensorRT inference...') | |||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download | |||
@@ -356,9 +352,17 @@ class DetectMultiBackend(nn.Module): | |||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) | |||
context = model.create_execution_context() | |||
batch_size = bindings['images'].shape[0] | |||
else: # TensorFlow (TFLite, pb, saved_model) | |||
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt | |||
LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...') | |||
elif coreml: # CoreML | |||
LOGGER.info(f'Loading {w} for CoreML inference...') | |||
import coremltools as ct | |||
model = ct.models.MLModel(w) | |||
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) | |||
if saved_model: # SavedModel | |||
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') | |||
import tensorflow as tf | |||
model = tf.keras.models.load_model(w) | |||
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt | |||
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') | |||
import tensorflow as tf | |||
def wrap_frozen_graph(gd, inputs, outputs): | |||
@@ -369,19 +373,15 @@ class DetectMultiBackend(nn.Module): | |||
graph_def = tf.Graph().as_graph_def() | |||
graph_def.ParseFromString(open(w, 'rb').read()) | |||
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0") | |||
elif saved_model: | |||
LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...') | |||
import tensorflow as tf | |||
model = tf.keras.models.load_model(w) | |||
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python | |||
if 'edgetpu' in w.lower(): | |||
if 'edgetpu' in w.lower(): # Edge TPU | |||
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') | |||
import tflite_runtime.interpreter as tfli | |||
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime | |||
'Darwin': 'libedgetpu.1.dylib', | |||
'Windows': 'edgetpu.dll'}[platform.system()] | |||
interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)]) | |||
else: | |||
else: # Lite | |||
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') | |||
import tensorflow as tf | |||
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model | |||
@@ -396,21 +396,13 @@ class DetectMultiBackend(nn.Module): | |||
if self.pt or self.jit: # PyTorch | |||
y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize) | |||
return y if val else y[0] | |||
elif self.coreml: # CoreML | |||
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) | |||
im = Image.fromarray((im[0] * 255).astype('uint8')) | |||
# im = im.resize((192, 320), Image.ANTIALIAS) | |||
y = self.model.predict({'image': im}) # coordinates are xywh normalized | |||
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels | |||
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float) | |||
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) | |||
elif self.onnx: # ONNX | |||
elif self.dnn: # ONNX OpenCV DNN | |||
im = im.cpu().numpy() # torch to numpy | |||
if self.dnn: # ONNX OpenCV DNN | |||
self.net.setInput(im) | |||
y = self.net.forward() | |||
else: # ONNX Runtime | |||
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0] | |||
self.net.setInput(im) | |||
y = self.net.forward() | |||
elif self.onnx: # ONNX Runtime | |||
im = im.cpu().numpy() # torch to numpy | |||
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0] | |||
elif self.xml: # OpenVINO | |||
im = im.cpu().numpy() # FP32 | |||
desc = self.ie.TensorDesc(precision='FP32', dims=im.shape, layout='NCHW') # Tensor Description | |||
@@ -423,13 +415,21 @@ class DetectMultiBackend(nn.Module): | |||
self.binding_addrs['images'] = int(im.data_ptr()) | |||
self.context.execute_v2(list(self.binding_addrs.values())) | |||
y = self.bindings['output'].data | |||
else: # TensorFlow model (TFLite, pb, saved_model) | |||
elif self.coreml: # CoreML | |||
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) | |||
if self.pb: | |||
y = self.frozen_func(x=self.tf.constant(im)).numpy() | |||
elif self.saved_model: | |||
im = Image.fromarray((im[0] * 255).astype('uint8')) | |||
# im = im.resize((192, 320), Image.ANTIALIAS) | |||
y = self.model.predict({'image': im}) # coordinates are xywh normalized | |||
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels | |||
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float) | |||
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) | |||
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) | |||
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) | |||
if self.saved_model: # SavedModel | |||
y = self.model(im, training=False).numpy() | |||
elif self.tflite: | |||
elif self.pb: # GraphDef | |||
y = self.frozen_func(x=self.tf.constant(im)).numpy() | |||
elif self.tflite: # Lite | |||
input, output = self.input_details[0], self.output_details[0] | |||
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model | |||
if int8: | |||
@@ -451,7 +451,7 @@ class DetectMultiBackend(nn.Module): | |||
def warmup(self, imgsz=(1, 3, 640, 640), half=False): | |||
# Warmup model by running inference once | |||
if self.pt or self.engine or self.onnx: # warmup types | |||
if self.pt or self.jit or self.onnx or self.engine: # warmup types | |||
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models | |||
im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image | |||
self.forward(im) # warmup |
@@ -9,13 +9,13 @@ Usage - formats: | |||
$ python path/to/val.py --weights yolov5s.pt # PyTorch | |||
yolov5s.torchscript # TorchScript | |||
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn | |||
yolov5s.mlmodel # CoreML (under development) | |||
yolov5s.xml # OpenVINO | |||
yolov5s.engine # TensorRT | |||
yolov5s.mlmodel # CoreML (under development) | |||
yolov5s_saved_model # TensorFlow SavedModel | |||
yolov5s.pb # TensorFlow protobuf | |||
yolov5s.pb # TensorFlow GraphDef | |||
yolov5s.tflite # TensorFlow Lite | |||
yolov5s_edgetpu.tflite # TensorFlow Edge TPU | |||
yolov5s.engine # TensorRT | |||
""" | |||
import argparse |