|
|
@@ -33,7 +33,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile |
|
|
|
|
|
|
|
FILE = Path(__file__).resolve() |
|
|
|
ROOT = FILE.parents[0] # yolov5/ dir |
|
|
|
sys.path.append(ROOT.as_posix()) # add yolov5/ to path |
|
|
|
if str(ROOT) not in sys.path: |
|
|
|
sys.path.append(str(ROOT)) # add ROOT to PATH |
|
|
|
|
|
|
|
from models.common import Conv |
|
|
|
from models.experimental import attempt_load |
|
|
@@ -174,7 +175,7 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')): |
|
|
|
print(f'\n{prefix} export failure: {e}') |
|
|
|
|
|
|
|
|
|
|
|
def export_tflite(keras_model, im, file, tfl_int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')): |
|
|
|
def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')): |
|
|
|
# YOLOv5 TensorFlow Lite export |
|
|
|
try: |
|
|
|
import tensorflow as tf |
|
|
@@ -187,7 +188,7 @@ def export_tflite(keras_model, im, file, tfl_int8, data, ncalib, prefix=colorstr |
|
|
|
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) |
|
|
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] |
|
|
|
converter.optimizations = [tf.lite.Optimize.DEFAULT] |
|
|
|
if tfl_int8: |
|
|
|
if int8: |
|
|
|
dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data |
|
|
|
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib) |
|
|
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] |
|
|
@@ -234,7 +235,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' |
|
|
|
inplace=False, # set YOLOv5 Detect() inplace=True |
|
|
|
train=False, # model.train() mode |
|
|
|
optimize=False, # TorchScript: optimize for mobile |
|
|
|
dynamic=False, # ONNX: dynamic axes |
|
|
|
int8=False, # CoreML/TF INT8 quantization |
|
|
|
dynamic=False, # ONNX/TF: dynamic axes |
|
|
|
simplify=False, # ONNX: simplify model |
|
|
|
opset=12, # ONNX: opset version |
|
|
|
): |
|
|
@@ -288,7 +290,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' |
|
|
|
if pb or tfjs: # pb prerequisite to tfjs |
|
|
|
export_pb(model, im, file) |
|
|
|
if tflite: |
|
|
|
export_tflite(model, im, file, tfl_int8=False, data=data, ncalib=100) |
|
|
|
export_tflite(model, im, file, int8=int8, data=data, ncalib=100) |
|
|
|
if tfjs: |
|
|
|
export_tfjs(model, im, file) |
|
|
|
|
|
|
@@ -309,6 +311,7 @@ def parse_opt(): |
|
|
|
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True') |
|
|
|
parser.add_argument('--train', action='store_true', help='model.train() mode') |
|
|
|
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile') |
|
|
|
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization') |
|
|
|
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') |
|
|
|
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') |
|
|
|
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version') |