Skip to content

Commit

Permalink
[TF FE] Fix centernet and correct FloorDiv translator for signed inte…
Browse files Browse the repository at this point in the history
…ger type (openvinotoolkit#22684)

### Details:
- Centernet's topk operation returns large int32 values (greater than
1000000), even though they are integer `FloorDiv`/`Div(inp_1, inp_2) +
Floor` operation is performed in float16 and because of that it causes
accuracy problems.

- To solve this need to performs FloorDiv operation in integer with a
subgraph:
```
res = x / y; if x > 0 and y > 0
res = x / y - 1; if (x < 0 xor y < 0) and (x mod y != 0)
```
 - checked on separate bus: no degradations caused.

### Tickets:
 - CVS-130526

---------

Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
pavel-esir and rkazants authored Feb 14, 2024
1 parent b53fa91 commit 88b792e
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 11 deletions.
29 changes: 24 additions & 5 deletions src/frontends/tensorflow_common/src/op/binary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "openvino/op/bitwise_and.hpp"
#include "openvino/op/bitwise_or.hpp"
#include "openvino/op/bitwise_xor.hpp"
#include "openvino/op/ceiling.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp"
Expand All @@ -26,9 +27,11 @@
#include "openvino/op/minimum.hpp"
#include "openvino/op/mod.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/negative.hpp"
#include "openvino/op/not_equal.hpp"
#include "openvino/op/power.hpp"
#include "openvino/op/prelu.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/squared_difference.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
Expand All @@ -54,11 +57,27 @@ OutputVector translate_binary_op(const NodeContext& node,
OutputVector translate_floor_div_op(const NodeContext& node) {
auto floordiv_fn = [](const Output<Node>& x, const Output<Node>& y) -> shared_ptr<Node> {
auto out_type = x.get_element_type();
if (out_type.is_integral()) {
auto float_x = make_shared<v0::Convert>(x, element::f32);
auto float_y = make_shared<v0::Convert>(y, element::f32);
return make_shared<v0::Convert>(make_shared<v0::Floor>(make_shared<v1::Divide>(float_x, float_y)),
out_type);
if (out_type.is_integral() && out_type.is_signed()) {
// when integer inputs have different signs remainder should be taken into account
// res = x / y; if x > 0 and y > 0
// res = x / y - 1; if (x < 0 xor y < 0) and (x mod y != 0)

auto zero_const = make_shared<v0::Constant>(out_type, Shape{}, 0);
auto minus_one_const = make_shared<v0::Constant>(out_type, Shape{}, -1);

auto x_less_cond = make_shared<v1::Less>(x, zero_const);
auto y_less_cond = make_shared<v1::Less>(y, zero_const);
auto xor_cond = make_shared<v1::LogicalXor>(x_less_cond, y_less_cond);

auto div = make_shared<v1::Divide>(x, y, false);
auto mod_xy = make_shared<v1::Mod>(x, y);
auto cond_mod = make_shared<v1::NotEqual>(mod_xy, zero_const);

auto cond = make_shared<v1::LogicalAnd>(cond_mod, xor_cond);
auto reminder = make_shared<v1::Select>(cond, minus_one_const, zero_const);
return make_shared<v1::Add>(div, reminder);
} else if (out_type.is_integral() && !out_type.is_signed()) {
return make_shared<v1::Divide>(x, y);
} else {
return make_shared<v0::Floor>(make_shared<v1::Divide>(x, y));
}
Expand Down
2 changes: 1 addition & 1 deletion tests/layer_tests/common/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def generate_ir_python_api(coverage=False, **kwargs):

out_dir = kwargs['output_dir'] + os.sep + kwargs['model_name'] + ".xml"

# TODO: Remove usage of legacy params from layer tests and switch to convert_model from tools.ovc
# TODO: CVS-132151 Remove usage of legacy params from layer tests and switch to convert_model from tools.ovc
ov_model = convert_model(**kwargs)
serialize(ov_model, out_dir)

Expand Down
4 changes: 2 additions & 2 deletions tests/layer_tests/tensorflow_tests/test_tf_Div.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def create_div_net(self, input_shape, input_type):
dict(input_shape=[10, 20], input_type=np.float32),
dict(input_shape=[2, 3, 4], input_type=np.float32),
pytest.param(dict(input_shape=[8, 5], input_type=np.int32),
marks=pytest.mark.xfail(reason='Ticket TBD - Divide inconsistent behavior on different systems')),
marks=pytest.mark.xfail(reason='Ticket CVS-132377 - Divide inconsistent behavior on different systems')),
dict(input_shape=[], input_type=np.float32),
]

Expand All @@ -47,4 +47,4 @@ def test_div_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend):
self._test(*self.create_div_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend)
use_new_frontend=use_new_frontend)
84 changes: 81 additions & 3 deletions tests/layer_tests/tensorflow_tests/test_tf_FloorDiv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,26 @@

import numpy as np
import pytest
import platform

from common.tf_layer_test_class import CommonTFLayerTest
from common.utils.tf_utils import permute_nchw_to_nhwc

rng = np.random.default_rng()

def list_arm_platforms():
return ['arm', 'armv7l', 'aarch64', 'arm64', 'ARM64']

class TestFloorDiv(CommonTFLayerTest):
def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_frontend):
import tensorflow as tf

self.dtype = dtype
tf.compat.v1.reset_default_graph()

# Create the graph and model
with tf.compat.v1.Session() as sess:
x = tf.compat.v1.placeholder(dtype, x_shape, 'Input')
constant_value = np.array(-10).astype(dtype)
y = tf.constant(constant_value)
x = tf.raw_ops.Abs(x=x)
res = tf.raw_ops.FloorDiv(x=x, y=y)

tf.compat.v1.global_variables_initializer()
Expand All @@ -29,12 +32,28 @@ def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_f

return tf_net, ref_net

def _prepare_input(self, inputs_info):
tensor_name = list(inputs_info.keys())[0]
x_shape = inputs_info[tensor_name]
inputs_data = {}
if np.issubdtype(self.dtype, np.floating):
inputs_data[tensor_name] = rng.uniform(-5.0, 5.0, x_shape).astype(self.dtype)
elif np.issubdtype(self.dtype, np.signedinteger):
inputs_data[tensor_name] = rng.integers(-8, 8, x_shape).astype(self.dtype)
else:
inputs_data[tensor_name] = rng.integers(0, 8, x_shape).astype(self.dtype)
return inputs_data

# TODO: implement tests for 2 Consts + Add


test_data_1D = [
dict(x_shape=[], dtype=np.int32),
dict(x_shape=[2], dtype=np.int64),
dict(x_shape=[2, 4, 5], dtype=np.int32),
dict(x_shape=[2, 4, 5], dtype=np.uint32),
dict(x_shape=[2, 4, 5], dtype=np.uint64),

dict(x_shape=[], dtype=np.float32),
dict(x_shape=[2], dtype=np.float64),
dict(x_shape=[2, 4, 5], dtype=np.float32),
Expand All @@ -45,7 +64,66 @@ def create_add_placeholder_const_net(self, x_shape, dtype, ir_version, use_new_f
@pytest.mark.precommit_tf_fe
def test_add_placeholder_const_1D(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend):
if platform.system() == 'Linux' and platform.machine() in list_arm_platforms() and np.issubdtype(params['dtype'], np.signedinteger):
pytest.xfail(reason='Ticket CVS-132377 - Divide inconsistent behavior on different systems')

self._test(*self.create_add_placeholder_const_net(**params, ir_version=ir_version,
use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend)


class TestFloorDivStaticInput(CommonTFLayerTest):
min = -100
max = 200
step = 1
dtype = np.int32

def create_flordiv_tf_net(self, min, max, step, y, dtype, ir_version, use_new_frontend):
import tensorflow as tf
x = np.arange(min, max, step, dtype=dtype)

self.min = min
self.max = max
self.step = step
self.dtype = dtype

tf.compat.v1.reset_default_graph()

with tf.compat.v1.Session() as sess:
x = tf.compat.v1.placeholder(dtype, x.shape, 'Input')
y = tf.constant(np.array(y).astype(dtype))
res = tf.raw_ops.FloorDiv(x=x, y=y)

tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

ref_net = None

return tf_net, ref_net

def _prepare_input(self, inputs_dict):
for input in inputs_dict.keys():
inputs_dict[input] = np.arange(self.min, self.max, self.step, dtype=self.dtype)
return inputs_dict

test_inputs = [
dict(min=-20, max=20, step=1, y=[10]),
dict(min=-20, max=20, step=1, y=[5]),
dict(min=-20, max=20, step=1, y=[6]),
dict(min=-20, max=20, step=1, y=[-5]),
dict(min=-20, max=20, step=1, y=[-6]),
dict(min=-1e5, max=1e5, step=100, y=[1e5]),
]
@pytest.mark.parametrize("params", test_inputs)
@pytest.mark.parametrize("dtype", [np.int32, np.int64])
@pytest.mark.nightly
@pytest.mark.precommit_tf_fe
@pytest.mark.xfail(condition=platform.system() == 'Linux' and platform.machine() in list_arm_platforms(),
reason='Ticket CVS-132377 - Divide inconsistent behavior on different systems')
def test_floordiv(self, params, dtype, ie_device, precision, ir_version, temp_dir,
use_new_frontend):
self._test(*self.create_flordiv_tf_net(**params, dtype=dtype, ir_version=ir_version,
use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend)

0 comments on commit 88b792e

Please sign in to comment.