-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TF FE] Support Div operation for TensorFlow (#21730)
* Support Div operation for TensorFlow * Update test_tf_Div.py * Update div.cpp * Update op_table and common_op_table * update translate_div_op * print inputs * update div.cpp * set m_pythondiv to false * update div.cpp * update div.cpp * Update tests/layer_tests/tensorflow_tests/test_tf_Div.py * Update tests/layer_tests/tensorflow_tests/test_tf_Div.py --------- Co-authored-by: Roman Kazantsev <[email protected]>
- Loading branch information
Showing
4 changed files
with
111 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "common_op_table.hpp" | ||
#include "openvino/op/divide.hpp" | ||
|
||
using namespace std; | ||
using namespace ov::op; | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace tensorflow { | ||
namespace op { | ||
OutputVector translate_div_op(const NodeContext& node) { | ||
default_op_checks(node, 2, {"Div"}); | ||
auto x = node.get_input(0); | ||
auto y = node.get_input(1); | ||
// Check if the element type is a signed integer | ||
if (x.get_element_type().is_integral_number() && x.get_element_type().is_signed()) { | ||
// prepare auxiliary zero constants of the same type as the inputs | ||
auto const_zero = create_same_type_const_scalar<int32_t>(x, 0); | ||
|
||
// compute the modulus of x and y | ||
auto mod_result = make_shared<v1::Mod>(x, y); | ||
// compute a mask to get positions of non-zero values of mod result | ||
auto mod_non_zero = make_shared<v1::NotEqual>(mod_result, const_zero); | ||
|
||
// compute the division of x and y | ||
auto divide = make_shared<v1::Divide>(x, y); | ||
// compute a mask to get positions of negative values of division result | ||
auto div_is_neg = make_shared<v1::Less>(divide, const_zero); | ||
|
||
// compute a boolean mask of elements for non-zero values of Mod result and negative values of Divide result | ||
auto mask = make_shared<v1::LogicalAnd>(mod_non_zero, div_is_neg); | ||
|
||
// prepare auxiliary one constants of the same type as the inputs | ||
auto const_one = create_same_type_const_scalar<int32_t>(x, 1); | ||
// add 1 to the divide result | ||
auto add_result = make_shared<v1::Add>(divide, const_one); | ||
|
||
// select division results based on the mask | ||
// - perform floor division for non-negative values. | ||
// - round negative values to the nearest zero. | ||
auto div = make_shared<v1::Select>(mask, add_result, divide); | ||
set_node_name(node.get_name(), div); | ||
return div->outputs(); | ||
} else { | ||
// for other cases (non-signed-integer types) | ||
// compute regular division of x and y | ||
auto div = make_shared<v1::Divide>(x, y); | ||
set_node_name(node.get_name(), div); | ||
return div->outputs(); | ||
} | ||
} | ||
} // namespace op | ||
} // namespace tensorflow | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (C) 2018-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
import pytest | ||
import tensorflow as tf | ||
from common.tf_layer_test_class import CommonTFLayerTest | ||
|
||
|
||
class TestDiv(CommonTFLayerTest): | ||
def _prepare_input(self, inputs_info): | ||
assert 'x' in inputs_info | ||
assert 'y' in inputs_info | ||
x_shape = inputs_info['x'] | ||
y_shape = inputs_info['y'] | ||
inputs_data = {} | ||
# generate x and y | ||
inputs_data['x'] = np.random.randint(-10, 10, x_shape).astype(self.input_type) | ||
inputs_data['y'] = np.random.randint(1, 10, y_shape)*np.random.choice([-1,1], y_shape) | ||
return inputs_data | ||
|
||
def create_div_net(self, input_shape, input_type): | ||
self.input_type = input_type | ||
tf.compat.v1.reset_default_graph() | ||
# Create the graph and model | ||
with tf.compat.v1.Session() as sess: | ||
x = tf.compat.v1.placeholder(input_type, input_shape, 'x') | ||
y = tf.compat.v1.placeholder(input_type, input_shape, 'y') | ||
tf.raw_ops.Div(x=x, y=y) | ||
tf.compat.v1.global_variables_initializer() | ||
tf_net = sess.graph_def | ||
|
||
return tf_net, None | ||
|
||
test_data_basic = [ | ||
dict(input_shape=[10, 20], input_type=np.float32), | ||
dict(input_shape=[2, 3, 4], input_type=np.float32), | ||
pytest.param(dict(input_shape=[8, 5], input_type=np.int32), | ||
marks=pytest.mark.xfail(reason='Ticket TBD - Divide inconsistent behavior on different systems')), | ||
dict(input_shape=[], input_type=np.float32), | ||
] | ||
|
||
@pytest.mark.parametrize("params", test_data_basic) | ||
@pytest.mark.precommit_tf_fe | ||
@pytest.mark.nightly | ||
def test_div_basic(self, params, ie_device, precision, ir_version, temp_dir, | ||
use_new_frontend, use_old_api): | ||
self._test(*self.create_div_net(**params), | ||
ie_device, precision, ir_version, temp_dir=temp_dir, | ||
use_new_frontend=use_new_frontend, use_old_api=use_old_api) |