* Export and detect with TensorRT engine file * Resolve `isort` * Make validation works with TensorRT engine * feat: update export docstring * feat: change suffix from *.trt to *.engine * feat: get rid of pycuda * feat: make compatiable with val.py * feat: support detect with fp16 engine * Add Lite to Edge TPU string * Remove *.trt comment * Revert to standard success logger.info string * Fix Deprecation Warning ``` export.py:310: DeprecationWarning: Use build_serialized_network instead. with builder.build_engine(network, config) as engine, open(f, 'wb') as t: ``` * Revert deprecation warning fix @imyhxy it seems we can't apply the deprecation warning fix because then export fails, so I'm reverting my previous change here. * Update export.py * Update export.py * Update common.py * export onnx to file before building TensorRT engine file * feat: triger ONNX export failed early * feat: load ONNX model from file Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>modifyDataloader
# Load model | # Load model | ||||
device = select_device(device) | device = select_device(device) | ||||
model = DetectMultiBackend(weights, device=device, dnn=dnn) | model = DetectMultiBackend(weights, device=device, dnn=dnn) | ||||
stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx | |||||
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine | |||||
imgsz = check_img_size(imgsz, s=stride) # check image size | imgsz = check_img_size(imgsz, s=stride) # check image size | ||||
# Half | # Half | ||||
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA | |||||
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA | |||||
if pt: | if pt: | ||||
model.model.half() if half else model.model.float() | model.model.half() if half else model.model.float() | ||||
TensorFlow GraphDef | yolov5s.pb | 'pb' | TensorFlow GraphDef | yolov5s.pb | 'pb' | ||||
TensorFlow Lite | yolov5s.tflite | 'tflite' | TensorFlow Lite | yolov5s.tflite | 'tflite' | ||||
TensorFlow.js | yolov5s_web_model/ | 'tfjs' | TensorFlow.js | yolov5s_web_model/ | 'tfjs' | ||||
TensorRT | yolov5s.engine | 'engine' | |||||
Usage: | Usage: | ||||
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs | $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs | ||||
yolov5s_saved_model | yolov5s_saved_model | ||||
yolov5s.pb | yolov5s.pb | ||||
yolov5s.tflite | yolov5s.tflite | ||||
yolov5s.engine | |||||
TensorFlow.js: | TensorFlow.js: | ||||
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example | $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example | ||||
LOGGER.info(f'\n{prefix} export failure: {e}') | LOGGER.info(f'\n{prefix} export failure: {e}') | ||||
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): | |||||
try: | |||||
check_requirements(('tensorrt',)) | |||||
import tensorrt as trt | |||||
opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x | |||||
export_onnx(model, im, file, opset, train, False, simplify) | |||||
onnx = file.with_suffix('.onnx') | |||||
assert onnx.exists(), f'failed to export ONNX file: {onnx}' | |||||
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') | |||||
f = str(file).replace('.pt', '.engine') # TensorRT engine file | |||||
logger = trt.Logger(trt.Logger.INFO) | |||||
if verbose: | |||||
logger.min_severity = trt.Logger.Severity.VERBOSE | |||||
builder = trt.Builder(logger) | |||||
config = builder.create_builder_config() | |||||
config.max_workspace_size = workspace * 1 << 30 | |||||
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) | |||||
network = builder.create_network(flag) | |||||
parser = trt.OnnxParser(network, logger) | |||||
if not parser.parse_from_file(str(onnx)): | |||||
raise RuntimeError(f'failed to load ONNX file: {onnx}') | |||||
inputs = [network.get_input(i) for i in range(network.num_inputs)] | |||||
outputs = [network.get_output(i) for i in range(network.num_outputs)] | |||||
LOGGER.info(f'{prefix} Network Description:') | |||||
for inp in inputs: | |||||
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') | |||||
for out in outputs: | |||||
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') | |||||
half &= builder.platform_has_fast_fp16 | |||||
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}') | |||||
if half: | |||||
config.set_flag(trt.BuilderFlag.FP16) | |||||
with builder.build_engine(network, config) as engine, open(f, 'wb') as t: | |||||
t.write(engine.serialize()) | |||||
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') | |||||
except Exception as e: | |||||
LOGGER.info(f'\n{prefix} export failure: {e}') | |||||
@torch.no_grad() | @torch.no_grad() | ||||
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' | ||||
weights=ROOT / 'yolov5s.pt', # weights path | weights=ROOT / 'yolov5s.pt', # weights path | ||||
dynamic=False, # ONNX/TF: dynamic axes | dynamic=False, # ONNX/TF: dynamic axes | ||||
simplify=False, # ONNX: simplify model | simplify=False, # ONNX: simplify model | ||||
opset=12, # ONNX: opset version | opset=12, # ONNX: opset version | ||||
verbose=False, # TensorRT: verbose log | |||||
workspace=4, # TensorRT: workspace size (GB) | |||||
topk_per_class=100, # TF.js NMS: topk per class to keep | topk_per_class=100, # TF.js NMS: topk per class to keep | ||||
topk_all=100, # TF.js NMS: topk for all classes to keep | topk_all=100, # TF.js NMS: topk for all classes to keep | ||||
iou_thres=0.45, # TF.js NMS: IoU threshold | iou_thres=0.45, # TF.js NMS: IoU threshold | ||||
export_torchscript(model, im, file, optimize) | export_torchscript(model, im, file, optimize) | ||||
if 'onnx' in include: | if 'onnx' in include: | ||||
export_onnx(model, im, file, opset, train, dynamic, simplify) | export_onnx(model, im, file, opset, train, dynamic, simplify) | ||||
if 'engine' in include: | |||||
export_engine(model, im, file, train, half, simplify, workspace, verbose) | |||||
if 'coreml' in include: | if 'coreml' in include: | ||||
export_coreml(model, im, file) | export_coreml(model, im, file) | ||||
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') | parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') | ||||
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') | parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') | ||||
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version') | parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version') | ||||
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log') | |||||
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)') | |||||
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') | parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') | ||||
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep') | parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep') | ||||
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold') | parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold') | ||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') | parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') | ||||
parser.add_argument('--include', nargs='+', | parser.add_argument('--include', nargs='+', | ||||
default=['torchscript', 'onnx'], | default=['torchscript', 'onnx'], | ||||
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)') | |||||
help='available formats are (torchscript, onnx, engine, coreml, saved_model, pb, tflite, tfjs)') | |||||
opt = parser.parse_args() | opt = parser.parse_args() | ||||
print_args(FILE.stem, opt) | print_args(FILE.stem, opt) | ||||
return opt | return opt |
import math | import math | ||||
import platform | import platform | ||||
import warnings | import warnings | ||||
from collections import namedtuple | |||||
from copy import copy | from copy import copy | ||||
from pathlib import Path | from pathlib import Path | ||||
# TensorFlow Lite: *.tflite | # TensorFlow Lite: *.tflite | ||||
# ONNX Runtime: *.onnx | # ONNX Runtime: *.onnx | ||||
# OpenCV DNN: *.onnx with dnn=True | # OpenCV DNN: *.onnx with dnn=True | ||||
# TensorRT: *.engine | |||||
super().__init__() | super().__init__() | ||||
w = str(weights[0] if isinstance(weights, list) else weights) | w = str(weights[0] if isinstance(weights, list) else weights) | ||||
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel'] | |||||
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel'] | |||||
check_suffix(w, suffixes) # check weights have acceptable suffix | check_suffix(w, suffixes) # check weights have acceptable suffix | ||||
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans | |||||
pt, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans | |||||
jit = pt and 'torchscript' in w.lower() | jit = pt and 'torchscript' in w.lower() | ||||
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults | stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults | ||||
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime')) | check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime')) | ||||
import onnxruntime | import onnxruntime | ||||
session = onnxruntime.InferenceSession(w, None) | session = onnxruntime.InferenceSession(w, None) | ||||
elif engine: # TensorRT | |||||
LOGGER.info(f'Loading {w} for TensorRT inference...') | |||||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download | |||||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) | |||||
logger = trt.Logger(trt.Logger.INFO) | |||||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime: | |||||
model = runtime.deserialize_cuda_engine(f.read()) | |||||
bindings = dict() | |||||
for index in range(model.num_bindings): | |||||
name = model.get_binding_name(index) | |||||
dtype = trt.nptype(model.get_binding_dtype(index)) | |||||
shape = tuple(model.get_binding_shape(index)) | |||||
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device) | |||||
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) | |||||
binding_addrs = {n: d.ptr for n, d in bindings.items()} | |||||
context = model.create_execution_context() | |||||
batch_size = bindings['images'].shape[0] | |||||
else: # TensorFlow model (TFLite, pb, saved_model) | else: # TensorFlow model (TFLite, pb, saved_model) | ||||
import tensorflow as tf | import tensorflow as tf | ||||
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt | if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt | ||||
model = tf.keras.models.load_model(w) | model = tf.keras.models.load_model(w) | ||||
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python | elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python | ||||
if 'edgetpu' in w.lower(): | if 'edgetpu' in w.lower(): | ||||
LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...') | |||||
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') | |||||
import tflite_runtime.interpreter as tfli | import tflite_runtime.interpreter as tfli | ||||
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime | delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime | ||||
'Darwin': 'libedgetpu.1.dylib', | 'Darwin': 'libedgetpu.1.dylib', | ||||
y = self.net.forward() | y = self.net.forward() | ||||
else: # ONNX Runtime | else: # ONNX Runtime | ||||
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0] | y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0] | ||||
elif self.engine: # TensorRT | |||||
assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape) | |||||
self.binding_addrs['images'] = int(im.data_ptr()) | |||||
self.context.execute_v2(list(self.binding_addrs.values())) | |||||
y = self.bindings['output'].data | |||||
else: # TensorFlow model (TFLite, pb, saved_model) | else: # TensorFlow model (TFLite, pb, saved_model) | ||||
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) | im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) | ||||
if self.pb: | if self.pb: | ||||
y[..., 1] *= h # y | y[..., 1] *= h # y | ||||
y[..., 2] *= w # w | y[..., 2] *= w # w | ||||
y[..., 3] *= h # h | y[..., 3] *= h # h | ||||
y = torch.tensor(y) | |||||
y = torch.tensor(y) if isinstance(y, np.ndarray) else y | |||||
return (y, []) if val else y | return (y, []) if val else y | ||||
# Initialize/load model and set device | # Initialize/load model and set device | ||||
training = model is not None | training = model is not None | ||||
if training: # called by train.py | if training: # called by train.py | ||||
device, pt = next(model.parameters()).device, True # get model device, PyTorch model | |||||
device, pt, engine = next(model.parameters()).device, True, False # get model device, PyTorch model | |||||
half &= device.type != 'cpu' # half precision only supported on CUDA | half &= device.type != 'cpu' # half precision only supported on CUDA | ||||
model.half() if half else model.float() | model.half() if half else model.float() | ||||
# Load model | # Load model | ||||
model = DetectMultiBackend(weights, device=device, dnn=dnn) | model = DetectMultiBackend(weights, device=device, dnn=dnn) | ||||
stride, pt = model.stride, model.pt | |||||
stride, pt, engine = model.stride, model.pt, model.engine | |||||
imgsz = check_img_size(imgsz, s=stride) # check image size | imgsz = check_img_size(imgsz, s=stride) # check image size | ||||
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA | |||||
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA | |||||
if pt: | if pt: | ||||
model.model.half() if half else model.model.float() | model.model.half() if half else model.model.float() | ||||
elif engine: | |||||
batch_size = model.batch_size | |||||
else: | else: | ||||
half = False | half = False | ||||
batch_size = 1 # export.py models default to batch-size 1 | batch_size = 1 # export.py models default to batch-size 1 | ||||
pbar = tqdm(dataloader, desc=s, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar | pbar = tqdm(dataloader, desc=s, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar | ||||
for batch_i, (im, targets, paths, shapes) in enumerate(pbar): | for batch_i, (im, targets, paths, shapes) in enumerate(pbar): | ||||
t1 = time_sync() | t1 = time_sync() | ||||
if pt: | |||||
if pt or engine: | |||||
im = im.to(device, non_blocking=True) | im = im.to(device, non_blocking=True) | ||||
targets = targets.to(device) | targets = targets.to(device) | ||||
im = im.half() if half else im.float() # uint8 to fp16/32 | im = im.half() if half else im.float() # uint8 to fp16/32 |