|
|
@@ -327,7 +327,7 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')): |
|
|
|
LOGGER.info(f'\n{prefix} export failure: {e}') |
|
|
|
|
|
|
|
|
|
|
|
def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')): |
|
|
|
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')): |
|
|
|
# YOLOv5 TensorFlow Lite export |
|
|
|
try: |
|
|
|
import tensorflow as tf |
|
|
@@ -343,13 +343,15 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te |
|
|
|
if int8: |
|
|
|
from models.tf import representative_dataset_gen |
|
|
|
dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data |
|
|
|
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib) |
|
|
|
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100) |
|
|
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] |
|
|
|
converter.target_spec.supported_types = [] |
|
|
|
converter.inference_input_type = tf.uint8 # or tf.int8 |
|
|
|
converter.inference_output_type = tf.uint8 # or tf.int8 |
|
|
|
converter.experimental_new_quantizer = True |
|
|
|
f = str(file).replace('.pt', '-int8.tflite') |
|
|
|
if nms or agnostic_nms: |
|
|
|
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS) |
|
|
|
|
|
|
|
tflite_model = converter.convert() |
|
|
|
open(f, "wb").write(tflite_model) |
|
|
@@ -524,7 +526,7 @@ def run( |
|
|
|
if pb or tfjs: # pb prerequisite to tfjs |
|
|
|
f[6] = export_pb(model, im, file) |
|
|
|
if tflite or edgetpu: |
|
|
|
f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100) |
|
|
|
f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) |
|
|
|
if edgetpu: |
|
|
|
f[8] = export_edgetpu(model, im, file) |
|
|
|
if tfjs: |