|
|
@@ -247,11 +247,11 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F |
|
|
|
|
|
|
|
def export_saved_model(model, im, file, dynamic, |
|
|
|
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45, |
|
|
|
conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')): |
|
|
|
conf_thres=0.25, keras=False, prefix=colorstr('TensorFlow SavedModel:')): |
|
|
|
# YOLOv5 TensorFlow SavedModel export |
|
|
|
try: |
|
|
|
import tensorflow as tf |
|
|
|
from tensorflow import keras |
|
|
|
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 |
|
|
|
|
|
|
|
from models.tf import TFDetect, TFModel |
|
|
|
|
|
|
@@ -262,13 +262,26 @@ def export_saved_model(model, im, file, dynamic, |
|
|
|
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) |
|
|
|
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow |
|
|
|
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) |
|
|
|
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) |
|
|
|
inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) |
|
|
|
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) |
|
|
|
keras_model = keras.Model(inputs=inputs, outputs=outputs) |
|
|
|
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) |
|
|
|
keras_model.trainable = False |
|
|
|
keras_model.summary() |
|
|
|
keras_model.save(f, save_format='tf') |
|
|
|
|
|
|
|
if keras: |
|
|
|
keras_model.save(f, save_format='tf') |
|
|
|
else: |
|
|
|
m = tf.function(lambda x: keras_model(x)) # full model |
|
|
|
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype) |
|
|
|
m = m.get_concrete_function(spec) |
|
|
|
frozen_func = convert_variables_to_constants_v2(m) |
|
|
|
tfm = tf.Module() |
|
|
|
tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec]) |
|
|
|
tfm.__call__(im) |
|
|
|
tf.saved_model.save( |
|
|
|
tfm, |
|
|
|
f, |
|
|
|
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if |
|
|
|
check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions()) |
|
|
|
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') |
|
|
|
return keras_model, f |
|
|
|
except Exception as e: |