diff --git a/src/frontends/tensorflow_common/src/op/binary_op.cpp b/src/frontends/tensorflow_common/src/op/binary_op.cpp index 07d047251ab9b2..e992c1bfc0b760 100644 --- a/src/frontends/tensorflow_common/src/op/binary_op.cpp +++ b/src/frontends/tensorflow_common/src/op/binary_op.cpp @@ -8,6 +8,7 @@ #include "openvino/op/bitwise_and.hpp" #include "openvino/op/bitwise_or.hpp" #include "openvino/op/bitwise_xor.hpp" +#include "openvino/op/ceiling.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/divide.hpp" @@ -26,9 +27,11 @@ #include "openvino/op/minimum.hpp" #include "openvino/op/mod.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/negative.hpp" #include "openvino/op/not_equal.hpp" #include "openvino/op/power.hpp" #include "openvino/op/prelu.hpp" +#include "openvino/op/select.hpp" #include "openvino/op/squared_difference.hpp" #include "openvino/op/subtract.hpp" #include "openvino/op/unsqueeze.hpp" @@ -54,11 +57,27 @@ OutputVector translate_binary_op(const NodeContext& node, OutputVector translate_floor_div_op(const NodeContext& node) { auto floordiv_fn = [](const Output& x, const Output& y) -> shared_ptr { auto out_type = x.get_element_type(); - if (out_type.is_integral()) { - auto float_x = make_shared(x, element::f32); - auto float_y = make_shared(y, element::f32); - return make_shared(make_shared(make_shared(float_x, float_y)), - out_type); + if (out_type.is_integral() && out_type.is_signed()) { + // when integer inputs have different signs remainder should be taken into account + // res = x / y; if x > 0 and y > 0 + // res = x / y - 1; if (x < 0 xor y < 0) and (x mod y != 0) + + auto zero_const = make_shared(out_type, Shape{}, 0); + auto minus_one_const = make_shared(out_type, Shape{}, -1); + + auto x_less_cond = make_shared(x, zero_const); + auto y_less_cond = make_shared(y, zero_const); + auto xor_cond = make_shared(x_less_cond, y_less_cond); + + auto div = make_shared(x, y, false); + auto mod_xy = make_shared(x, y); + auto cond_mod = make_shared(mod_xy, zero_const); + + auto cond = make_shared(cond_mod, xor_cond); + auto reminder = make_shared(cond, minus_one_const, zero_const); + return make_shared(div, reminder); + } else if (out_type.is_integral() && !out_type.is_signed()) { + return make_shared(x, y); } else { return make_shared(make_shared(x, y)); } diff --git a/tests/layer_tests/common/utils/common_utils.py b/tests/layer_tests/common/utils/common_utils.py index fab02b88a66bdf..02483636fbb8fe 100644 --- a/tests/layer_tests/common/utils/common_utils.py +++ b/tests/layer_tests/common/utils/common_utils.py @@ -53,7 +53,7 @@ def generate_ir_python_api(coverage=False, **kwargs): out_dir = kwargs['output_dir'] + os.sep + kwargs['model_name'] + ".xml" - # TODO: Remove usage of legacy params from layer tests and switch to convert_model from tools.ovc + # TODO: CVS-132151 Remove usage of legacy params from layer tests and switch to convert_model from tools.ovc ov_model = convert_model(**kwargs) serialize(ov_model, out_dir) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Div.py b/tests/layer_tests/tensorflow_tests/test_tf_Div.py index 9d129e2280e476..cc14427f2ce920 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_Div.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_Div.py @@ -36,7 +36,7 @@ def create_div_net(self, input_shape, input_type): dict(input_shape=[10, 20], input_type=np.float32), dict(input_shape=[2, 3, 4], input_type=np.float32), pytest.param(dict(input_shape=[8, 5], input_type=np.int32), - marks=pytest.mark.xfail(reason='Ticket TBD - Divide inconsistent behavior on different systems')), + marks=pytest.mark.xfail(reason='Ticket CVS-132377 - Divide inconsistent behavior on different systems')), dict(input_shape=[], input_type=np.float32), ] @@ -47,4 +47,4 @@ def test_div_basic(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend): self._test(*self.create_div_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, - use_new_frontend=use_new_frontend) \ No newline at end of file + use_new_frontend=use_new_frontend) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_FloorDiv.py b/tests/layer_tests/tensorflow_tests/test_tf_FloorDiv.py index b49468fa8c7319..2b9dc2b1d9b061 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_FloorDiv.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_FloorDiv.py @@ -3,15 +3,19 @@ import numpy as np import pytest +import platform from common.tf_layer_test_class import CommonTFLayerTest -from common.utils.tf_utils import permute_nchw_to_nhwc +rng = np.random.default_rng() + +def list_arm_platforms(): + return ['arm', 'armv7l', 'aarch64', 'arm64', 'ARM64'] class TestFloorDiv(CommonTFLayerTest): def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_frontend): import tensorflow as tf - + self.dtype = dtype tf.compat.v1.reset_default_graph() # Create the graph and model @@ -19,7 +23,6 @@ def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_f x = tf.compat.v1.placeholder(dtype, x_shape, 'Input') constant_value = np.array(-10).astype(dtype) y = tf.constant(constant_value) - x = tf.raw_ops.Abs(x=x) res = tf.raw_ops.FloorDiv(x=x, y=y) tf.compat.v1.global_variables_initializer() @@ -29,12 +32,28 @@ def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_f return tf_net, ref_net + def _prepare_input(self, inputs_info): + tensor_name = list(inputs_info.keys())[0] + x_shape = inputs_info[tensor_name] + inputs_data = {} + if np.issubdtype(self.dtype, np.floating): + inputs_data[tensor_name] = rng.uniform(-5.0, 5.0, x_shape).astype(self.dtype) + elif np.issubdtype(self.dtype, np.signedinteger): + inputs_data[tensor_name] = rng.integers(-8, 8, x_shape).astype(self.dtype) + else: + inputs_data[tensor_name] = rng.integers(0, 8, x_shape).astype(self.dtype) + return inputs_data + # TODO: implement tests for 2 Consts + Add + test_data_1D = [ dict(x_shape=[], dtype=np.int32), dict(x_shape=[2], dtype=np.int64), dict(x_shape=[2, 4, 5], dtype=np.int32), + dict(x_shape=[2, 4, 5], dtype=np.uint32), + dict(x_shape=[2, 4, 5], dtype=np.uint64), + dict(x_shape=[], dtype=np.float32), dict(x_shape=[2], dtype=np.float64), dict(x_shape=[2, 4, 5], dtype=np.float32), @@ -45,7 +64,66 @@ def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_f @pytest.mark.precommit_tf_fe def test_add_placeholder_const_1D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend): + if platform.system() == 'Linux' and platform.machine() in list_arm_platforms() and np.issubdtype(params['dtype'], np.signedinteger): + pytest.xfail(reason='Ticket CVS-132377 - Divide inconsistent behavior on different systems') + self._test(*self.create_add_placeholder_const_net(**params, ir_version=ir_version, use_new_frontend=use_new_frontend), ie_device, precision, ir_version, temp_dir=temp_dir, use_new_frontend=use_new_frontend) + + +class TestFloorDivStaticInput(CommonTFLayerTest): + min = -100 + max = 200 + step = 1 + dtype = np.int32 + + def create_flordiv_tf_net(self, min, max, step, y, dtype, ir_version, use_new_frontend): + import tensorflow as tf + x = np.arange(min, max, step, dtype=dtype) + + self.min = min + self.max = max + self.step = step + self.dtype = dtype + + tf.compat.v1.reset_default_graph() + + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(dtype, x.shape, 'Input') + y = tf.constant(np.array(y).astype(dtype)) + res = tf.raw_ops.FloorDiv(x=x, y=y) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + def _prepare_input(self, inputs_dict): + for input in inputs_dict.keys(): + inputs_dict[input] = np.arange(self.min, self.max, self.step, dtype=self.dtype) + return inputs_dict + + test_inputs = [ + dict(min=-20, max=20, step=1, y=[10]), + dict(min=-20, max=20, step=1, y=[5]), + dict(min=-20, max=20, step=1, y=[6]), + dict(min=-20, max=20, step=1, y=[-5]), + dict(min=-20, max=20, step=1, y=[-6]), + dict(min=-1e5, max=1e5, step=100, y=[1e5]), + ] + @pytest.mark.parametrize("params", test_inputs) + @pytest.mark.parametrize("dtype", [np.int32, np.int64]) + @pytest.mark.nightly + @pytest.mark.precommit_tf_fe + @pytest.mark.xfail(condition=platform.system() == 'Linux' and platform.machine() in list_arm_platforms(), + reason='Ticket CVS-132377 - Divide inconsistent behavior on different systems') + def test_floordiv(self, params, dtype, ie_device, precision, ir_version, temp_dir, + use_new_frontend): + self._test(*self.create_flordiv_tf_net(**params, dtype=dtype, ir_version=ir_version, + use_new_frontend=use_new_frontend), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend)