Skip to content

Commit

Permalink
Add layer test for Case operation
Browse files Browse the repository at this point in the history
  • Loading branch information
shubdas9902 committed Dec 16, 2024
1 parent dcd5a16 commit ef61090
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions tests/layer_tests/tensorflow_tests/test_tf_Case_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

class TestCaseOp(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
"""
Prepares input data based on the given input shapes and data types.
"""
assert 'cond' in inputs_info
assert 'input_data' in inputs_info
inputs_data = {
Expand All @@ -14,7 +17,17 @@ def _prepare_input(self, inputs_info):
}
return inputs_data

def create_case_net(self, input_shape, branches, default_branch):
def create_case_net(self, input_shape, cond_value):
"""
Creates a TensorFlow model with a Case operation.
Args:
input_shape: Shape of the input tensor.
cond_value: The condition value to select the branch.
Returns:
TensorFlow graph definition and None.
"""
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
# Inputs
Expand All @@ -32,22 +45,27 @@ def branch_fn_2():

# Create Case operation
case_op = tf.raw_ops.Case(branch_index=cond, branches=branches_fn, output_type=tf.float32)

tf.identity(case_op, name="output")

tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

return tf_net, None

# Test parameters
test_data_basic = [
dict(input_shape=[1, 2], branches=2, default_branch=None, cond=True),
dict(input_shape=[3, 3], branches=2, default_branch=None, cond=False),
dict(input_shape=[1, 2], cond=True),
dict(input_shape=[3, 3], cond=False),
]

@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_case_op(self, params, ie_device, precision, ir_version, temp_dir,
def test_case_op(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
"""
Executes the test for the Case operation.
"""
self._test(*self.create_case_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

0 comments on commit ef61090

Please sign in to comment.