* Assert engine precision #6777 * Default to FP32 inputs for TensorRT engines * Default to FP16 TensorRT exports #6777 * Remove wrong line #6777 * Automatically adjust detect.py input precision #6777 * Automatically adjust val.py input precision #6777 * Add missing colon * Cleanup * Cleanup * Remove default trt_fp16_input definition * Experiment * Reorder detect.py if statement to after half checks * Update common.py * Update export.py * Cleanup Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>modifyDataloader
@@ -97,6 +97,10 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) | |||
half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA | |||
if pt or jit: | |||
model.model.half() if half else model.model.float() | |||
elif engine and model.trt_fp16_input != half: | |||
LOGGER.info('model ' + ( | |||
'requires' if model.trt_fp16_input else 'incompatible with') + ' --half. Adjusting automatically.') | |||
half = model.trt_fp16_input | |||
# Dataloader | |||
if webcam: |
@@ -233,9 +233,8 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F | |||
for out in outputs: | |||
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') | |||
half &= builder.platform_has_fast_fp16 | |||
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}') | |||
if half: | |||
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 else 32} engine in {f}') | |||
if builder.platform_has_fast_fp16: | |||
config.set_flag(trt.BuilderFlag.FP16) | |||
with builder.build_engine(network, config) as engine, open(f, 'wb') as t: | |||
t.write(engine.serialize()) |
@@ -338,6 +338,7 @@ class DetectMultiBackend(nn.Module): | |||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download | |||
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0 | |||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) | |||
trt_fp16_input = False | |||
logger = trt.Logger(trt.Logger.INFO) | |||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime: | |||
model = runtime.deserialize_cuda_engine(f.read()) | |||
@@ -348,6 +349,8 @@ class DetectMultiBackend(nn.Module): | |||
shape = tuple(model.get_binding_shape(index)) | |||
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device) | |||
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) | |||
if model.binding_is_input(index) and dtype == np.float16: | |||
trt_fp16_input = True | |||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) | |||
context = model.create_execution_context() | |||
batch_size = bindings['images'].shape[0] |
@@ -144,6 +144,10 @@ def run(data, | |||
model.model.half() if half else model.model.float() | |||
elif engine: | |||
batch_size = model.batch_size | |||
if model.trt_fp16_input != half: | |||
LOGGER.info('model ' + ( | |||
'requires' if model.trt_fp16_input else 'incompatible with') + ' --half. Adjusting automatically.') | |||
half = model.trt_fp16_input | |||
else: | |||
half = False | |||
batch_size = 1 # export.py models default to batch-size 1 |