From ef61090cff7ac7b9f3d75ca1db2128d748b912b1 Mon Sep 17 00:00:00 2001 From: shubdas9902 Date: Mon, 16 Dec 2024 21:06:52 +0530 Subject: [PATCH] Add layer test for Case operation --- .../tensorflow_tests/test_tf_Case_op.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py b/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py index d69f8185786375..12a7815247a26a 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py @@ -6,6 +6,9 @@ class TestCaseOp(CommonTFLayerTest): def _prepare_input(self, inputs_info): + """ + Prepares input data based on the given input shapes and data types. + """ assert 'cond' in inputs_info assert 'input_data' in inputs_info inputs_data = { @@ -14,7 +17,17 @@ def _prepare_input(self, inputs_info): } return inputs_data - def create_case_net(self, input_shape, branches, default_branch): + def create_case_net(self, input_shape, cond_value): + """ + Creates a TensorFlow model with a Case operation. + + Args: + input_shape: Shape of the input tensor. + cond_value: The condition value to select the branch. + + Returns: + TensorFlow graph definition and None. + """ tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: # Inputs @@ -32,22 +45,27 @@ def branch_fn_2(): # Create Case operation case_op = tf.raw_ops.Case(branch_index=cond, branches=branches_fn, output_type=tf.float32) - + tf.identity(case_op, name="output") + tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def return tf_net, None + # Test parameters test_data_basic = [ - dict(input_shape=[1, 2], branches=2, default_branch=None, cond=True), - dict(input_shape=[3, 3], branches=2, default_branch=None, cond=False), + dict(input_shape=[1, 2], cond=True), + dict(input_shape=[3, 3], cond=False), ] @pytest.mark.parametrize("params", test_data_basic) @pytest.mark.precommit_tf_fe @pytest.mark.nightly - def test_case_op(self, params, ie_device, precision, ir_version, temp_dir, + def test_case_op(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api): + """ + Executes the test for the Case operation. + """ self._test(*self.create_case_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_new_frontend=use_new_frontend, use_old_api=use_old_api)