Quellcode durchsuchen

Fix TF exports >= 2GB (#6292)

* Fix exporting saved_model: pb exceeds 2GB

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Replace TF v1.x API with TF v2.x API for saved_model export

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Clean up

* Remove lambda in tf.function()

* Revert "Remove lambda in tf.function()" to be compatible with TF v2.4

This reverts commit 46c7931f11.

* Fix for pre-commit.ci

* Cleanup1

* Cleanup2

* Backwards compatibility update

* Update common.py

* Update common.py

* Cleanup3

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
modifyDataloader
Jiacong Fang GitHub vor 2 Jahren
Ursprung
Commit
66e5d794c7
Es konnte kein GPG-Schlüssel zu dieser Signatur gefunden werden GPG-Schlüssel-ID: 4AEE18F83AFDEB23
2 geänderte Dateien mit 22 neuen und 8 gelöschten Zeilen
  1. +19
    -6
      export.py
  2. +3
    -2
      models/common.py

+ 19
- 6
export.py Datei anzeigen

@@ -247,11 +247,11 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F

def export_saved_model(model, im, file, dynamic,
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')):
conf_thres=0.25, keras=False, prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv5 TensorFlow SavedModel export
try:
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

from models.tf import TFDetect, TFModel

@@ -262,13 +262,26 @@ def export_saved_model(model, im, file, dynamic,
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
keras_model = keras.Model(inputs=inputs, outputs=outputs)
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
keras_model.trainable = False
keras_model.summary()
keras_model.save(f, save_format='tf')

if keras:
keras_model.save(f, save_format='tf')
else:
m = tf.function(lambda x: keras_model(x)) # full model
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
m = m.get_concrete_function(spec)
frozen_func = convert_variables_to_constants_v2(m)
tfm = tf.Module()
tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec])
tfm.__call__(im)
tf.saved_model.save(
tfm,
f,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if
check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return keras_model, f
except Exception as e:

+ 3
- 2
models/common.py Datei anzeigen

@@ -359,7 +359,8 @@ class DetectMultiBackend(nn.Module):
if saved_model: # SavedModel
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
import tensorflow as tf
model = tf.keras.models.load_model(w)
keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
import tensorflow as tf
@@ -431,7 +432,7 @@ class DetectMultiBackend(nn.Module):
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.saved_model: # SavedModel
y = self.model(im, training=False).numpy()
y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.tflite: # Lite

Laden…
Abbrechen
Speichern