Browse Source

Automatic TFLite uint8 determination (#4515)

* Auto TFLite uint8 detection

This PR automatically determines if TFLite models are uint8 quantized rather than accepting a manual argument.

The quantization determination is based on @zldrobit comment https://github.com/ultralytics/yolov5/pull/1127#issuecomment-901713847

* Cleanup
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
79af1144c2
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 6 deletions
  1. +5
    -6
      detect.py

+ 5
- 6
detect.py View File

@@ -52,7 +52,6 @@ def run(weights='yolov5s.pt', # model.pt path(s)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
tfl_int8=False, # INT8 quantized TFLite model
):
save_img = not nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
@@ -104,6 +103,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
int8 = input_details[0]['dtype'] == np.uint8 # is TFLite quantized uint8 model
imgsz = check_img_size(imgsz, s=stride) # check image size

# Dataloader
@@ -145,15 +145,15 @@ def run(weights='yolov5s.pt', # model.pt path(s)
elif saved_model:
pred = model(imn, training=False).numpy()
elif tflite:
if tfl_int8:
if int8:
scale, zero_point = input_details[0]['quantization']
imn = (imn / scale + zero_point).astype(np.uint8)
imn = (imn / scale + zero_point).astype(np.uint8) # de-scale
interpreter.set_tensor(input_details[0]['index'], imn)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
if tfl_int8:
if int8:
scale, zero_point = output_details[0]['quantization']
pred = (pred.astype(np.float32) - zero_point) * scale
pred = (pred.astype(np.float32) - zero_point) * scale # re-scale
pred[..., 0] *= imgsz[1] # x
pred[..., 1] *= imgsz[0] # y
pred[..., 2] *= imgsz[1] # w
@@ -268,7 +268,6 @@ def parse_opt():
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
return opt

Loading…
Cancel
Save