Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF FE] Allow any model format extension #21508

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/frontends/tensorflow/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,14 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
// avoid parsing of checkpoints here
if (variants[0].is<std::string>()) {
std::string model_path = variants[0].as<std::string>();
if (ov::util::ends_with(model_path, ".pb") && GraphIteratorProto::is_supported(model_path)) {
if (GraphIteratorProto::is_supported(model_path)) {
// handle binary protobuf format
// for automatic deduction of the frontend to convert the model
// we have more strict rule that is to have `.pb` extension in the path
return true;
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
return true;
} else if (ov::util::ends_with(model_path, ".meta") && GraphIteratorMeta::is_supported(model_path)) {
} else if (GraphIteratorMeta::is_supported(model_path)) {
return true;
} else if (GraphIteratorProtoTxt::is_supported(model_path)) {
// handle text protobuf format
Expand All @@ -149,15 +149,14 @@ bool FrontEnd::supported_impl(const std::vector<ov::Any>& variants) const {
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
else if (variants[0].is<std::wstring>()) {
std::wstring model_path = variants[0].as<std::wstring>();
if (ov::util::ends_with(model_path, std::wstring(L".pb")) && GraphIteratorProto::is_supported(model_path)) {
if (GraphIteratorProto::is_supported(model_path)) {
// handle binary protobuf format with a path in Unicode
// for automatic deduction of the frontend to convert the model
// we have more strict rule that is to have `.pb` extension in the path
return true;
} else if (GraphIteratorSavedModel::is_supported(model_path)) {
return true;
} else if (ov::util::ends_with(model_path, std::wstring(L".meta")) &&
GraphIteratorMeta::is_supported(model_path)) {
} else if (GraphIteratorMeta::is_supported(model_path)) {
return true;
} else if (GraphIteratorProtoTxt::is_supported(model_path)) {
// handle text protobuf format
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/tensorflow/tests/convert_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using TFConvertModelTest = FrontEndConvertModelTest;

static const std::vector<std::string> models{
std::string("2in_2out/2in_2out.pb"),
std::string("2in_2out/2in_2out.pb.frozen"),
std::string("2in_2out/2in_2out.pb.frozen_text"),
std::string("forward_edge_model/forward_edge_model.pbtxt"),
std::string("forward_edge_model2/forward_edge_model2.pbtxt"),
std::string("concat_with_non_constant_axis/concat_with_non_constant_axis.pbtxt"),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import os
import sys

import numpy as np
import tensorflow as tf

tf.compat.v1.reset_default_graph()
Expand Down Expand Up @@ -33,3 +34,5 @@
tf_net = sess.graph_def

tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "2in_2out"), '2in_2out.pb', False)
tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "2in_2out"), '2in_2out.pb.frozen', False)
tf.io.write_graph(tf_net, os.path.join(sys.argv[1], "2in_2out"), '2in_2out.pb.frozen_text', True)
Loading