Export, detect and validation with TensorRT engine file (#5699)
* Export and detect with TensorRT engine file * Resolve `isort` * Make validation works with TensorRT engine * feat: update export docstring * feat: change suffix from *.trt to *.engine * feat: get rid of pycuda * feat: make compatiable with val.py * feat: support detect with fp16 engine * Add Lite to Edge TPU string * Remove *.trt comment * Revert to standard success logger.info string * Fix Deprecation Warning ``` export.py:310: DeprecationWarning: Use build_serialized_network instead. with builder.build_engine(network, config) as engine, open(f, 'wb') as t: ``` * Revert deprecation warning fix @imyhxy it seems we can't apply the deprecation warning fix because then export fails, so I'm reverting my previous change here. * Update export.py * Update export.py * Update common.py * export onnx to file before building TensorRT engine file * feat: triger ONNX export failed early * feat: load ONNX model from file Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
f17c86b7f0
commit
7a39803476
|
|
@ -77,11 +77,11 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
|
||||||
# Load model
|
# Load model
|
||||||
device = select_device(device)
|
device = select_device(device)
|
||||||
model = DetectMultiBackend(weights, device=device, dnn=dnn)
|
model = DetectMultiBackend(weights, device=device, dnn=dnn)
|
||||||
stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx
|
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
|
||||||
imgsz = check_img_size(imgsz, s=stride) # check image size
|
imgsz = check_img_size(imgsz, s=stride) # check image size
|
||||||
|
|
||||||
# Half
|
# Half
|
||||||
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
|
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
|
||||||
if pt:
|
if pt:
|
||||||
model.model.half() if half else model.model.float()
|
model.model.half() if half else model.model.float()
|
||||||
|
|
||||||
|
|
|
||||||
55
export.py
55
export.py
|
|
@ -12,6 +12,7 @@ TensorFlow SavedModel | yolov5s_saved_model/ | 'saved_model'
|
||||||
TensorFlow GraphDef | yolov5s.pb | 'pb'
|
TensorFlow GraphDef | yolov5s.pb | 'pb'
|
||||||
TensorFlow Lite | yolov5s.tflite | 'tflite'
|
TensorFlow Lite | yolov5s.tflite | 'tflite'
|
||||||
TensorFlow.js | yolov5s_web_model/ | 'tfjs'
|
TensorFlow.js | yolov5s_web_model/ | 'tfjs'
|
||||||
|
TensorRT | yolov5s.engine | 'engine'
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs
|
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs
|
||||||
|
|
@ -24,6 +25,7 @@ Inference:
|
||||||
yolov5s_saved_model
|
yolov5s_saved_model
|
||||||
yolov5s.pb
|
yolov5s.pb
|
||||||
yolov5s.tflite
|
yolov5s.tflite
|
||||||
|
yolov5s.engine
|
||||||
|
|
||||||
TensorFlow.js:
|
TensorFlow.js:
|
||||||
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
||||||
|
|
@ -263,6 +265,51 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
|
||||||
LOGGER.info(f'\n{prefix} export failure: {e}')
|
LOGGER.info(f'\n{prefix} export failure: {e}')
|
||||||
|
|
||||||
|
|
||||||
|
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
|
||||||
|
try:
|
||||||
|
check_requirements(('tensorrt',))
|
||||||
|
import tensorrt as trt
|
||||||
|
|
||||||
|
opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
|
||||||
|
export_onnx(model, im, file, opset, train, False, simplify)
|
||||||
|
onnx = file.with_suffix('.onnx')
|
||||||
|
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
||||||
|
|
||||||
|
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
||||||
|
f = str(file).replace('.pt', '.engine') # TensorRT engine file
|
||||||
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
|
if verbose:
|
||||||
|
logger.min_severity = trt.Logger.Severity.VERBOSE
|
||||||
|
|
||||||
|
builder = trt.Builder(logger)
|
||||||
|
config = builder.create_builder_config()
|
||||||
|
config.max_workspace_size = workspace * 1 << 30
|
||||||
|
|
||||||
|
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||||
|
network = builder.create_network(flag)
|
||||||
|
parser = trt.OnnxParser(network, logger)
|
||||||
|
if not parser.parse_from_file(str(onnx)):
|
||||||
|
raise RuntimeError(f'failed to load ONNX file: {onnx}')
|
||||||
|
|
||||||
|
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||||||
|
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
||||||
|
LOGGER.info(f'{prefix} Network Description:')
|
||||||
|
for inp in inputs:
|
||||||
|
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
|
||||||
|
for out in outputs:
|
||||||
|
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
||||||
|
|
||||||
|
half &= builder.platform_has_fast_fp16
|
||||||
|
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}')
|
||||||
|
if half:
|
||||||
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
|
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
||||||
|
t.write(engine.serialize())
|
||||||
|
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.info(f'\n{prefix} export failure: {e}')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
||||||
weights=ROOT / 'yolov5s.pt', # weights path
|
weights=ROOT / 'yolov5s.pt', # weights path
|
||||||
|
|
@ -278,6 +325,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
||||||
dynamic=False, # ONNX/TF: dynamic axes
|
dynamic=False, # ONNX/TF: dynamic axes
|
||||||
simplify=False, # ONNX: simplify model
|
simplify=False, # ONNX: simplify model
|
||||||
opset=12, # ONNX: opset version
|
opset=12, # ONNX: opset version
|
||||||
|
verbose=False, # TensorRT: verbose log
|
||||||
|
workspace=4, # TensorRT: workspace size (GB)
|
||||||
topk_per_class=100, # TF.js NMS: topk per class to keep
|
topk_per_class=100, # TF.js NMS: topk per class to keep
|
||||||
topk_all=100, # TF.js NMS: topk for all classes to keep
|
topk_all=100, # TF.js NMS: topk for all classes to keep
|
||||||
iou_thres=0.45, # TF.js NMS: IoU threshold
|
iou_thres=0.45, # TF.js NMS: IoU threshold
|
||||||
|
|
@ -322,6 +371,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
||||||
export_torchscript(model, im, file, optimize)
|
export_torchscript(model, im, file, optimize)
|
||||||
if 'onnx' in include:
|
if 'onnx' in include:
|
||||||
export_onnx(model, im, file, opset, train, dynamic, simplify)
|
export_onnx(model, im, file, opset, train, dynamic, simplify)
|
||||||
|
if 'engine' in include:
|
||||||
|
export_engine(model, im, file, train, half, simplify, workspace, verbose)
|
||||||
if 'coreml' in include:
|
if 'coreml' in include:
|
||||||
export_coreml(model, im, file)
|
export_coreml(model, im, file)
|
||||||
|
|
||||||
|
|
@ -360,13 +411,15 @@ def parse_opt():
|
||||||
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
|
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
|
||||||
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
|
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
|
||||||
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
|
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
|
||||||
|
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
|
||||||
|
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
|
||||||
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
|
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
|
||||||
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
|
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
|
||||||
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
|
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
|
||||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
|
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
|
||||||
parser.add_argument('--include', nargs='+',
|
parser.add_argument('--include', nargs='+',
|
||||||
default=['torchscript', 'onnx'],
|
default=['torchscript', 'onnx'],
|
||||||
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
|
help='available formats are (torchscript, onnx, engine, coreml, saved_model, pb, tflite, tfjs)')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print_args(FILE.stem, opt)
|
print_args(FILE.stem, opt)
|
||||||
return opt
|
return opt
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import json
|
||||||
import math
|
import math
|
||||||
import platform
|
import platform
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import namedtuple
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -285,11 +286,12 @@ class DetectMultiBackend(nn.Module):
|
||||||
# TensorFlow Lite: *.tflite
|
# TensorFlow Lite: *.tflite
|
||||||
# ONNX Runtime: *.onnx
|
# ONNX Runtime: *.onnx
|
||||||
# OpenCV DNN: *.onnx with dnn=True
|
# OpenCV DNN: *.onnx with dnn=True
|
||||||
|
# TensorRT: *.engine
|
||||||
super().__init__()
|
super().__init__()
|
||||||
w = str(weights[0] if isinstance(weights, list) else weights)
|
w = str(weights[0] if isinstance(weights, list) else weights)
|
||||||
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
|
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
|
||||||
check_suffix(w, suffixes) # check weights have acceptable suffix
|
check_suffix(w, suffixes) # check weights have acceptable suffix
|
||||||
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
|
pt, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
|
||||||
jit = pt and 'torchscript' in w.lower()
|
jit = pt and 'torchscript' in w.lower()
|
||||||
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
|
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
|
||||||
|
|
||||||
|
|
@ -317,6 +319,23 @@ class DetectMultiBackend(nn.Module):
|
||||||
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
|
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
session = onnxruntime.InferenceSession(w, None)
|
session = onnxruntime.InferenceSession(w, None)
|
||||||
|
elif engine: # TensorRT
|
||||||
|
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
||||||
|
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
||||||
|
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
||||||
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
|
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
||||||
|
model = runtime.deserialize_cuda_engine(f.read())
|
||||||
|
bindings = dict()
|
||||||
|
for index in range(model.num_bindings):
|
||||||
|
name = model.get_binding_name(index)
|
||||||
|
dtype = trt.nptype(model.get_binding_dtype(index))
|
||||||
|
shape = tuple(model.get_binding_shape(index))
|
||||||
|
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
|
||||||
|
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
|
||||||
|
binding_addrs = {n: d.ptr for n, d in bindings.items()}
|
||||||
|
context = model.create_execution_context()
|
||||||
|
batch_size = bindings['images'].shape[0]
|
||||||
else: # TensorFlow model (TFLite, pb, saved_model)
|
else: # TensorFlow model (TFLite, pb, saved_model)
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||||
|
|
@ -334,7 +353,7 @@ class DetectMultiBackend(nn.Module):
|
||||||
model = tf.keras.models.load_model(w)
|
model = tf.keras.models.load_model(w)
|
||||||
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||||
if 'edgetpu' in w.lower():
|
if 'edgetpu' in w.lower():
|
||||||
LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...')
|
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
||||||
import tflite_runtime.interpreter as tfli
|
import tflite_runtime.interpreter as tfli
|
||||||
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
|
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
|
||||||
'Darwin': 'libedgetpu.1.dylib',
|
'Darwin': 'libedgetpu.1.dylib',
|
||||||
|
|
@ -369,6 +388,11 @@ class DetectMultiBackend(nn.Module):
|
||||||
y = self.net.forward()
|
y = self.net.forward()
|
||||||
else: # ONNX Runtime
|
else: # ONNX Runtime
|
||||||
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
|
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
|
||||||
|
elif self.engine: # TensorRT
|
||||||
|
assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
|
||||||
|
self.binding_addrs['images'] = int(im.data_ptr())
|
||||||
|
self.context.execute_v2(list(self.binding_addrs.values()))
|
||||||
|
y = self.bindings['output'].data
|
||||||
else: # TensorFlow model (TFLite, pb, saved_model)
|
else: # TensorFlow model (TFLite, pb, saved_model)
|
||||||
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
|
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
|
||||||
if self.pb:
|
if self.pb:
|
||||||
|
|
@ -391,7 +415,7 @@ class DetectMultiBackend(nn.Module):
|
||||||
y[..., 1] *= h # y
|
y[..., 1] *= h # y
|
||||||
y[..., 2] *= w # w
|
y[..., 2] *= w # w
|
||||||
y[..., 3] *= h # h
|
y[..., 3] *= h # h
|
||||||
y = torch.tensor(y)
|
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
|
||||||
return (y, []) if val else y
|
return (y, []) if val else y
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
10
val.py
10
val.py
|
|
@ -111,7 +111,7 @@ def run(data,
|
||||||
# Initialize/load model and set device
|
# Initialize/load model and set device
|
||||||
training = model is not None
|
training = model is not None
|
||||||
if training: # called by train.py
|
if training: # called by train.py
|
||||||
device, pt = next(model.parameters()).device, True # get model device, PyTorch model
|
device, pt, engine = next(model.parameters()).device, True, False # get model device, PyTorch model
|
||||||
|
|
||||||
half &= device.type != 'cpu' # half precision only supported on CUDA
|
half &= device.type != 'cpu' # half precision only supported on CUDA
|
||||||
model.half() if half else model.float()
|
model.half() if half else model.float()
|
||||||
|
|
@ -124,11 +124,13 @@ def run(data,
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model = DetectMultiBackend(weights, device=device, dnn=dnn)
|
model = DetectMultiBackend(weights, device=device, dnn=dnn)
|
||||||
stride, pt = model.stride, model.pt
|
stride, pt, engine = model.stride, model.pt, model.engine
|
||||||
imgsz = check_img_size(imgsz, s=stride) # check image size
|
imgsz = check_img_size(imgsz, s=stride) # check image size
|
||||||
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
|
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
|
||||||
if pt:
|
if pt:
|
||||||
model.model.half() if half else model.model.float()
|
model.model.half() if half else model.model.float()
|
||||||
|
elif engine:
|
||||||
|
batch_size = model.batch_size
|
||||||
else:
|
else:
|
||||||
half = False
|
half = False
|
||||||
batch_size = 1 # export.py models default to batch-size 1
|
batch_size = 1 # export.py models default to batch-size 1
|
||||||
|
|
@ -165,7 +167,7 @@ def run(data,
|
||||||
pbar = tqdm(dataloader, desc=s, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
|
pbar = tqdm(dataloader, desc=s, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
|
||||||
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
|
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
|
||||||
t1 = time_sync()
|
t1 = time_sync()
|
||||||
if pt:
|
if pt or engine:
|
||||||
im = im.to(device, non_blocking=True)
|
im = im.to(device, non_blocking=True)
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
im = im.half() if half else im.float() # uint8 to fp16/32
|
im = im.half() if half else im.float() # uint8 to fp16/32
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue