diff --git a/src/frontends/tensorflow/src/transformations/switch_merge_resolve.cpp b/src/frontends/tensorflow/src/transformations/switch_merge_resolve.cpp index af03b24519c29d..35c2cd1b7f23e1 100644 --- a/src/frontends/tensorflow/src/transformations/switch_merge_resolve.cpp +++ b/src/frontends/tensorflow/src/transformations/switch_merge_resolve.cpp @@ -15,6 +15,7 @@ #include "openvino/op/if.hpp" #include "openvino/op/parameter.hpp" #include "openvino/op/result.hpp" +#include "openvino/op/util/multi_subgraph_base.hpp" #include "tf_utils.hpp" using namespace ov; @@ -151,6 +152,16 @@ void insert_result_before_merge(const shared_ptr& merge_node, } // namespace bool pass::SwitchMergeResolver::run_on_model(const shared_ptr& m) { + // run this transformation recursively since this is a model pass + for (const auto& op : m->get_ordered_ops()) { + auto multisubgraph_op = as_type_ptr(op); + if (multisubgraph_op) { + for (size_t i = 0; i < multisubgraph_op->get_internal_subgraphs_size(); ++i) { + run_on_model(multisubgraph_op->get_function(static_cast(i))); + } + } + } + // split set of Switch and Merge nodes to clusters // where each cluster of Switch and Merge nodes will represent // the single If operation for fusing diff --git a/src/frontends/tensorflow_common/src/helper_transforms/const_to_result_remover.cpp b/src/frontends/tensorflow_common/src/helper_transforms/const_to_result_remover.cpp index 1963bcf47dae22..d16152ca492246 100644 --- a/src/frontends/tensorflow_common/src/helper_transforms/const_to_result_remover.cpp +++ b/src/frontends/tensorflow_common/src/helper_transforms/const_to_result_remover.cpp @@ -16,6 +16,8 @@ namespace tensorflow { namespace pass { bool ConstToResultRemover::run_on_model(const std::shared_ptr& m) { + // Note: need to perform this transformation only on the main ov::Model graph + // no need to apply it for sub-graphs! ResultVector results_to_remove; // look for isolated UnsupportedConst->Result sub-graphs to remove // also, find isolated Constant->Result sub-graphs to remove diff --git a/tests/layer_tests/tensorflow_tests/test_tf_While.py b/tests/layer_tests/tensorflow_tests/test_tf_While.py index 2a112700f30ad5..d4aaedf86854e6 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_While.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_While.py @@ -50,6 +50,7 @@ def body(x, y): test_data_basic = [ dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=False), + dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=True), dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=False), dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=True) ] @@ -109,6 +110,7 @@ def body(x, y): test_data_basic = [ dict(y_shape=[2, 3], lower_control_flow=False), + dict(y_shape=[2, 3], lower_control_flow=True), dict(y_shape=[2, 1, 4], lower_control_flow=False), dict(y_shape=[2, 1, 4], lower_control_flow=True) ] @@ -122,3 +124,77 @@ def test_while_basic(self, params, ie_device, precision, ir_version, temp_dir, self._test(*self.create_while_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_new_frontend=use_new_frontend, use_old_api=use_old_api) + + +class TestWhileWithNestedIf(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'x' in inputs_info, "Test error: inputs_info must contain `x`" + assert 'y' in inputs_info, "Test error: inputs_info must contain `y`" + x_shape = inputs_info['x'] + y_shape = inputs_info['y'] + inputs_data = {} + inputs_data['x'] = np.random.randint(1, 10, x_shape).astype(np.int32) + inputs_data['y'] = np.random.randint(-50, 50, y_shape).astype(np.int32) + return inputs_data + + def create_while_with_nested_if_net(self, y_shape, data_type, lower_control_flow): + from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 + def while_function(x, y): + @tf.function + def cond(x, y): + return tf.less(x, 10) + + @tf.function + def body(x, y): + # create If operation inside While body + # use different logic for updating y based on x + def if_op(cond, y): + def then_branch(): + y_new = tf.multiply(y, tf.constant(2, dtype=data_type)) + return y_new + + def else_branch(): + y_new = tf.subtract(y, tf.constant(55, dtype=data_type)) + return y_new + + if_op = tf.cond(cond, then_branch, else_branch) + output = tf.identity(if_op, name='if_op') + return output + + y_new = tf.add(y, tf.constant(2, dtype=data_type)) + cond = tf.less(x, 5) + y_new = if_op(cond, y_new) + x_new = tf.add(x, 1) + return x_new, y_new + + return tf.while_loop(cond, body, [x, y]) + + tf_while_graph = tf.function(while_function) + x = np.random.randint(9, 10, []).astype(data_type) + y = np.random.randint(-50, 50, y_shape).astype(data_type) + concrete_func = tf_while_graph.get_concrete_function(x, y) + + # lower_control_flow defines representation of While operation + # in case of lower_control_flow=True it is decomposed into LoopCond, NextIteration and TensorArray operations + frozen_func = convert_variables_to_constants_v2(concrete_func, + lower_control_flow=lower_control_flow) + + graph_def = frozen_func.graph.as_graph_def(add_shapes=True) + return graph_def, None + + test_data_basic = [ + dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=False), + dict(y_shape=[2, 3], data_type=np.int32, lower_control_flow=True), + dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=False), + dict(y_shape=[2, 1, 4], data_type=np.int32, lower_control_flow=True) + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + @pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182") + def test_while_with_nested_if_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_while_with_nested_if_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api)