Skip to content

Commit

Permalink
[TF FE] Stabilize binary comparison ops (Equal, NonEqual, Less, etc.)…
Browse files Browse the repository at this point in the history
… tests on all platforms (openvinotoolkit#26205)

**Details:** Stabilize binary comparison ops (Equal, NonEqual, Less,
etc.) tests on all platforms

**Ticket:** 145795

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Aug 23, 2024
1 parent 7a8ae55 commit 79c966b
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 11 deletions.
96 changes: 96 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_BinaryComparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
from common.utils.tf_utils import mix_array_with_several_values

rng = np.random.default_rng(23356)


class TestBinaryComparison(CommonTFLayerTest):
def _generate_value(self, input_shape, input_type):
if np.issubdtype(input_type, np.floating):
gen_value = rng.uniform(-5.0, 5.0, input_shape).astype(input_type)
elif np.issubdtype(input_type, np.signedinteger):
gen_value = rng.integers(-8, 8, input_shape).astype(input_type)
else:
gen_value = rng.integers(8, 16, input_shape).astype(input_type)
return gen_value

def _prepare_input(self, inputs_info):
assert 'x:0' in inputs_info, "Test error: inputs_info must contain `x`"
x_shape = inputs_info['x:0']
input_type = self.input_type

inputs_data = {}
y_value = None
if self.is_const:
y_value = self.y_value
y_shape = y_value.shape
else:
assert 'y:0' in inputs_info, "Test error: inputs_info must contain `y`"
y_shape = inputs_info['y:0']
y_value = self._generate_value(y_shape, input_type)
inputs_data['y:0'] = y_value

# generate x value so that some elements will be equal, less, greater than y value element-wise
squeeze_dims = 0 if len(y_shape) <= len(x_shape) else len(y_shape) - len(x_shape)
zeros_list = [0] * squeeze_dims
y_value = y_value[tuple(zeros_list)]
y_value_minus_one = y_value - 1
y_value_plus_one = y_value + 1

x_value = self._generate_value(x_shape, input_type)
# mix input data with preferable values
x_value = mix_array_with_several_values(x_value, [y_value, y_value_plus_one, y_value_minus_one], rng)
inputs_data['x:0'] = x_value

return inputs_data

def create_binary_comparison_net(self, input_shape1, input_shape2, binary_op, is_const, input_type):
compare_ops_map = {
'Equal': tf.raw_ops.Equal,
'NotEqual': tf.raw_ops.NotEqual,
'Greater': tf.raw_ops.Greater,
'GreaterEqual': tf.raw_ops.GreaterEqual,
'Less': tf.raw_ops.Less,
'LessEqual': tf.raw_ops.LessEqual,
}

self.input_type = input_type
self.is_const = is_const
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
x = tf.compat.v1.placeholder(input_type, input_shape1, 'x')
y = tf.compat.v1.placeholder(input_type, input_shape2, 'y')
if is_const:
self.y_value = self._generate_value(input_shape2, input_type)
y = tf.constant(self.y_value, dtype=input_type)
compare_ops_map[binary_op](x=x, y=y)
tf.compat.v1.global_variables_initializer()

tf_net = sess.graph_def

return tf_net, None

@pytest.mark.parametrize('input_shape1', [[], [4], [3, 4], [2, 3, 4]])
@pytest.mark.parametrize('input_shape2', [[4], [3, 4]])
@pytest.mark.parametrize('binary_op', ['Equal', 'NotEqual', 'Greater', 'GreaterEqual', 'Less', 'LessEqual'])
@pytest.mark.parametrize('is_const', [False, True])
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16,
np.int32, np.int64,
np.float16, np.float32, np.float64])
@pytest.mark.precommit
@pytest.mark.nightly
def test_binary_comparison(self, input_shape1, input_shape2, binary_op, is_const, input_type,
ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
if ie_device == 'GPU' and input_type == np.int16 and is_const:
pytest.skip('150501: Accuracy error on GPU for int16 type and constant operand')
self._test(*self.create_binary_comparison_net(input_shape1, input_shape2, binary_op, is_const, input_type),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)
13 changes: 2 additions & 11 deletions tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,7 @@ def create_add_placeholder_const_net(self, x_shape, y_shape, op_type):
'Pow': tf.raw_ops.Pow,
'Maximum': tf.raw_ops.Maximum,
'Minimum': tf.raw_ops.Minimum,
'Equal': tf.raw_ops.Equal,
'NotEqual': tf.raw_ops.NotEqual,
'Mod': tf.raw_ops.Mod,
'Greater': tf.raw_ops.Greater,
'GreaterEqual': tf.raw_ops.GreaterEqual,
'Less': tf.raw_ops.Less,
'LessEqual': tf.raw_ops.LessEqual,
'LogicalAnd': tf.raw_ops.LogicalAnd,
'LogicalOr': tf.raw_ops.LogicalOr,
'FloorMod': tf.raw_ops.FloorMod,
Expand Down Expand Up @@ -95,8 +89,8 @@ def create_add_placeholder_const_net(self, x_shape, y_shape, op_type):
@pytest.mark.parametrize('y_shape', [[4], [2, 3, 4]])
@pytest.mark.parametrize("op_type",
['Add', 'AddV2', 'Sub', 'Mul', 'Div', 'RealDiv', 'SquaredDifference', 'Pow',
'Maximum', 'Minimum', 'Equal', 'NotEqual', 'Mod', 'Greater', 'GreaterEqual', 'Less',
'LessEqual', 'LogicalAnd', 'LogicalOr', 'FloorMod', 'FloorDiv', 'Xdivy'])
'Maximum', 'Minimum', 'Mod', 'LogicalAnd', 'LogicalOr', 'FloorMod',
'FloorDiv', 'Xdivy'])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64',
Expand All @@ -109,8 +103,5 @@ def test_binary_op(self, x_shape, y_shape, ie_device, precision, ir_version, tem
pytest.skip("For Mod and Pow GPU has inference mismatch")
if op_type in ['Mod', 'FloorDiv', 'FloorMod']:
pytest.skip("Inference mismatch for Mod and FloorDiv")
if ie_device == 'GPU' and precision == 'FP16' and op_type in ['Equal', 'NotEqual', 'Greater', 'GreaterEqual',
'Less', 'LessEqual']:
pytest.skip("Accuracy mismatch on GPU")
self._test(*self.create_add_placeholder_const_net(x_shape=x_shape, y_shape=y_shape, op_type=op_type), ie_device,
precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend)

0 comments on commit 79c966b

Please sign in to comment.