* Export and detect with TensorRT engine file * Resolve `isort` * Make validation works with TensorRT engine * feat: update export docstring * feat: change suffix from *.trt to *.engine * feat: get rid of pycuda * feat: make compatiable with val.py * feat: support detect with fp16 engine * Add Lite to Edge TPU string * Remove *.trt comment * Revert to standard success logger.info string * Fix Deprecation Warning ``` export.py:310: DeprecationWarning: Use build_serialized_network instead. with builder.build_engine(network, config) as engine, open(f, 'wb') as t: ``` * Revert deprecation warning fix @imyhxy it seems we can't apply the deprecation warning fix because then export fails, so I'm reverting my previous change here. * Update export.py * Update export.py * Update common.py * export onnx to file before building TensorRT engine file * feat: triger ONNX export failed early * feat: load ONNX model from file Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>modifyDataloader
@@ -77,11 +77,11 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | |||
# Load model | |||
device = select_device(device) | |||
model = DetectMultiBackend(weights, device=device, dnn=dnn) | |||
stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx | |||
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine | |||
imgsz = check_img_size(imgsz, s=stride) # check image size | |||
# Half | |||
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA | |||
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA | |||
if pt: | |||
model.model.half() if half else model.model.float() | |||
@@ -12,6 +12,7 @@ TensorFlow SavedModel | yolov5s_saved_model/ | 'saved_model' | |||
TensorFlow GraphDef | yolov5s.pb | 'pb' | |||
TensorFlow Lite | yolov5s.tflite | 'tflite' | |||
TensorFlow.js | yolov5s_web_model/ | 'tfjs' | |||
TensorRT | yolov5s.engine | 'engine' | |||
Usage: | |||
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs | |||
@@ -24,6 +25,7 @@ Inference: | |||
yolov5s_saved_model | |||
yolov5s.pb | |||
yolov5s.tflite | |||
yolov5s.engine | |||
TensorFlow.js: | |||
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example | |||
@@ -263,6 +265,51 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): | |||
try: | |||
check_requirements(('tensorrt',)) | |||
import tensorrt as trt | |||
opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x | |||
export_onnx(model, im, file, opset, train, False, simplify) | |||
onnx = file.with_suffix('.onnx') | |||
assert onnx.exists(), f'failed to export ONNX file: {onnx}' | |||
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') | |||
f = str(file).replace('.pt', '.engine') # TensorRT engine file | |||
logger = trt.Logger(trt.Logger.INFO) | |||
if verbose: | |||
logger.min_severity = trt.Logger.Severity.VERBOSE | |||
builder = trt.Builder(logger) | |||
config = builder.create_builder_config() | |||
config.max_workspace_size = workspace * 1 << 30 | |||
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) | |||
network = builder.create_network(flag) | |||
parser = trt.OnnxParser(network, logger) | |||
if not parser.parse_from_file(str(onnx)): | |||
raise RuntimeError(f'failed to load ONNX file: {onnx}') | |||
inputs = [network.get_input(i) for i in range(network.num_inputs)] | |||
outputs = [network.get_output(i) for i in range(network.num_outputs)] | |||
LOGGER.info(f'{prefix} Network Description:') | |||
for inp in inputs: | |||
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') | |||
for out in outputs: | |||
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') | |||
half &= builder.platform_has_fast_fp16 | |||
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}') | |||
if half: | |||
config.set_flag(trt.BuilderFlag.FP16) | |||
with builder.build_engine(network, config) as engine, open(f, 'wb') as t: | |||
t.write(engine.serialize()) | |||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||
except Exception as e: | |||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||
@torch.no_grad() | |||
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | |||
weights=ROOT / 'yolov5s.pt', # weights path | |||
@@ -278,6 +325,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | |||
dynamic=False, # ONNX/TF: dynamic axes | |||
simplify=False, # ONNX: simplify model | |||
opset=12, # ONNX: opset version | |||
verbose=False, # TensorRT: verbose log | |||
workspace=4, # TensorRT: workspace size (GB) | |||
topk_per_class=100, # TF.js NMS: topk per class to keep | |||
topk_all=100, # TF.js NMS: topk for all classes to keep | |||
iou_thres=0.45, # TF.js NMS: IoU threshold | |||
@@ -322,6 +371,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | |||
export_torchscript(model, im, file, optimize) | |||
if 'onnx' in include: | |||
export_onnx(model, im, file, opset, train, dynamic, simplify) | |||
if 'engine' in include: | |||
export_engine(model, im, file, train, half, simplify, workspace, verbose) | |||
if 'coreml' in include: | |||
export_coreml(model, im, file) | |||
@@ -360,13 +411,15 @@ def parse_opt(): | |||
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') | |||
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') | |||
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version') | |||
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log') | |||
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)') | |||
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') | |||
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep') | |||
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold') | |||
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') | |||
parser.add_argument('--include', nargs='+', | |||
default=['torchscript', 'onnx'], | |||
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)') | |||
help='available formats are (torchscript, onnx, engine, coreml, saved_model, pb, tflite, tfjs)') | |||
opt = parser.parse_args() | |||
print_args(FILE.stem, opt) | |||
return opt |
@@ -7,6 +7,7 @@ import json | |||
import math | |||
import platform | |||
import warnings | |||
from collections import namedtuple | |||
from copy import copy | |||
from pathlib import Path | |||
@@ -285,11 +286,12 @@ class DetectMultiBackend(nn.Module): | |||
# TensorFlow Lite: *.tflite | |||
# ONNX Runtime: *.onnx | |||
# OpenCV DNN: *.onnx with dnn=True | |||
# TensorRT: *.engine | |||
super().__init__() | |||
w = str(weights[0] if isinstance(weights, list) else weights) | |||
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel'] | |||
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel'] | |||
check_suffix(w, suffixes) # check weights have acceptable suffix | |||
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans | |||
pt, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans | |||
jit = pt and 'torchscript' in w.lower() | |||
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults | |||
@@ -317,6 +319,23 @@ class DetectMultiBackend(nn.Module): | |||
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime')) | |||
import onnxruntime | |||
session = onnxruntime.InferenceSession(w, None) | |||
elif engine: # TensorRT | |||
LOGGER.info(f'Loading {w} for TensorRT inference...') | |||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download | |||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) | |||
logger = trt.Logger(trt.Logger.INFO) | |||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime: | |||
model = runtime.deserialize_cuda_engine(f.read()) | |||
bindings = dict() | |||
for index in range(model.num_bindings): | |||
name = model.get_binding_name(index) | |||
dtype = trt.nptype(model.get_binding_dtype(index)) | |||
shape = tuple(model.get_binding_shape(index)) | |||
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device) | |||
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) | |||
binding_addrs = {n: d.ptr for n, d in bindings.items()} | |||
context = model.create_execution_context() | |||
batch_size = bindings['images'].shape[0] | |||
else: # TensorFlow model (TFLite, pb, saved_model) | |||
import tensorflow as tf | |||
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt | |||
@@ -334,7 +353,7 @@ class DetectMultiBackend(nn.Module): | |||
model = tf.keras.models.load_model(w) | |||
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python | |||
if 'edgetpu' in w.lower(): | |||
LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...') | |||
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') | |||
import tflite_runtime.interpreter as tfli | |||
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime | |||
'Darwin': 'libedgetpu.1.dylib', | |||
@@ -369,6 +388,11 @@ class DetectMultiBackend(nn.Module): | |||
y = self.net.forward() | |||
else: # ONNX Runtime | |||
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0] | |||
elif self.engine: # TensorRT | |||
assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape) | |||
self.binding_addrs['images'] = int(im.data_ptr()) | |||
self.context.execute_v2(list(self.binding_addrs.values())) | |||
y = self.bindings['output'].data | |||
else: # TensorFlow model (TFLite, pb, saved_model) | |||
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) | |||
if self.pb: | |||
@@ -391,7 +415,7 @@ class DetectMultiBackend(nn.Module): | |||
y[..., 1] *= h # y | |||
y[..., 2] *= w # w | |||
y[..., 3] *= h # h | |||
y = torch.tensor(y) | |||
y = torch.tensor(y) if isinstance(y, np.ndarray) else y | |||
return (y, []) if val else y | |||
@@ -111,7 +111,7 @@ def run(data, | |||
# Initialize/load model and set device | |||
training = model is not None | |||
if training: # called by train.py | |||
device, pt = next(model.parameters()).device, True # get model device, PyTorch model | |||
device, pt, engine = next(model.parameters()).device, True, False # get model device, PyTorch model | |||
half &= device.type != 'cpu' # half precision only supported on CUDA | |||
model.half() if half else model.float() | |||
@@ -124,11 +124,13 @@ def run(data, | |||
# Load model | |||
model = DetectMultiBackend(weights, device=device, dnn=dnn) | |||
stride, pt = model.stride, model.pt | |||
stride, pt, engine = model.stride, model.pt, model.engine | |||
imgsz = check_img_size(imgsz, s=stride) # check image size | |||
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA | |||
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA | |||
if pt: | |||
model.model.half() if half else model.model.float() | |||
elif engine: | |||
batch_size = model.batch_size | |||
else: | |||
half = False | |||
batch_size = 1 # export.py models default to batch-size 1 | |||
@@ -165,7 +167,7 @@ def run(data, | |||
pbar = tqdm(dataloader, desc=s, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar | |||
for batch_i, (im, targets, paths, shapes) in enumerate(pbar): | |||
t1 = time_sync() | |||
if pt: | |||
if pt or engine: | |||
im = im.to(device, non_blocking=True) | |||
targets = targets.to(device) | |||
im = im.half() if half else im.float() # uint8 to fp16/32 |