Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF FE] Deduce Switch-Merge predicate shape #27277

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ bool ov::pass::ReverseShapeAndTypeInfer::run_on_model(const std::shared_ptr<ov::
if_op->get_input_tensor(0).m_element_type = element::boolean;
is_changed = true;
}

// in case TensorFlow models, we can deduce predicate shape that must be a scalar
// If operations created by fusing Switch-Merge sub-graph contain tf_switch_merge_if rt-info
if (if_op->get_rt_info().count("tf_switch_merge_if") &&
rkazants marked this conversation as resolved.
Show resolved Hide resolved
if_op->get_rt_info()["tf_switch_merge_if"].as<bool>() &&
if_op->input_value(0).get_partial_shape().rank().is_dynamic()) {
if_op->get_input_tensor(0).m_partial_shape = ov::PartialShape({});
is_changed = true;
}
} else if (ov::as_type_ptr<ov::op::v1::ConvertLike>(op)) {
is_changed |= inherit_output_shape(op, {0});
is_changed |= inherit_output_type(op, {1});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ bool pass::SwitchMergeResolver::run_on_model(const shared_ptr<Model>& m) {
auto else_body = make_shared<Model>(else_results, else_params);

auto if_op = make_shared<v8::If>(cond);
// in case TensorFlow models, we can deduce predicate shape that must be a scalar
if_op->get_rt_info()["tf_switch_merge_if"] = true;

set_cf_marker(if_cf_marker, if_op);
if_op->set_then_body(then_body);
if_op->set_else_body(else_body);
Expand Down
45 changes: 45 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_SwitchMerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,48 @@ def test_merge_eliminating_several_cond_flows(self, params, cond_value, x_type,
self._test(*self.merge_eliminating_several_cond_flows_net(**params, cond_value=cond_value, x_type=x_type),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)


class TestSwitchMergeWithVariablePredicate(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'x:0' in inputs_info
x_shape = inputs_info['x:0']
inputs_data = {}
rng = np.random.default_rng()
inputs_data['x:0'] = rng.integers(-10, 10, x_shape).astype(np.float32)
inputs_data['cond:0'] = np.array(self.cond_value, dtype=bool)
return inputs_data

def switch_merge_with_variable_predicate_net(self, x_shape, cond_shape, cond_value):
self.cond_value = cond_value
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x')
cond = tf.compat.v1.placeholder(tf.bool, cond_shape, 'cond')
const_add = tf.constant(3, dtype=tf.float32)
const_sub = tf.constant(1, dtype=tf.float32)
switch_false, switch_true = tf.raw_ops.Switch(data=x, pred=cond)
add = tf.raw_ops.AddV2(x=switch_false, y=const_add)
sub = tf.raw_ops.Sub(x=switch_true, y=const_sub)
merge = tf.raw_ops.Merge(inputs=[add, sub])
const_main = tf.constant(1, dtype=tf.float32)
tf.raw_ops.AddV2(x=merge[0], y=const_main, name='add_res')
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

return tf_net, None

@pytest.mark.parametrize('x_shape', [[], [2], [3, 2]])
@pytest.mark.parametrize('cond_shape', [None, []])
@pytest.mark.parametrize('cond_value', [True, False])
@pytest.mark.precommit
@pytest.mark.nightly
def test_switch_merge_with_variable_predicate(self, x_shape, cond_shape, cond_value,
ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
if ie_device == 'GPU':
pytest.skip("156244: accuracy error on GPU")
self._test(*self.switch_merge_with_variable_predicate_net(x_shape, cond_shape, cond_value),
rkazants marked this conversation as resolved.
Show resolved Hide resolved
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)
Loading