Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF FE] Support Div operation for TensorFlow #21730

Merged
merged 23 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
09f86a6
Support Div operation for TensorFlow
sami0i Dec 18, 2023
ec3191d
Update test_tf_Div.py
sami0i Dec 22, 2023
537abc5
Update div.cpp
sami0i Dec 22, 2023
131ebdf
Update op_table and common_op_table
sami0i Dec 22, 2023
f20cae7
update translate_div_op
sami0i Dec 25, 2023
ac23640
Merge branch 'master' into div_21464
rkazants Dec 26, 2023
778c03f
print inputs
sami0i Dec 28, 2023
35e71eb
update div.cpp
sami0i Dec 28, 2023
1d79b1a
Merge branch 'div_21464' of github.com:sami0i/openvino into div_21464
sami0i Dec 28, 2023
5ff2e0d
set m_pythondiv to false
sami0i Dec 29, 2023
fdf55c7
update div.cpp
sami0i Jan 4, 2024
c697ea1
Merge branch 'master' into div_21464
rkazants Jan 5, 2024
9fd4de7
update div.cpp
sami0i Jan 5, 2024
cded7d1
Merge branch 'div_21464' of github.com:sami0i/openvino into div_21464
sami0i Jan 5, 2024
4097922
Update tests/layer_tests/tensorflow_tests/test_tf_Div.py
rkazants Jan 6, 2024
b5db386
Update tests/layer_tests/tensorflow_tests/test_tf_Div.py
rkazants Jan 6, 2024
feebad5
Update tests/layer_tests/tensorflow_tests/test_tf_Div.py
rkazants Jan 6, 2024
022e62c
Update tests/layer_tests/tensorflow_tests/test_tf_Div.py
rkazants Jan 6, 2024
6d82e8e
Update tests/layer_tests/tensorflow_tests/test_tf_Div.py
rkazants Jan 6, 2024
587e030
Update tests/layer_tests/tensorflow_tests/test_tf_Div.py
rkazants Jan 6, 2024
eeb02d3
Update tests/layer_tests/tensorflow_tests/test_tf_Div.py
rkazants Jan 6, 2024
89eb158
Update tests/layer_tests/tensorflow_tests/test_tf_Div.py
rkazants Jan 6, 2024
f2780f6
Merge branch 'master' into div_21464
rkazants Jan 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"BitwiseAnd", CreatorFunction(translate_binary_op<opset13::BitwiseAnd>)},
{"BitwiseOr", CreatorFunction(translate_binary_op<opset13::BitwiseOr>)},
{"BitwiseXor", CreatorFunction(translate_binary_op<opset13::BitwiseXor>)},
{"Div", CreatorFunction(translate_div_op)},
{"Equal", CreatorFunction(translate_binary_op<opset8::Equal>)},
{"FloorMod", CreatorFunction(translate_binary_op<opset8::FloorMod>)},
{"Greater", CreatorFunction(translate_binary_op<opset8::Greater>)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ OP_CONVERTER(translate_crop_and_resize_op);
OP_CONVERTER(translate_depth_to_space_op);
OP_CONVERTER(translate_depthwise_conv_2d_native_op);
OP_CONVERTER(translate_div_no_nan_op);
OP_CONVERTER(translate_div_op);
OP_CONVERTER(translate_mul_op);
OP_CONVERTER(translate_dynamic_partition_op);
OP_CONVERTER(translate_einsum_op);
Expand Down
59 changes: 59 additions & 0 deletions src/frontends/tensorflow_common/src/op/div.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "common_op_table.hpp"
#include "openvino/op/divide.hpp"

using namespace std;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_div_op(const NodeContext& node) {
rkazants marked this conversation as resolved.
Show resolved Hide resolved
default_op_checks(node, 2, {"Div"});
auto x = node.get_input(0);
auto y = node.get_input(1);
// Check if the element type is a signed integer
if (x.get_element_type().is_integral_number() && x.get_element_type().is_signed()) {
// prepare auxiliary zero constants of the same type as the inputs
auto const_zero = create_same_type_const_scalar<int32_t>(x, 0);

// compute the modulus of x and y
auto mod_result = make_shared<v1::Mod>(x, y);
// compute a mask to get positions of non-zero values of mod result
auto mod_non_zero = make_shared<v1::NotEqual>(mod_result, const_zero);

// compute the division of x and y
auto divide = make_shared<v1::Divide>(x, y);
// compute a mask to get positions of negative values of division result
auto div_is_neg = make_shared<v1::Less>(divide, const_zero);

// compute a boolean mask of elements for non-zero values of Mod result and negative values of Divide result
auto mask = make_shared<v1::LogicalAnd>(mod_non_zero, div_is_neg);

// prepare auxiliary one constants of the same type as the inputs
auto const_one = create_same_type_const_scalar<int32_t>(x, 1);
// add 1 to the divide result
auto add_result = make_shared<v1::Add>(divide, const_one);

// select division results based on the mask
// - perform floor division for non-negative values.
// - round negative values to the nearest zero.
auto div = make_shared<v1::Select>(mask, add_result, divide);
set_node_name(node.get_name(), div);
return div->outputs();
} else {
// for other cases (non-signed-integer types)
// compute regular division of x and y
auto div = make_shared<v1::Divide>(x, y);
set_node_name(node.get_name(), div);
return div->outputs();
}
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
52 changes: 52 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_Div.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
rkazants marked this conversation as resolved.
Show resolved Hide resolved
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest


class TestDiv(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'x' in inputs_info
assert 'y' in inputs_info
x_shape = inputs_info['x']
y_shape = inputs_info['y']
inputs_data = {}
# generate x and y
inputs_data['x'] = np.random.randint(-10, 10, x_shape).astype(self.input_type)
inputs_data['y'] = np.random.randint(1, 10, y_shape)*np.random.choice([-1,1], y_shape)
rkazants marked this conversation as resolved.
Show resolved Hide resolved
print("input x: \n{}".format(inputs_data['x']))
print("input y: \n{}".format(inputs_data['y']))
rkazants marked this conversation as resolved.
Show resolved Hide resolved
return inputs_data
rkazants marked this conversation as resolved.
Show resolved Hide resolved

def create_div_net(self, input_shape, input_type):
self.input_type = input_type
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
x = tf.compat.v1.placeholder(input_type, input_shape, 'x')
y = tf.compat.v1.placeholder(input_type, input_shape, 'y')
tf.raw_ops.Div(x=x, y=y)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

return tf_net, None

test_data_basic = [
dict(input_shape=[10, 20], input_type=np.float32),
rkazants marked this conversation as resolved.
Show resolved Hide resolved
dict(input_shape=[2, 3, 4], input_type=np.float32),
pytest.param(dict(input_shape=[8, 5], input_type=np.int32),
marks=pytest.mark.xfail(condition=platform.system() == 'Windows', reason='Ticket - TBD'))
rkazants marked this conversation as resolved.
Show resolved Hide resolved
dict(input_shape=[], input_type=np.float32),
]

@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_div_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_div_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
Loading