diff --git a/src/common/transformations/src/transformations/common_optimizations/reverse_shape_and_type_infer.cpp b/src/common/transformations/src/transformations/common_optimizations/reverse_shape_and_type_infer.cpp index 211f351da34024..9a06201f688675 100644 --- a/src/common/transformations/src/transformations/common_optimizations/reverse_shape_and_type_infer.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/reverse_shape_and_type_infer.cpp @@ -282,6 +282,15 @@ bool ov::pass::ReverseShapeAndTypeInfer::run_on_model(const std::shared_ptrget_input_tensor(0).m_element_type = element::boolean; is_changed = true; } + + // in case TensorFlow models, we can deduce predicate shape that must be a scalar + // If operations created by fusing Switch-Merge sub-graph contain tf_switch_merge_if rt-info + if (if_op->get_rt_info().count("tf_switch_merge_if") && + if_op->get_rt_info()["tf_switch_merge_if"].as() && + if_op->input_value(0).get_partial_shape().rank().is_dynamic()) { + if_op->get_input_tensor(0).m_partial_shape = ov::PartialShape({}); + is_changed = true; + } } else if (ov::as_type_ptr(op)) { is_changed |= inherit_output_shape(op, {0}); is_changed |= inherit_output_type(op, {1}); diff --git a/src/frontends/tensorflow/src/transformations/switch_merge_resolve.cpp b/src/frontends/tensorflow/src/transformations/switch_merge_resolve.cpp index 34b2a82152ccfc..cbdc506671aa67 100644 --- a/src/frontends/tensorflow/src/transformations/switch_merge_resolve.cpp +++ b/src/frontends/tensorflow/src/transformations/switch_merge_resolve.cpp @@ -235,6 +235,9 @@ bool pass::SwitchMergeResolver::run_on_model(const shared_ptr& m) { auto else_body = make_shared(else_results, else_params); auto if_op = make_shared(cond); + // in case TensorFlow models, we can deduce predicate shape that must be a scalar + if_op->get_rt_info()["tf_switch_merge_if"] = true; + set_cf_marker(if_cf_marker, if_op); if_op->set_then_body(then_body); if_op->set_else_body(else_body); diff --git a/tests/layer_tests/tensorflow_tests/test_tf_SwitchMerge.py b/tests/layer_tests/tensorflow_tests/test_tf_SwitchMerge.py index 96b73dd2134575..3747ab7a726aec 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_SwitchMerge.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_SwitchMerge.py @@ -63,3 +63,48 @@ def test_merge_eliminating_several_cond_flows(self, params, cond_value, x_type, self._test(*self.merge_eliminating_several_cond_flows_net(**params, cond_value=cond_value, x_type=x_type), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) + + +class TestSwitchMergeWithVariablePredicate(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'x:0' in inputs_info + x_shape = inputs_info['x:0'] + inputs_data = {} + rng = np.random.default_rng() + inputs_data['x:0'] = rng.integers(-10, 10, x_shape).astype(np.float32) + inputs_data['cond:0'] = np.array(self.cond_value, dtype=bool) + return inputs_data + + def switch_merge_with_variable_predicate_net(self, x_shape, cond_shape, cond_value): + self.cond_value = cond_value + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x') + cond = tf.compat.v1.placeholder(tf.bool, cond_shape, 'cond') + const_add = tf.constant(3, dtype=tf.float32) + const_sub = tf.constant(1, dtype=tf.float32) + switch_false, switch_true = tf.raw_ops.Switch(data=x, pred=cond) + add = tf.raw_ops.AddV2(x=switch_false, y=const_add) + sub = tf.raw_ops.Sub(x=switch_true, y=const_sub) + merge = tf.raw_ops.Merge(inputs=[add, sub]) + const_main = tf.constant(1, dtype=tf.float32) + tf.raw_ops.AddV2(x=merge[0], y=const_main, name='add_res') + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + @pytest.mark.parametrize('x_shape', [[], [2], [3, 2]]) + @pytest.mark.parametrize('cond_shape', [None, []]) + @pytest.mark.parametrize('cond_value', [True, False]) + @pytest.mark.precommit + @pytest.mark.nightly + def test_switch_merge_with_variable_predicate(self, x_shape, cond_shape, cond_value, + ie_device, precision, ir_version, temp_dir, + use_legacy_frontend): + if ie_device == 'GPU': + pytest.skip("156244: accuracy error on GPU") + self._test(*self.switch_merge_with_variable_predicate_net(x_shape, cond_shape, cond_value), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_legacy_frontend=use_legacy_frontend)