Browse Source

Export, detect and validation with TensorRT engine file (#5699)

* Export and detect with TensorRT engine file

* Resolve `isort`

* Make validation works with TensorRT engine

* feat: update export docstring

* feat: change suffix from *.trt to *.engine

* feat: get rid of pycuda

* feat: make compatiable with val.py

* feat: support detect with fp16 engine

* Add Lite to Edge TPU string

* Remove *.trt comment

* Revert to standard success logger.info string

* Fix Deprecation Warning

```
export.py:310: DeprecationWarning: Use build_serialized_network instead.
  with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
```

* Revert deprecation warning fix

@imyhxy it seems we can't apply the deprecation warning fix because then export fails, so I'm reverting my previous change here.

* Update export.py

* Update export.py

* Update common.py

* export onnx to file before building TensorRT engine file

* feat: triger ONNX export failed early

* feat: load ONNX model from file

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
modifyDataloader
imyhxy GitHub 3 years ago
parent
commit
7a39803476
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 11 deletions
  1. +2
    -2
      detect.py
  2. +54
    -1
      export.py
  3. +28
    -4
      models/common.py
  4. +6
    -4
      val.py

+ 2
- 2
detect.py View File

# Load model # Load model
device = select_device(device) device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn) model = DetectMultiBackend(weights, device=device, dnn=dnn)
stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size imgsz = check_img_size(imgsz, s=stride) # check image size


# Half # Half
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt: if pt:
model.model.half() if half else model.model.float() model.model.half() if half else model.model.float()



+ 54
- 1
export.py View File

TensorFlow GraphDef | yolov5s.pb | 'pb' TensorFlow GraphDef | yolov5s.pb | 'pb'
TensorFlow Lite | yolov5s.tflite | 'tflite' TensorFlow Lite | yolov5s.tflite | 'tflite'
TensorFlow.js | yolov5s_web_model/ | 'tfjs' TensorFlow.js | yolov5s_web_model/ | 'tfjs'
TensorRT | yolov5s.engine | 'engine'


Usage: Usage:
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs
yolov5s_saved_model yolov5s_saved_model
yolov5s.pb yolov5s.pb
yolov5s.tflite yolov5s.tflite
yolov5s.engine


TensorFlow.js: TensorFlow.js:
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
LOGGER.info(f'\n{prefix} export failure: {e}') LOGGER.info(f'\n{prefix} export failure: {e}')




def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
try:
check_requirements(('tensorrt',))
import tensorrt as trt

opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
export_onnx(model, im, file, opset, train, False, simplify)
onnx = file.with_suffix('.onnx')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'

LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
f = str(file).replace('.pt', '.engine') # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO)
if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30

flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)):
raise RuntimeError(f'failed to load ONNX file: {onnx}')

inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
LOGGER.info(f'{prefix} Network Description:')
for inp in inputs:
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
for out in outputs:
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')

half &= builder.platform_has_fast_fp16
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}')
if half:
config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
t.write(engine.serialize())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')

except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')

@torch.no_grad() @torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path weights=ROOT / 'yolov5s.pt', # weights path
dynamic=False, # ONNX/TF: dynamic axes dynamic=False, # ONNX/TF: dynamic axes
simplify=False, # ONNX: simplify model simplify=False, # ONNX: simplify model
opset=12, # ONNX: opset version opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB)
topk_per_class=100, # TF.js NMS: topk per class to keep topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold iou_thres=0.45, # TF.js NMS: IoU threshold
export_torchscript(model, im, file, optimize) export_torchscript(model, im, file, optimize)
if 'onnx' in include: if 'onnx' in include:
export_onnx(model, im, file, opset, train, dynamic, simplify) export_onnx(model, im, file, opset, train, dynamic, simplify)
if 'engine' in include:
export_engine(model, im, file, train, half, simplify, workspace, verbose)
if 'coreml' in include: if 'coreml' in include:
export_coreml(model, im, file) export_coreml(model, im, file)


parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version') parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep') parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument('--include', nargs='+', parser.add_argument('--include', nargs='+',
default=['torchscript', 'onnx'], default=['torchscript', 'onnx'],
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
help='available formats are (torchscript, onnx, engine, coreml, saved_model, pb, tflite, tfjs)')
opt = parser.parse_args() opt = parser.parse_args()
print_args(FILE.stem, opt) print_args(FILE.stem, opt)
return opt return opt

+ 28
- 4
models/common.py View File

import math import math
import platform import platform
import warnings import warnings
from collections import namedtuple
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path


# TensorFlow Lite: *.tflite # TensorFlow Lite: *.tflite
# ONNX Runtime: *.onnx # ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True # OpenCV DNN: *.onnx with dnn=True
# TensorRT: *.engine
super().__init__() super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights) w = str(weights[0] if isinstance(weights, list) else weights)
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
check_suffix(w, suffixes) # check weights have acceptable suffix check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
pt, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
jit = pt and 'torchscript' in w.lower() jit = pt and 'torchscript' in w.lower()
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults


check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime')) check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
import onnxruntime import onnxruntime
session = onnxruntime.InferenceSession(w, None) session = onnxruntime.InferenceSession(w, None)
elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
bindings = dict()
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
binding_addrs = {n: d.ptr for n, d in bindings.items()}
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model) else: # TensorFlow model (TFLite, pb, saved_model)
import tensorflow as tf import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
model = tf.keras.models.load_model(w) model = tf.keras.models.load_model(w)
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
if 'edgetpu' in w.lower(): if 'edgetpu' in w.lower():
LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...')
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
import tflite_runtime.interpreter as tfli import tflite_runtime.interpreter as tfli
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
'Darwin': 'libedgetpu.1.dylib', 'Darwin': 'libedgetpu.1.dylib',
y = self.net.forward() y = self.net.forward()
else: # ONNX Runtime else: # ONNX Runtime
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0] y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
elif self.engine: # TensorRT
assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
self.binding_addrs['images'] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = self.bindings['output'].data
else: # TensorFlow model (TFLite, pb, saved_model) else: # TensorFlow model (TFLite, pb, saved_model)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pb: if self.pb:
y[..., 1] *= h # y y[..., 1] *= h # y
y[..., 2] *= w # w y[..., 2] *= w # w
y[..., 3] *= h # h y[..., 3] *= h # h
y = torch.tensor(y)
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
return (y, []) if val else y return (y, []) if val else y





+ 6
- 4
val.py View File

# Initialize/load model and set device # Initialize/load model and set device
training = model is not None training = model is not None
if training: # called by train.py if training: # called by train.py
device, pt = next(model.parameters()).device, True # get model device, PyTorch model
device, pt, engine = next(model.parameters()).device, True, False # get model device, PyTorch model


half &= device.type != 'cpu' # half precision only supported on CUDA half &= device.type != 'cpu' # half precision only supported on CUDA
model.half() if half else model.float() model.half() if half else model.float()


# Load model # Load model
model = DetectMultiBackend(weights, device=device, dnn=dnn) model = DetectMultiBackend(weights, device=device, dnn=dnn)
stride, pt = model.stride, model.pt
stride, pt, engine = model.stride, model.pt, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size imgsz = check_img_size(imgsz, s=stride) # check image size
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt: if pt:
model.model.half() if half else model.model.float() model.model.half() if half else model.model.float()
elif engine:
batch_size = model.batch_size
else: else:
half = False half = False
batch_size = 1 # export.py models default to batch-size 1 batch_size = 1 # export.py models default to batch-size 1
pbar = tqdm(dataloader, desc=s, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar pbar = tqdm(dataloader, desc=s, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
for batch_i, (im, targets, paths, shapes) in enumerate(pbar): for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
t1 = time_sync() t1 = time_sync()
if pt:
if pt or engine:
im = im.to(device, non_blocking=True) im = im.to(device, non_blocking=True)
targets = targets.to(device) targets = targets.to(device)
im = im.half() if half else im.float() # uint8 to fp16/32 im = im.half() if half else im.float() # uint8 to fp16/32

Loading…
Cancel
Save