From 33b2e6bb518e22f87575aca24d50412bee550520 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Thu, 7 Dec 2023 09:21:38 +0400 Subject: [PATCH] [TF FE] Allow any model format extension (#21508) Signed-off-by: Kazantsev, Roman --- src/frontends/tensorflow/src/frontend.cpp | 9 ++++----- src/frontends/tensorflow/tests/convert_model.cpp | 2 ++ .../tests/test_models/gen_scripts/generate_2in_2out.py | 5 ++++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp index b52931d726e692..020e8cd3ecf4db 100644 --- a/src/frontends/tensorflow/src/frontend.cpp +++ b/src/frontends/tensorflow/src/frontend.cpp @@ -129,14 +129,14 @@ bool FrontEnd::supported_impl(const std::vector& variants) const { // avoid parsing of checkpoints here if (variants[0].is()) { std::string model_path = variants[0].as(); - if (ov::util::ends_with(model_path, ".pb") && GraphIteratorProto::is_supported(model_path)) { + if (GraphIteratorProto::is_supported(model_path)) { // handle binary protobuf format // for automatic deduction of the frontend to convert the model // we have more strict rule that is to have `.pb` extension in the path return true; } else if (GraphIteratorSavedModel::is_supported(model_path)) { return true; - } else if (ov::util::ends_with(model_path, ".meta") && GraphIteratorMeta::is_supported(model_path)) { + } else if (GraphIteratorMeta::is_supported(model_path)) { return true; } else if (GraphIteratorProtoTxt::is_supported(model_path)) { // handle text protobuf format @@ -161,15 +161,14 @@ bool FrontEnd::supported_impl(const std::vector& variants) const { #if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) else if (variants[0].is()) { std::wstring model_path = variants[0].as(); - if (ov::util::ends_with(model_path, std::wstring(L".pb")) && GraphIteratorProto::is_supported(model_path)) { + if (GraphIteratorProto::is_supported(model_path)) { // handle binary protobuf format with a path in Unicode // for automatic deduction of the frontend to convert the model // we have more strict rule that is to have `.pb` extension in the path return true; } else if (GraphIteratorSavedModel::is_supported(model_path)) { return true; - } else if (ov::util::ends_with(model_path, std::wstring(L".meta")) && - GraphIteratorMeta::is_supported(model_path)) { + } else if (GraphIteratorMeta::is_supported(model_path)) { return true; } else if (GraphIteratorProtoTxt::is_supported(model_path)) { // handle text protobuf format diff --git a/src/frontends/tensorflow/tests/convert_model.cpp b/src/frontends/tensorflow/tests/convert_model.cpp index f6ec18cf9cc12c..5419b2c4f77c6d 100644 --- a/src/frontends/tensorflow/tests/convert_model.cpp +++ b/src/frontends/tensorflow/tests/convert_model.cpp @@ -13,6 +13,8 @@ using TFConvertModelTest = FrontEndConvertModelTest; static const std::vector models{ std::string("2in_2out/2in_2out.pb"), + std::string("2in_2out/2in_2out.pb.frozen"), + std::string("2in_2out/2in_2out.pb.frozen_text"), std::string("forward_edge_model/forward_edge_model.pbtxt"), std::string("forward_edge_model2/forward_edge_model2.pbtxt"), std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pbtxt"), diff --git a/src/frontends/tensorflow/tests/test_models/gen_scripts/generate_2in_2out.py b/src/frontends/tensorflow/tests/test_models/gen_scripts/generate_2in_2out.py index 33264d2c6c749b..42f022c001c262 100644 --- a/src/frontends/tensorflow/tests/test_models/gen_scripts/generate_2in_2out.py +++ b/src/frontends/tensorflow/tests/test_models/gen_scripts/generate_2in_2out.py @@ -1,9 +1,10 @@ # Copyright (C) 2018-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import numpy as np import os import sys + +import numpy as np import tensorflow as tf tf.compat.v1.reset_default_graph() @@ -33,3 +34,5 @@ tf_net = sess.graph_def tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "2in_2out"), '2in_2out.pb', False) +tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "2in_2out"), '2in_2out.pb.frozen', False) +tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "2in_2out"), '2in_2out.pb.frozen_text', True)