Skip to content

Commit

Permalink
[TF FE] Allow any model format extension (openvinotoolkit#21508)
Browse files Browse the repository at this point in the history
Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored and akuporos committed Dec 8, 2023
1 parent 1413799 commit 8a27382
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
9 changes: 4 additions & 5 deletions src/frontends/tensorflow/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
// avoid parsing of checkpoints here
if (variants[0].is<std::string>()) {
std::string model_path = variants[0].as<std::string>();
if (ov::util::ends_with(model_path, ".pb") && GraphIteratorProto::is_supported(model_path)) {
if (GraphIteratorProto::is_supported(model_path)) {
// handle binary protobuf format
// for automatic deduction of the frontend to convert the model
// we have more strict rule that is to have `.pb` extension in the path
return true;
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
return true;
} else if (ov::util::ends_with(model_path, ".meta") && GraphIteratorMeta::is_supported(model_path)) {
} else if (GraphIteratorMeta::is_supported(model_path)) {
return true;
} else if (GraphIteratorProtoTxt::is_supported(model_path)) {
// handle text protobuf format
Expand All @@ -161,15 +161,14 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
else if (variants[0].is<std::wstring>()) {
std::wstring model_path = variants[0].as<std::wstring>();
if (ov::util::ends_with(model_path, std::wstring(L".pb")) && GraphIteratorProto::is_supported(model_path)) {
if (GraphIteratorProto::is_supported(model_path)) {
// handle binary protobuf format with a path in Unicode
// for automatic deduction of the frontend to convert the model
// we have more strict rule that is to have `.pb` extension in the path
return true;
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
return true;
} else if (ov::util::ends_with(model_path, std::wstring(L".meta")) &&
GraphIteratorMeta::is_supported(model_path)) {
} else if (GraphIteratorMeta::is_supported(model_path)) {
return true;
} else if (GraphIteratorProtoTxt::is_supported(model_path)) {
// handle text protobuf format
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/tensorflow/tests/convert_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using TFConvertModelTest = FrontEndConvertModelTest;

static const std::vector<std::string> models{
std::string("2in_2out/2in_2out.pb"),
std::string("2in_2out/2in_2out.pb.frozen"),
std::string("2in_2out/2in_2out.pb.frozen_text"),
std::string("forward_edge_model/forward_edge_model.pbtxt"),
std::string("forward_edge_model2/forward_edge_model2.pbtxt"),
std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pbtxt"),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import os
import sys

import numpy as np
import tensorflow as tf

tf.compat.v1.reset_default_graph()
Expand Down Expand Up @@ -33,3 +34,5 @@
tf_net = sess.graph_def

tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "2in_2out"), '2in_2out.pb', False)
tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "2in_2out"), '2in_2out.pb.frozen', False)
tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "2in_2out"), '2in_2out.pb.frozen_text', True)

0 comments on commit 8a27382

Please sign in to comment.