Skip to content

Commit

Permalink
[TF FE] Fix conversion of TF1 OD models out-of-the-box (openvinotoolk…
Browse files Browse the repository at this point in the history
…it#20916)

* [TF FE] Fix conversion of TF1 OD models out-of-the-box

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add test While with nested If operation

Signed-off-by: Kazantsev, Roman <[email protected]>

* Update tests/layer_tests/tensorflow_tests/test_tf_While.py

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Nov 7, 2023
1 parent ac1fb7b commit c6ca786
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "openvino/op/if.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/result.hpp"
#include "openvino/op/util/multi_subgraph_base.hpp"
#include "tf_utils.hpp"

using namespace ov;
Expand Down Expand Up @@ -151,6 +152,16 @@ void insert_result_before_merge(const shared_ptr<Merge>& merge_node,
} // namespace

bool pass::SwitchMergeResolver::run_on_model(const shared_ptr<Model>& m) {
// run this transformation recursively since this is a model pass
for (const auto& op : m->get_ordered_ops()) {
auto multisubgraph_op = as_type_ptr<ov::op::util::MultiSubGraphOp>(op);
if (multisubgraph_op) {
for (size_t i = 0; i < multisubgraph_op->get_internal_subgraphs_size(); ++i) {
run_on_model(multisubgraph_op->get_function(static_cast<int>(i)));
}
}
}

// split set of Switch and Merge nodes to clusters
// where each cluster of Switch and Merge nodes will represent
// the single If operation for fusing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ namespace tensorflow {
namespace pass {

bool ConstToResultRemover::run_on_model(const std::shared_ptr<ov::Model>& m) {
// Note: need to perform this transformation only on the main ov::Model graph
// no need to apply it for sub-graphs!
ResultVector results_to_remove;
// look for isolated UnsupportedConst->Result sub-graphs to remove
// also, find isolated Constant->Result sub-graphs to remove
Expand Down
76 changes: 76 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_While.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def body(x, y):

test_data_basic = [
dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=False),
dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=True),
dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=False),
dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=True)
]
Expand Down Expand Up @@ -109,6 +110,7 @@ def body(x, y):

test_data_basic = [
dict(y_shape=[2, 3], lower_control_flow=False),
dict(y_shape=[2, 3], lower_control_flow=True),
dict(y_shape=[2, 1, 4], lower_control_flow=False),
dict(y_shape=[2, 1, 4], lower_control_flow=True)
]
Expand All @@ -122,3 +124,77 @@ def test_while_basic(self, params, ie_device, precision, ir_version, temp_dir,
self._test(*self.create_while_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)


class TestWhileWithNestedIf(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'x' in inputs_info, "Test error: inputs_info must contain `x`"
assert 'y' in inputs_info, "Test error: inputs_info must contain `y`"
x_shape = inputs_info['x']
y_shape = inputs_info['y']
inputs_data = {}
inputs_data['x'] = np.random.randint(1, 10, x_shape).astype(np.int32)
inputs_data['y'] = np.random.randint(-50, 50, y_shape).astype(np.int32)
return inputs_data

def create_while_with_nested_if_net(self, y_shape, data_type, lower_control_flow):
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
def while_function(x, y):
@tf.function
def cond(x, y):
return tf.less(x, 10)

@tf.function
def body(x, y):
# create If operation inside While body
# use different logic for updating y based on x
def if_op(cond, y):
def then_branch():
y_new = tf.multiply(y, tf.constant(2, dtype=data_type))
return y_new

def else_branch():
y_new = tf.subtract(y, tf.constant(55, dtype=data_type))
return y_new

if_op = tf.cond(cond, then_branch, else_branch)
output = tf.identity(if_op, name='if_op')
return output

y_new = tf.add(y, tf.constant(2, dtype=data_type))
cond = tf.less(x, 5)
y_new = if_op(cond, y_new)
x_new = tf.add(x, 1)
return x_new, y_new

return tf.while_loop(cond, body, [x, y])

tf_while_graph = tf.function(while_function)
x = np.random.randint(9, 10, []).astype(data_type)
y = np.random.randint(-50, 50, y_shape).astype(data_type)
concrete_func = tf_while_graph.get_concrete_function(x, y)

# lower_control_flow defines representation of While operation
# in case of lower_control_flow=True it is decomposed into LoopCond, NextIteration and TensorArray operations
frozen_func = convert_variables_to_constants_v2(concrete_func,
lower_control_flow=lower_control_flow)

graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
return graph_def, None

test_data_basic = [
dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=False),
dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=True),
dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=False),
dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=True)
]

@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
@pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182")
def test_while_with_nested_if_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_while_with_nested_if_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

0 comments on commit c6ca786

Please sign in to comment.