From 53349dac8e9fb447bb43319811699bb72d1c2470 Mon Sep 17 00:00:00 2001 From: Phil2020 <35833843+phodgers@users.noreply.github.com> Date: Thu, 25 Nov 2021 16:54:00 +0000 Subject: [PATCH] Scope TF imports in `DetectMultiBackend()` (#5792) * tensorflow or tflite exclusively as interpreter As per bug report https://github.com/ultralytics/yolov5/issues/5709 I think there should be only one attempt to assign interpreter, and it appears tflite is only ever needed for the case of edgetpu model. * Scope imports * Nested definition line fix * Update common.py Co-authored-by: Glenn Jocher --- models/common.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/models/common.py b/models/common.py index 8836a65..284f03e 100644 --- a/models/common.py +++ b/models/common.py @@ -337,19 +337,21 @@ class DetectMultiBackend(nn.Module): context = model.create_execution_context() batch_size = bindings['images'].shape[0] else: # TensorFlow model (TFLite, pb, saved_model) - import tensorflow as tf if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...') + import tensorflow as tf + def wrap_frozen_graph(gd, inputs, outputs): x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs), tf.nest.map_structure(x.graph.as_graph_element, outputs)) - LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...') graph_def = tf.Graph().as_graph_def() graph_def.ParseFromString(open(w, 'rb').read()) frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0") elif saved_model: LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...') + import tensorflow as tf model = tf.keras.models.load_model(w) elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python if 'edgetpu' in w.lower(): @@ -361,6 +363,7 @@ class DetectMultiBackend(nn.Module): interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)]) else: LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') + import tensorflow as tf interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model interpreter.allocate_tensors() # allocate input_details = interpreter.get_input_details() # inputs