Skip to content

Commit

Permalink
[TF FE] Supported Complex tensor for AddN op (openvinotoolkit#23614)
Browse files Browse the repository at this point in the history
### Details:
- extended addN.cpp file to support complex tensor
- extended test_tf_addN.py with test for complex input

### Tickets:
 - openvinotoolkit#22944

---------

Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
2 people authored and alvoron committed Apr 29, 2024
1 parent 9eb07ed commit 456fbc3
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
19 changes: 17 additions & 2 deletions src/frontends/tensorflow_common/src/op/addN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "common_op_table.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/add.hpp"

using namespace std;
Expand All @@ -15,10 +16,24 @@ namespace tensorflow {
namespace op {

OutputVector translate_add_n_op(const NodeContext& node) {
default_op_checks(node, 1, {"AddN", "ADD_N"});
default_op_checks(node, 1, {"AddN", "ADD_N"}, true);
int num_size = static_cast<int>(node.get_input_size());
auto result = node.get_input(0);
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(result.get_node_shared_ptr());
if (complex_type_mark) {
element::Type complex_part_type = complex_type_mark->get_complex_part_type();
result = complex_type_mark->input_value(0);

// converting all the inputs to complex type (simulating complex type) and adding them
for (int ind = 1; ind < num_size; ++ind) {
auto complex_type_mark_ind = as_type_ptr<ComplexTypeMark>(node.get_input(ind).get_node_shared_ptr());
result = make_shared<v1::Add>(result, complex_type_mark_ind->input_value(0));
}
auto complex_add_n = make_shared<ComplexTypeMark>(result, complex_part_type);
set_node_name(node.get_name(), result.get_node_shared_ptr());
return {complex_add_n->output(0)};
}

Output<Node> result = node.get_input(0);
for (int ind = 1; ind < num_size; ++ind) {
result = make_shared<v1::Add>(result, node.get_input(ind));
}
Expand Down
43 changes: 40 additions & 3 deletions tests/layer_tests/tensorflow_tests/test_tf_AddN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import numpy as np
import pytest

import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest

import logging

# Testing operation AddN
# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/AddN
Expand All @@ -30,7 +29,6 @@ def create_addn_placeholder_const_net(self, input_shapes, ir_version, use_legacy

if len(input_shapes) == 1 and not use_legacy_frontend:
pytest.xfail(reason="96687")
import tensorflow as tf

tf.compat.v1.reset_default_graph()

Expand Down Expand Up @@ -66,3 +64,42 @@ def test_addn_placeholder_const(self, params, ie_device, precision, ir_version,
use_legacy_frontend=use_legacy_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

class TestComplexAddN(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
rng = np.random.default_rng()
inputs_data = {}
for idx, key in enumerate(inputs_info):
assert key in inputs_info
inputs_data[key] = 4 * rng.random(inputs_info[key]).astype(np.float32) - 2
return inputs_data

def create_complex_addn_net(self, input_shapes):
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
complex_tensors = []
for idx, input_shape in enumerate(input_shapes):
real = tf.compat.v1.placeholder(np.float32, input_shape, f'param_real_{idx+1}')
imag = tf.compat.v1.placeholder(np.float32, input_shape, f'param_imag_{idx+1}')
complex_tensors.append(tf.raw_ops.Complex(real=real, imag=imag))
addn = tf.raw_ops.AddN(inputs=complex_tensors, name='complex_AddN')
real = tf.raw_ops.Real(input=addn)
imag = tf.raw_ops.Imag(input=addn)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None

test_data = [
dict(input_shapes=[[1], [1]]),
dict(input_shapes=[[2, 3], [2, 3], [2, 3], [2, 3]]),
dict(input_shapes=[[3, 4, 5], [3, 4, 5], [3, 4, 5], [3, 4, 5], [3, 4, 5], [3, 4, 5]]),
]

@pytest.mark.parametrize("params", test_data)
@pytest.mark.precommit
@pytest.mark.nightly
def test_complex_addn(self, params, ie_device, precision, ir_version, temp_dir,
use_legacy_frontend):
self._test(*self.create_complex_addn_net(**params),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

0 comments on commit 456fbc3

Please sign in to comment.