|
|
@@ -285,12 +285,12 @@ def export_saved_model(model, |
|
|
|
if keras: |
|
|
|
keras_model.save(f, save_format='tf') |
|
|
|
else: |
|
|
|
m = tf.function(lambda x: keras_model(x)) # full model |
|
|
|
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype) |
|
|
|
m = tf.function(lambda x: keras_model(x)) # full model |
|
|
|
m = m.get_concrete_function(spec) |
|
|
|
frozen_func = convert_variables_to_constants_v2(m) |
|
|
|
tfm = tf.Module() |
|
|
|
tfm.__call__ = tf.function(lambda x: frozen_func(x)[0], [spec]) |
|
|
|
tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec]) |
|
|
|
tfm.__call__(im) |
|
|
|
tf.saved_model.save(tfm, |
|
|
|
f, |