From 891a84e97b9cad9abd5c3fa047061dbfd7e2b2c1 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 3 Nov 2023 14:33:51 +0400 Subject: [PATCH 1/8] [TF FE] Support complex tensors Signed-off-by: Kazantsev, Roman --- src/frontends/tensorflow/src/op_table.cpp | 9 +++ .../include/common_op_table.hpp | 4 ++ .../include/helper_ops/complex_type_mark.hpp | 44 +++++++++++++ .../tensorflow_common/include/utils.hpp | 7 ++- .../tensorflow_common/src/op/complex.cpp | 61 +++++++++++++++++++ .../tensorflow_common/src/op/irfft.cpp | 55 +++++++++++++++++ .../tensorflow_common/src/op/real_imag.cpp | 51 ++++++++++++++++ .../tensorflow_common/src/op/rfft.cpp | 48 +++++++++++++++ src/frontends/tensorflow_common/src/utils.cpp | 39 +++++++++++- tools/mo/openvino/tools/mo/convert_impl.py | 3 +- 10 files changed, 317 insertions(+), 4 deletions(-) create mode 100644 src/frontends/tensorflow_common/include/helper_ops/complex_type_mark.hpp create mode 100644 src/frontends/tensorflow_common/src/op/complex.cpp create mode 100644 src/frontends/tensorflow_common/src/op/irfft.cpp create mode 100644 src/frontends/tensorflow_common/src/op/real_imag.cpp create mode 100644 src/frontends/tensorflow_common/src/op/rfft.cpp diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index e5f25dad31270a..e7bfdff9c72fbc 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -154,6 +154,7 @@ const std::map get_supported_ops() { {"CheckNumerics", CreatorFunction(translate_identity_op)}, {"CheckNumericsV2", CreatorFunction(translate_identity_op)}, {"ClipByValue", CreatorFunction(translate_clip_by_value_op)}, + {"Complex", CreatorFunction(translate_complex_op)}, {"Concat", CreatorFunction(translate_concat_op)}, {"ConcatV2", CreatorFunction(translate_concat_op)}, {"Const", CreatorFunction(translate_const_op)}, @@ -196,7 +197,11 @@ const std::map get_supported_ops() { {"IdentityN", CreatorFunction(translate_identity_n_op)}, {"Inv", CreatorFunction(translate_inv_op)}, {"If", CreatorFunction(translate_if_op)}, + {"Imag", CreatorFunction(translate_real_imag_op)}, {"input_arg", CreatorFunction(translate_input_arg_op)}, + {"IRFFT", CreatorFunction(translate_irfft_op)}, + {"IRFFT2D", CreatorFunction(translate_irfft_op)}, + {"IRFFT3D", CreatorFunction(translate_irfft_op)}, {"Iterator", CreatorFunction(translate_iterator_op)}, {"IteratorGetNext", CreatorFunction(translate_iterator_get_next_op)}, {"IteratorV2", CreatorFunction(translate_iterator_op)}, @@ -248,6 +253,7 @@ const std::map get_supported_ops() { {"Rank", CreatorFunction(translate_rank_op)}, {"RandomUniform", CreatorFunction(translate_random_uniform_op)}, {"RandomUniformInt", CreatorFunction(translate_random_uniform_int_op)}, + {"Real", CreatorFunction(translate_real_imag_op)}, {"Reciprocal", CreatorFunction(translate_reciprocal_op)}, {"Relu6", CreatorFunction(translate_relu_6_op)}, {"Reshape", CreatorFunction(translate_reshape_op)}, @@ -257,6 +263,9 @@ const std::map get_supported_ops() { {"ResizeBilinear", CreatorFunction(translate_interpolate_op)}, {"ResizeNearestNeighbor", CreatorFunction(translate_interpolate_op)}, {"ResourceGather", CreatorFunction(translate_resource_gather_op)}, + {"RFFT", CreatorFunction(translate_rfft_op)}, + {"RFFT2D", CreatorFunction(translate_rfft_op)}, + {"RFFT3D", CreatorFunction(translate_rfft_op)}, {"Roll", CreatorFunction(translate_roll_op)}, {"Round", CreatorFunction(translate_round_op)}, {"Rsqrt", CreatorFunction(translate_rsqrt_op)}, diff --git a/src/frontends/tensorflow_common/include/common_op_table.hpp b/src/frontends/tensorflow_common/include/common_op_table.hpp index 6befa470761a45..5d05e579ffe93d 100644 --- a/src/frontends/tensorflow_common/include/common_op_table.hpp +++ b/src/frontends/tensorflow_common/include/common_op_table.hpp @@ -46,6 +46,7 @@ OP_CONVERTER(translate_broadcast_to_op); OP_CONVERTER(translate_bucketize_op); OP_CONVERTER(translate_cast_op); OP_CONVERTER(translate_clip_by_value_op); +OP_CONVERTER(translate_complex_op); OP_CONVERTER(translate_concat_op); OP_CONVERTER(translate_const_op); OP_CONVERTER(translate_conv_2d_op); @@ -80,6 +81,7 @@ OP_CONVERTER(translate_inv_op); OP_CONVERTER(translate_invert_permutation_op); OP_CONVERTER(translate_output_arg_op); OP_CONVERTER(translate_interpolate_op); +OP_CONVERTER(translate_irfft_op); OP_CONVERTER(translate_is_finite_op); OP_CONVERTER(translate_is_inf_op); OP_CONVERTER(translate_is_nan_op); @@ -109,6 +111,7 @@ OP_CONVERTER(translate_range_op); OP_CONVERTER(translate_rank_op); OP_CONVERTER(translate_random_uniform_op); OP_CONVERTER(translate_random_uniform_int_op); +OP_CONVERTER(translate_real_imag_op); OP_CONVERTER(translate_relu_6_op); OP_CONVERTER(translate_reciprocal_op); OP_CONVERTER(translate_reshape_op); @@ -116,6 +119,7 @@ OP_CONVERTER(translate_resource_gather_op); OP_CONVERTER(translate_reverse_op); OP_CONVERTER(translate_reverse_v2_op); OP_CONVERTER(translate_reverse_sequence_op); +OP_CONVERTER(translate_rfft_op); OP_CONVERTER(translate_roll_op); OP_CONVERTER(translate_round_op); OP_CONVERTER(translate_rsqrt_op); diff --git a/src/frontends/tensorflow_common/include/helper_ops/complex_type_mark.hpp b/src/frontends/tensorflow_common/include/helper_ops/complex_type_mark.hpp new file mode 100644 index 00000000000000..61191c1f951566 --- /dev/null +++ b/src/frontends/tensorflow_common/include/helper_ops/complex_type_mark.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/type/element_type.hpp" +#include "openvino/op/util/framework_node.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { + +class ComplexTypeMark : public ov::op::util::FrameworkNode { +public: + OPENVINO_OP("ComplexTypeMark", "util", ov::op::util::FrameworkNode); + + ComplexTypeMark(const ov::Output& input, const ov::element::Type& complex_part_type) + : ov::op::util::FrameworkNode(ov::OutputVector{input}, 1), + m_complex_part_type(complex_part_type) { + validate_and_infer_types(); + } + + void validate_and_infer_types() override { + set_output_type(0, ov::element::dynamic, PartialShape::dynamic()); + } + + std::shared_ptr clone_with_new_inputs(const OutputVector& inputs) const override { + auto complex_type_mark = std::make_shared(inputs[0], m_complex_part_type); + complex_type_mark->set_attrs(get_attrs()); + return complex_type_mark; + } + + ov::element::Type get_complex_part_type() const { + return m_complex_part_type; + } + +private: + ov::element::Type m_complex_part_type; +}; + +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/include/utils.hpp b/src/frontends/tensorflow_common/include/utils.hpp index acca76aaab8dcc..1fa5d0083fde55 100644 --- a/src/frontends/tensorflow_common/include/utils.hpp +++ b/src/frontends/tensorflow_common/include/utils.hpp @@ -88,7 +88,10 @@ void fill_explicit_pads_vectors(const NodeContext& node, ov::CoordinateDiff& pads_begin, ov::CoordinateDiff& pads_end); -void default_op_checks(const NodeContext& node, size_t min_input_size, const std::vector& supported_ops); +void default_op_checks(const NodeContext& node, + size_t min_input_size, + const std::vector& supported_ops, + bool supported_complex = false); ov::Output get_elements_number_1d(const Output& output, ov::element::Type output_type, @@ -155,6 +158,8 @@ ov::Output get_data_slice(const ov::Output& data, const int64_t& stop, const int64_t& step); +ov::Output compute_broadcast_args(const ov::Output& shape1, const ov::Output& shape2); + } // namespace tensorflow } // namespace frontend } // namespace ov diff --git a/src/frontends/tensorflow_common/src/op/complex.cpp b/src/frontends/tensorflow_common/src/op/complex.cpp new file mode 100644 index 00000000000000..db686e928d7785 --- /dev/null +++ b/src/frontends/tensorflow_common/src/op/complex.cpp @@ -0,0 +1,61 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_complex_op(const NodeContext& node) { + default_op_checks(node, 2, {"Complex"}, true); + auto real = node.get_input(0); + auto imag = node.get_input(1); + auto tout = node.get_attribute("Tout", "DT_COMPLEX64"); + element::Type complex_part_type = (tout == "DT_COMPLEX64" ? element::f32 : element::f64); + + // compute target shape to which real and imag parts must be broadcasted + // and broadcast them + auto real_shape = make_shared(real, element::i32); + auto imag_shape = make_shared(imag, element::i32); + auto target_shape = compute_broadcast_args(real_shape, imag_shape); + real = make_shared(real, target_shape); + imag = make_shared(imag, target_shape); + + // expand real and imaginary parts with one dimension in the end for further concatenation + // this way, complex tensor with real and imag of shapes [N1, N2, ..., Nk] will be represented as floating-point + // tensor of shape [N1, N2, ..., Nk, 2] + auto real_rank = compute_subgraph_scalar_rank(real, element::i32, false); + real = make_shared(real, real_rank); + imag = make_shared(imag, real_rank); + + // concatenate real and imaginary parts to have a complex tensor represented as a floating-point tensor of shape + // [N1, N2, ..., Nk, 2] + auto complex_tensor = make_shared(OutputVector{real, imag}, -1)->output(0); + complex_tensor = make_shared(complex_tensor, complex_part_type); + + // set node name and tensor + set_node_name(node.get_name(), complex_tensor.get_node_shared_ptr()); + + // create complex type mark operation for upcoming operations in a graph + auto complex_type_mark = make_shared(complex_tensor, complex_part_type); + return complex_type_mark->outputs(); +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/src/op/irfft.cpp b/src/frontends/tensorflow_common/src/op/irfft.cpp new file mode 100644 index 00000000000000..858bb4961a1b9e --- /dev/null +++ b/src/frontends/tensorflow_common/src/op/irfft.cpp @@ -0,0 +1,55 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/irdft.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/subtract.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_irfft_op(const NodeContext& node) { + default_op_checks(node, 2, {"IRFFT", "IRFFT2D", "IRFFT3D"}, true); + auto op_type = node.get_op_type(); + auto input = node.get_input(0); + auto fft_length = node.get_input(1); + auto treal = node.get_attribute("Treal", element::f32); + + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + TENSORFLOW_OP_VALIDATION( + node, + complex_type_mark, + "[TensorFlow Frontend] internal error: ComplexTypeMark is not created before " + op_type + " operation."); + + // compute axes along which to compute inverse RFFT + auto data = complex_type_mark->input_value(0); + auto data_rank = compute_subgraph_scalar_rank(data, element::i32, true); + auto const_two = make_shared(element::i32, Shape{}, 2); + auto const_one = make_shared(element::i32, Shape{}, 1); + auto data_rank_minus_one = make_shared(data_rank, const_one); + auto axes = make_shared(const_two, data_rank_minus_one, const_one, element::i32); + auto irdft = make_shared(complex_type_mark->input_value(0), axes, fft_length)->output(0); + + // no need to insert ComplexTypeMark because operation generates a floating-point tensor + irdft = make_shared(irdft, treal); + set_node_name(node.get_name(), irdft.get_node_shared_ptr()); + + return {irdft}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/src/op/real_imag.cpp b/src/frontends/tensorflow_common/src/op/real_imag.cpp new file mode 100644 index 00000000000000..a8b06a5113b65f --- /dev/null +++ b/src/frontends/tensorflow_common/src/op/real_imag.cpp @@ -0,0 +1,51 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_real_imag_op(const NodeContext& node) { + default_op_checks(node, 1, {"Real", "Imag"}, true); + auto op_type = node.get_op_type(); + auto input = node.get_input(0); + auto tout = node.get_attribute("Tout", element::f32); + // Complex tensor is represented as a floating-point tensor of shape [N1, N2, ..., Nk, 2] + // where real part is placed in the slice by last dimension [..., 0] and + // imaginary part is placed by index [..., 1] + int32_t axis_value = (op_type == "Real") ? 0 : 1; + + // check that complex type mark is set at the input + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + TENSORFLOW_OP_VALIDATION( + node, + complex_type_mark, + "[TensorFlow Frontend] internal error: ComplexTypeMark is not set at the input of " + op_type); + auto data = complex_type_mark->input_value(0); + + // gather the required slice corresponding to Real or Imaginary part + auto gather_index = make_shared(element::i32, Shape{}, axis_value); + auto gather_axis = make_shared(element::i32, Shape{1}, -1); + auto complex_part = make_shared(data, gather_index, gather_axis); + + set_node_name(node.get_name(), complex_part); + + return {complex_part}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/src/op/rfft.cpp b/src/frontends/tensorflow_common/src/op/rfft.cpp new file mode 100644 index 00000000000000..40beabc8e90cd5 --- /dev/null +++ b/src/frontends/tensorflow_common/src/op/rfft.cpp @@ -0,0 +1,48 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/core/any.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/rdft.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_rfft_op(const NodeContext& node) { + default_op_checks(node, 2, {"RFFT", "RFFT2D", "RFFT3D"}); + auto input = node.get_input(0); + auto fft_length = node.get_input(1); + auto tcomplex = node.get_attribute("Tcomplex", "DT_COMPLEX64"); + element::Type complex_part_type = (tcomplex == "DT_COMPLEX64" ? element::f32 : element::f64); + + // compute axes along which to compute inverse RFFT + auto input_rank = compute_subgraph_scalar_rank(input, element::i32, true); + auto const_two = make_shared(element::i32, Shape{}, 2); + auto const_one = make_shared(element::i32, Shape{}, 1); + auto axes = make_shared(const_two, input_rank, const_one, element::i32); + + // compute real FFT and align its output type + auto rfft = make_shared(input, axes, fft_length)->output(0); + rfft = make_shared(rfft, complex_part_type); + set_node_name(node.get_name(), rfft.get_node_shared_ptr()); + + // insert ComplexTypeMark since RFFT generates output of complex type + auto complex_type_mark = make_shared(rfft, complex_part_type); + + return {complex_type_mark}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/src/utils.cpp b/src/frontends/tensorflow_common/src/utils.cpp index 5e65bb7dae2e1b..adf736d3b2cf84 100644 --- a/src/frontends/tensorflow_common/src/utils.cpp +++ b/src/frontends/tensorflow_common/src/utils.cpp @@ -7,6 +7,7 @@ #include #include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" #include "openvino/opsets/opset10.hpp" using namespace ov; @@ -242,7 +243,10 @@ OutputVector translate_convolution_op(const frontend::NodeContext& node, size_t return {conv}; } -void default_op_checks(const frontend::NodeContext& node, size_t min_input_size, const vector& supported_ops) { +void default_op_checks(const frontend::NodeContext& node, + size_t min_input_size, + const vector& supported_ops, + bool supported_complex) { auto op_type = node.get_op_type(); TENSORFLOW_OP_VALIDATION(node, find(supported_ops.begin(), supported_ops.end(), op_type) != supported_ops.end(), @@ -250,6 +254,21 @@ void default_op_checks(const frontend::NodeContext& node, size_t min_input_size, TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= min_input_size, op_type + " must have at least " + to_string(min_input_size) + " inputs."); + + // check if it supports complex type in case complex type input + bool has_input_complex_type = false; + auto input_size = static_cast(node.get_input_size()); + for (int input_ind = 0; input_ind < input_size; ++input_ind) { + auto node_input = node.get_input(input_ind); + if (as_type_ptr(node_input.get_node_shared_ptr())) { + has_input_complex_type = true; + break; + } + } + TENSORFLOW_OP_VALIDATION( + node, + !has_input_complex_type || supported_complex, + "[TensorFlow Frontend] internal error: translator for " + op_type + " does not support input complex type"); } bool is_conditional_edge(const string& input_tensor_name) { @@ -356,6 +375,24 @@ Output get_data_slice(const Output& data, const int64_t& start, cons return make_shared(data, start_const, stop_const, step_const)->output(0); } +Output compute_broadcast_args(const Output& shape1, const Output& shape2) { + // compute a number of shape elements to append for broadcasting + auto size0 = make_shared(shape1); + auto size1 = make_shared(shape2); + auto max_size = make_shared(size0, size1); + auto diff1 = make_shared(max_size, size0); + auto diff2 = make_shared(max_size, size1); + + // pad the shortest shape value with minus ones + // to take dynamic shapes into account + auto const_zero = create_same_type_const(diff1, std::vector{0}, Shape{1}); + auto const_one = create_same_type_const_scalar(shape1, 1); + auto padded_s0 = make_shared(shape1, diff1, const_zero, const_one, ov::op::PadMode::CONSTANT); + auto padded_s1 = make_shared(shape2, diff2, const_zero, const_one, ov::op::PadMode::CONSTANT); + + auto broadcasted_shape = make_shared(padded_s0, padded_s1); + return broadcasted_shape->output(0); +} } // namespace tensorflow } // namespace frontend } // namespace ov diff --git a/tools/mo/openvino/tools/mo/convert_impl.py b/tools/mo/openvino/tools/mo/convert_impl.py index 9d683f4b6ac977..6491f54d5acded 100644 --- a/tools/mo/openvino/tools/mo/convert_impl.py +++ b/tools/mo/openvino/tools/mo/convert_impl.py @@ -316,8 +316,7 @@ def update_fallback_with_conversion_error(use_new_frontend: bool, is_tf: bool, e "LoopCond", "Enter", "NextIteration", "Exit", "Switch", "Merge", # corresponds to operations with complex tensors "FFT", "FFT2D", "FFT3D", "IFFT", "IFFT2D", "IFFT3D", - "RFFT", "RFFT2D", "RFFT3D", "IRFFT", "IRFFT2D", "IRFFT3D", - "Complex", "ComplexAbs", "Real", "Imag", + "ComplexAbs", ] if len(conversion_error_match) < 1 or len(conversion_error_match[0]) != 4: # no match for the fallback by unsupported operation From 659476c4cfa34b16d44cf7592dfc257e11a2bdfe Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 3 Nov 2023 14:47:17 +0400 Subject: [PATCH 2/8] Align output type for Real and Imag operations Signed-off-by: Kazantsev, Roman --- src/frontends/tensorflow_common/src/op/real_imag.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/frontends/tensorflow_common/src/op/real_imag.cpp b/src/frontends/tensorflow_common/src/op/real_imag.cpp index a8b06a5113b65f..7b237ac4db443b 100644 --- a/src/frontends/tensorflow_common/src/op/real_imag.cpp +++ b/src/frontends/tensorflow_common/src/op/real_imag.cpp @@ -5,6 +5,7 @@ #include "common_op_table.hpp" #include "helper_ops/complex_type_mark.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" #include "openvino/op/gather.hpp" #include "utils.hpp" @@ -38,9 +39,12 @@ OutputVector translate_real_imag_op(const NodeContext& node) { // gather the required slice corresponding to Real or Imaginary part auto gather_index = make_shared(element::i32, Shape{}, axis_value); auto gather_axis = make_shared(element::i32, Shape{1}, -1); - auto complex_part = make_shared(data, gather_index, gather_axis); + auto complex_part = make_shared(data, gather_index, gather_axis)->output(0); - set_node_name(node.get_name(), complex_part); + // align output type required by tout attribute + complex_part = make_shared(complex_part, tout); + + set_node_name(node.get_name(), complex_part.get_node_shared_ptr()); return {complex_part}; } From 005554ea6bff62fb63d18681f2809117fec5efbe Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 3 Nov 2023 16:04:58 +0400 Subject: [PATCH 3/8] Update decoding complex types --- src/frontends/tensorflow/src/decoder_proto.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/frontends/tensorflow/src/decoder_proto.cpp b/src/frontends/tensorflow/src/decoder_proto.cpp index 9e0a53efb6d09f..573e850850ce16 100644 --- a/src/frontends/tensorflow/src/decoder_proto.cpp +++ b/src/frontends/tensorflow/src/decoder_proto.cpp @@ -113,10 +113,14 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const { case ::tensorflow::AttrValue::ValueCase::kType: { auto atype = attrs[0].type(); - if (atype != ::tensorflow::DT_STRING) { - return get_ov_type(attrs[0].type()); - } else { + if (atype == ::tensorflow::DT_STRING) { return ov::Any("DT_STRING"); + } else if (atype == ::tensorflow::DT_COMPLEX64) { + return ov::Any("DT_COMPLEX64"); + } else if (atype == ::tensorflow::DT_COMPLEX128) { + return ov::Any("DT_COMPLEX128"); + } else { + return get_ov_type(atype); } } From 81d249de269504409bacc04b7587cb21f230df0d Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 3 Nov 2023 18:20:45 +0400 Subject: [PATCH 4/8] Add support for ComplexAbs, FFT and IFFT operations Signed-off-by: Kazantsev, Roman --- src/frontends/tensorflow/src/op_table.cpp | 7 +++ .../include/common_op_table.hpp | 3 + .../tensorflow_common/src/op/complex_abs.cpp | 61 +++++++++++++++++++ .../tensorflow_common/src/op/fft.cpp | 55 +++++++++++++++++ .../tensorflow_common/src/op/ifft.cpp | 55 +++++++++++++++++ .../tensorflow_common/src/op/roll.cpp | 45 ++++++++++++-- tools/mo/openvino/tools/mo/convert_impl.py | 3 - 7 files changed, 220 insertions(+), 9 deletions(-) create mode 100644 src/frontends/tensorflow_common/src/op/complex_abs.cpp create mode 100644 src/frontends/tensorflow_common/src/op/fft.cpp create mode 100644 src/frontends/tensorflow_common/src/op/ifft.cpp diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index e7bfdff9c72fbc..176525fa7ef643 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -155,6 +155,7 @@ const std::map get_supported_ops() { {"CheckNumericsV2", CreatorFunction(translate_identity_op)}, {"ClipByValue", CreatorFunction(translate_clip_by_value_op)}, {"Complex", CreatorFunction(translate_complex_op)}, + {"ComplexAbs", CreatorFunction(translate_complex_abs_op)}, {"Concat", CreatorFunction(translate_concat_op)}, {"ConcatV2", CreatorFunction(translate_concat_op)}, {"Const", CreatorFunction(translate_const_op)}, @@ -179,6 +180,9 @@ const std::map get_supported_ops() { {"FakeQuantWithMinMaxVars", CreatorFunction(translate_fake_quant_op)}, {"FakeQuantWithMinMaxVarsPerChannel", CreatorFunction(translate_fake_quant_op)}, {"FakeQuantWithMinMaxArgs", CreatorFunction(translate_fake_quant_with_min_max_args)}, + {"FFT", CreatorFunction(translate_fft_op)}, + {"FFT2D", CreatorFunction(translate_fft_op)}, + {"FFT3D", CreatorFunction(translate_fft_op)}, {"FIFOQueue", CreatorFunction(translate_fifo_queue_op)}, {"FIFOQueueV2", CreatorFunction(translate_fifo_queue_op)}, {"Fill", CreatorFunction(translate_fill_op)}, @@ -197,6 +201,9 @@ const std::map get_supported_ops() { {"IdentityN", CreatorFunction(translate_identity_n_op)}, {"Inv", CreatorFunction(translate_inv_op)}, {"If", CreatorFunction(translate_if_op)}, + {"IFFT", CreatorFunction(translate_ifft_op)}, + {"IFFT2D", CreatorFunction(translate_ifft_op)}, + {"IFFT3D", CreatorFunction(translate_ifft_op)}, {"Imag", CreatorFunction(translate_real_imag_op)}, {"input_arg", CreatorFunction(translate_input_arg_op)}, {"IRFFT", CreatorFunction(translate_irfft_op)}, diff --git a/src/frontends/tensorflow_common/include/common_op_table.hpp b/src/frontends/tensorflow_common/include/common_op_table.hpp index 5d05e579ffe93d..6b5d83d4c2bb84 100644 --- a/src/frontends/tensorflow_common/include/common_op_table.hpp +++ b/src/frontends/tensorflow_common/include/common_op_table.hpp @@ -47,6 +47,7 @@ OP_CONVERTER(translate_bucketize_op); OP_CONVERTER(translate_cast_op); OP_CONVERTER(translate_clip_by_value_op); OP_CONVERTER(translate_complex_op); +OP_CONVERTER(translate_complex_abs_op); OP_CONVERTER(translate_concat_op); OP_CONVERTER(translate_const_op); OP_CONVERTER(translate_conv_2d_op); @@ -67,6 +68,7 @@ OP_CONVERTER(translate_expand_dims_op); OP_CONVERTER(translate_extract_image_patches_op); OP_CONVERTER(translate_fake_quant_op); OP_CONVERTER(translate_fake_quant_with_min_max_args); +OP_CONVERTER(translate_fft_op); OP_CONVERTER(translate_fill_op); OP_CONVERTER(translate_floor_div_op); OP_CONVERTER_NAMED(translate_fused_batch_norm_op); @@ -76,6 +78,7 @@ OP_CONVERTER(translate_gather_nd_op); OP_CONVERTER(translate_gather_tree_op); OP_CONVERTER(translate_identity_op); OP_CONVERTER(translate_identity_n_op); +OP_CONVERTER(translate_ifft_op); OP_CONVERTER(translate_input_arg_op); OP_CONVERTER(translate_inv_op); OP_CONVERTER(translate_invert_permutation_op); diff --git a/src/frontends/tensorflow_common/src/op/complex_abs.cpp b/src/frontends/tensorflow_common/src/op/complex_abs.cpp new file mode 100644 index 00000000000000..008bc369320ebe --- /dev/null +++ b/src/frontends/tensorflow_common/src/op/complex_abs.cpp @@ -0,0 +1,61 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/power.hpp" +#include "openvino/op/reduce_sum.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_complex_abs_op(const NodeContext& node) { + default_op_checks(node, 1, {"ComplexAbs"}, true); + auto op_type = node.get_op_type(); + auto x = node.get_input(0); + auto tout = node.get_attribute("Tout", element::f32); + + // check that complex type mark is set to the input + auto complex_type_mark = as_type_ptr(x.get_node_shared_ptr()); + TENSORFLOW_OP_VALIDATION(node, + complex_type_mark, + "[TensorFlow Frontend] internal error: ComplexTypeMark is not set to input of " + op_type); + auto complex_part_type = complex_type_mark->get_complex_part_type(); + // data is complex tensor representation in a form [N1, N2, ..., Nk, 2] + // where slice [N1, N2, ..., Nk, 0] contains real part of the complex tensor and + // slice [N1, N2, ..., Nk, 1] contains imaginary part of the complex tensor + auto data = complex_type_mark->input_value(0); + + // compute element-wise square for complex representation + auto const_two = make_shared(complex_part_type, Shape{}, 2); + auto squared_data = make_shared(data, const_two); + + // compute sum of squared real and imaginary parts + auto const_minus_one = make_shared(element::i32, Shape{}, -1); + auto complex_abs = make_shared(squared_data, const_minus_one, false)->output(0); + + // compute ComplexAbs by root-squared operation + auto const_half = make_shared(complex_part_type, Shape{}, 0.5f); + complex_abs = make_shared(complex_abs, const_half); + + // aling output type required by tout attribute + complex_abs = make_shared(complex_abs, tout); + + set_node_name(node.get_name(), complex_abs.get_node_shared_ptr()); + return {complex_abs}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/src/op/fft.cpp b/src/frontends/tensorflow_common/src/op/fft.cpp new file mode 100644 index 00000000000000..fac0ee7686fb1e --- /dev/null +++ b/src/frontends/tensorflow_common/src/op/fft.cpp @@ -0,0 +1,55 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/op/dft.hpp" +#include "openvino/op/subtract.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_fft_op(const NodeContext& node) { + default_op_checks(node, 1, {"FFT", "FFT2D", "FFT3D"}, true); + auto op_type = node.get_op_type(); + auto input = node.get_input(0); + + // check that ComplexTypeMark is set + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + TENSORFLOW_OP_VALIDATION( + node, + complex_type_mark, + "[TensorFlow Frontend] internal error: ComplexTypeMark is not set to input for " + op_type); + auto data = complex_type_mark->input_value(0); + auto complex_part_type = complex_type_mark->get_complex_part_type(); + + // compute axes along which to compute FFT + auto const_two = make_shared(element::i32, Shape{}, 2); + auto const_one = make_shared(element::i32, Shape{}, 1); + auto data_rank = compute_subgraph_scalar_rank(data, element::i32, true); + // exclude the last dimension since it concatenated real and imaginary parts + auto data_rank_minus_one = make_shared(data_rank, const_one); + auto axes = make_shared(const_two, data_rank_minus_one, const_one, element::i32); + + // compute FFT and align its output type + auto fft = make_shared(data, axes); + set_node_name(node.get_name(), fft); + + // insert ComplexTypeMark since FFT generates output of complex type + complex_type_mark = make_shared(fft, complex_part_type); + + return {complex_type_mark}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/src/op/ifft.cpp b/src/frontends/tensorflow_common/src/op/ifft.cpp new file mode 100644 index 00000000000000..1d6c525e225847 --- /dev/null +++ b/src/frontends/tensorflow_common/src/op/ifft.cpp @@ -0,0 +1,55 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/op/idft.hpp" +#include "openvino/op/subtract.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_ifft_op(const NodeContext& node) { + default_op_checks(node, 1, {"IFFT", "IFFT2D", "IFFT3D"}, true); + auto op_type = node.get_op_type(); + auto input = node.get_input(0); + + // check that ComplexTypeMark is set + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + TENSORFLOW_OP_VALIDATION( + node, + complex_type_mark, + "[TensorFlow Frontend] internal error: ComplexTypeMark is not set to input for " + op_type); + auto data = complex_type_mark->input_value(0); + auto complex_part_type = complex_type_mark->get_complex_part_type(); + + // compute axes along which to compute inverse FFT + auto const_two = make_shared(element::i32, Shape{}, 2); + auto const_one = make_shared(element::i32, Shape{}, 1); + auto data_rank = compute_subgraph_scalar_rank(data, element::i32, true); + // exclude the last dimension since it concatenated real and imaginary parts + auto data_rank_minus_one = make_shared(data_rank, const_one); + auto axes = make_shared(const_two, data_rank_minus_one, const_one, element::i32); + + // compute inverse FFT and align its output type + auto ifft = make_shared(data, axes); + set_node_name(node.get_name(), ifft); + + // insert ComplexTypeMark since IFFT generates output of complex type + complex_type_mark = make_shared(ifft, complex_part_type); + + return {complex_type_mark}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow_common/src/op/roll.cpp b/src/frontends/tensorflow_common/src/op/roll.cpp index 3f53e178a38572..6ed227c0eef52a 100644 --- a/src/frontends/tensorflow_common/src/op/roll.cpp +++ b/src/frontends/tensorflow_common/src/op/roll.cpp @@ -2,12 +2,19 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/op/roll.hpp" + #include "common_op_table.hpp" -#include "openvino/opsets/opset8.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/floor_mod.hpp" +#include "openvino/op/subtract.hpp" +#include "utils.hpp" using namespace std; using namespace ov; -using namespace ov::opset8; +using namespace ov::op; using namespace ov::frontend::tensorflow; namespace ov { @@ -15,12 +22,38 @@ namespace frontend { namespace tensorflow { namespace op { ov::OutputVector translate_roll_op(const NodeContext& node) { - auto data = node.get_input(0); + default_op_checks(node, 3, {"Roll"}, true); + auto input = node.get_input(0); auto shift = node.get_input(1); auto axis = node.get_input(2); - auto res = std::make_shared(data, shift, axis); - set_node_name(node.get_name(), res); - return res->outputs(); + + // check if complex type mark is set + // if yes, sinking it through Roll operation further + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + element::Type complex_part_type = element::dynamic; + if (complex_type_mark) { + input = complex_type_mark->input_value(0); + complex_part_type = complex_type_mark->get_complex_part_type(); + + // axes can be negative so we need to adjust them + // since the last dimension for complex type case is auxiliary (not real) + axis = make_shared(axis, element::i64); + auto input_rank = compute_subgraph_scalar_rank(input, element::i64, true); + auto const_one = make_shared(element::i64, Shape{}, 1); + auto input_rank_minus_one = make_shared(input_rank, const_one)->output(0); + + // adjust axis to make them non-negative + axis = make_shared(axis, input_rank_minus_one); + } + + auto roll = std::make_shared(input, shift, axis)->output(0); + set_node_name(node.get_name(), roll.get_node_shared_ptr()); + + if (complex_type_mark) { + roll = make_shared(roll, complex_part_type)->output(0); + } + + return {roll}; } } // namespace op } // namespace tensorflow diff --git a/tools/mo/openvino/tools/mo/convert_impl.py b/tools/mo/openvino/tools/mo/convert_impl.py index 6491f54d5acded..3a6df79daa69ab 100644 --- a/tools/mo/openvino/tools/mo/convert_impl.py +++ b/tools/mo/openvino/tools/mo/convert_impl.py @@ -314,9 +314,6 @@ def update_fallback_with_conversion_error(use_new_frontend: bool, is_tf: bool, e all_fallback_operations = [ # corresponds to TF1 While operation "LoopCond", "Enter", "NextIteration", "Exit", "Switch", "Merge", - # corresponds to operations with complex tensors - "FFT", "FFT2D", "FFT3D", "IFFT", "IFFT2D", "IFFT3D", - "ComplexAbs", ] if len(conversion_error_match) < 1 or len(conversion_error_match[0]) != 4: # no match for the fallback by unsupported operation From 8fa8b96451f8b512685d8a96235f59bbe79aba1e Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 3 Nov 2023 20:02:18 +0400 Subject: [PATCH 5/8] Correct axes based on a number of inner-most dimensions --- src/frontends/tensorflow_common/src/op/fft.cpp | 13 +++++++++++-- src/frontends/tensorflow_common/src/op/ifft.cpp | 13 +++++++++++-- src/frontends/tensorflow_common/src/op/irfft.cpp | 13 +++++++++++-- src/frontends/tensorflow_common/src/op/rfft.cpp | 15 +++++++++++++-- 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/src/frontends/tensorflow_common/src/op/fft.cpp b/src/frontends/tensorflow_common/src/op/fft.cpp index fac0ee7686fb1e..b46a7ce91757cf 100644 --- a/src/frontends/tensorflow_common/src/op/fft.cpp +++ b/src/frontends/tensorflow_common/src/op/fft.cpp @@ -31,13 +31,22 @@ OutputVector translate_fft_op(const NodeContext& node) { auto data = complex_type_mark->input_value(0); auto complex_part_type = complex_type_mark->get_complex_part_type(); + // compute a number of inner-most dimensions + int32_t num_axes = 1; + if (op_type == "FFT2D") { + num_axes = 2; + } else if (op_type == "FFT3D") { + num_axes = 3; + } + // compute axes along which to compute FFT - auto const_two = make_shared(element::i32, Shape{}, 2); + auto const_num_axes = make_shared(element::i32, Shape{}, num_axes); auto const_one = make_shared(element::i32, Shape{}, 1); auto data_rank = compute_subgraph_scalar_rank(data, element::i32, true); // exclude the last dimension since it concatenated real and imaginary parts auto data_rank_minus_one = make_shared(data_rank, const_one); - auto axes = make_shared(const_two, data_rank_minus_one, const_one, element::i32); + auto start = make_shared(data_rank_minus_one, const_num_axes); + auto axes = make_shared(start, data_rank_minus_one, const_one, element::i32); // compute FFT and align its output type auto fft = make_shared(data, axes); diff --git a/src/frontends/tensorflow_common/src/op/ifft.cpp b/src/frontends/tensorflow_common/src/op/ifft.cpp index 1d6c525e225847..927d7934f549ee 100644 --- a/src/frontends/tensorflow_common/src/op/ifft.cpp +++ b/src/frontends/tensorflow_common/src/op/ifft.cpp @@ -31,13 +31,22 @@ OutputVector translate_ifft_op(const NodeContext& node) { auto data = complex_type_mark->input_value(0); auto complex_part_type = complex_type_mark->get_complex_part_type(); + // compute a number of inner-most dimensions + int32_t num_axes = 1; + if (op_type == "IFFT2D") { + num_axes = 2; + } else if (op_type == "IFFT3D") { + num_axes = 3; + } + // compute axes along which to compute inverse FFT - auto const_two = make_shared(element::i32, Shape{}, 2); + auto const_num_axes = make_shared(element::i32, Shape{}, num_axes); auto const_one = make_shared(element::i32, Shape{}, 1); auto data_rank = compute_subgraph_scalar_rank(data, element::i32, true); // exclude the last dimension since it concatenated real and imaginary parts auto data_rank_minus_one = make_shared(data_rank, const_one); - auto axes = make_shared(const_two, data_rank_minus_one, const_one, element::i32); + auto start = make_shared(data_rank_minus_one, const_num_axes); + auto axes = make_shared(start, data_rank_minus_one, const_one, element::i32); // compute inverse FFT and align its output type auto ifft = make_shared(data, axes); diff --git a/src/frontends/tensorflow_common/src/op/irfft.cpp b/src/frontends/tensorflow_common/src/op/irfft.cpp index 858bb4961a1b9e..5fc1f08d35c158 100644 --- a/src/frontends/tensorflow_common/src/op/irfft.cpp +++ b/src/frontends/tensorflow_common/src/op/irfft.cpp @@ -33,13 +33,22 @@ OutputVector translate_irfft_op(const NodeContext& node) { complex_type_mark, "[TensorFlow Frontend] internal error: ComplexTypeMark is not created before " + op_type + " operation."); + // compute a number of inner-most dimensions + int32_t num_axes = 1; + if (op_type == "IRFFT2D") { + num_axes = 2; + } else if (op_type == "IRFFT3D") { + num_axes = 3; + } + // compute axes along which to compute inverse RFFT + auto const_num_axes = make_shared(element::i32, Shape{}, num_axes); auto data = complex_type_mark->input_value(0); auto data_rank = compute_subgraph_scalar_rank(data, element::i32, true); - auto const_two = make_shared(element::i32, Shape{}, 2); auto const_one = make_shared(element::i32, Shape{}, 1); auto data_rank_minus_one = make_shared(data_rank, const_one); - auto axes = make_shared(const_two, data_rank_minus_one, const_one, element::i32); + auto start = make_shared(data_rank_minus_one, const_num_axes); + auto axes = make_shared(start, data_rank_minus_one, const_one, element::i32); auto irdft = make_shared(complex_type_mark->input_value(0), axes, fft_length)->output(0); // no need to insert ComplexTypeMark because operation generates a floating-point tensor diff --git a/src/frontends/tensorflow_common/src/op/rfft.cpp b/src/frontends/tensorflow_common/src/op/rfft.cpp index 40beabc8e90cd5..7e38c8651a9058 100644 --- a/src/frontends/tensorflow_common/src/op/rfft.cpp +++ b/src/frontends/tensorflow_common/src/op/rfft.cpp @@ -7,6 +7,7 @@ #include "openvino/core/any.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/rdft.hpp" +#include "openvino/op/subtract.hpp" #include "utils.hpp" using namespace std; @@ -20,16 +21,26 @@ namespace op { OutputVector translate_rfft_op(const NodeContext& node) { default_op_checks(node, 2, {"RFFT", "RFFT2D", "RFFT3D"}); + auto op_type = node.get_op_type(); auto input = node.get_input(0); auto fft_length = node.get_input(1); auto tcomplex = node.get_attribute("Tcomplex", "DT_COMPLEX64"); element::Type complex_part_type = (tcomplex == "DT_COMPLEX64" ? element::f32 : element::f64); + // compute a number of inner-most dimension of the input signal + int32_t num_axes = 1; + if (op_type == "RFFT2D") { + num_axes = 2; + } else if (op_type == "RFFT3D") { + num_axes = 3; + } + // compute axes along which to compute inverse RFFT + auto const_num_axes = make_shared(element::i32, Shape{}, num_axes); auto input_rank = compute_subgraph_scalar_rank(input, element::i32, true); - auto const_two = make_shared(element::i32, Shape{}, 2); + auto start = make_shared(input_rank, const_num_axes); auto const_one = make_shared(element::i32, Shape{}, 1); - auto axes = make_shared(const_two, input_rank, const_one, element::i32); + auto axes = make_shared(start, input_rank, const_one, element::i32); // compute real FFT and align its output type auto rfft = make_shared(input, axes, fft_length)->output(0); From 41a1f16ff23edd523d9151df62802f7ca8c41e99 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Mon, 6 Nov 2023 00:33:53 +0400 Subject: [PATCH 6/8] Add layer tests Signed-off-by: Kazantsev, Roman --- .../tensorflow_tests/test_tf_ComplexFFT.py | 196 ++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_ComplexFFT.py diff --git a/tests/layer_tests/tensorflow_tests/test_tf_ComplexFFT.py b/tests/layer_tests/tensorflow_tests/test_tf_ComplexFFT.py new file mode 100644 index 00000000000000..14586b3bba3805 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_ComplexFFT.py @@ -0,0 +1,196 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestComplexFFT(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + rng = np.random.default_rng() + assert 'param_real' in inputs_info + assert 'param_imag' in inputs_info + param_real_shape = inputs_info['param_real'] + param_imag_shape = inputs_info['param_imag'] + inputs_data = {} + inputs_data['param_real'] = 4 * rng.random(param_real_shape).astype(np.float32) - 2 + inputs_data['param_imag'] = 4 * rng.random(param_imag_shape).astype(np.float32) - 2 + return inputs_data + + def create_complex_fft_net(self, input_shape, shift_roll, axis_roll, fft_op): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + param_real = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real') + param_imag = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag') + shift = tf.constant(shift_roll, dtype=tf.int32) + axis = tf.constant(axis_roll, dtype=tf.int32) + complex = tf.raw_ops.Complex(real=param_real, imag=param_imag) + roll = tf.raw_ops.Roll(input=complex, shift=shift, axis=axis) + fft = fft_op(input=roll) + real = tf.raw_ops.Real(input=fft) + imag = tf.raw_ops.Imag(input=fft) + tf.raw_ops.Pack(values=[real, imag], axis=-1) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(input_shape=[1, 50, 2], shift_roll=[10, 1], axis_roll=[-2, -1]), + dict(input_shape=[4, 20, 3], shift_roll=[2, 10], axis_roll=[0, 1]), + dict(input_shape=[1, 50, 50, 2], shift_roll=[10, 20], axis_roll=[-2, -1]), + dict(input_shape=[4, 20, 30, 3], shift_roll=[2, 10], axis_roll=[0, 1]), + dict(input_shape=[1, 50, 50, 30, 2], shift_roll=[10, 20, 4], axis_roll=[-3, -2, -1]), + dict(input_shape=[4, 20, 30, 10, 3], shift_roll=[2, 10], axis_roll=[1, 2]), + ] + + @pytest.mark.parametrize("fft_op", [ + tf.raw_ops.FFT, tf.raw_ops.FFT2D, tf.raw_ops.FFT3D, + tf.raw_ops.IFFT, tf.raw_ops.IFFT2D, tf.raw_ops.IFFT3D + ]) + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_complex_fft_basic(self, params, fft_op, + ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test( + *self.create_complex_fft_net(**params, fft_op=fft_op), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api, custom_eps=1e-2) + + +class TestComplexAbs(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + rng = np.random.default_rng() + assert 'param_real' in inputs_info + assert 'param_imag' in inputs_info + param_real_shape = inputs_info['param_real'] + param_imag_shape = inputs_info['param_imag'] + inputs_data = {} + inputs_data['param_real'] = 4 * rng.random(param_real_shape).astype(np.float32) - 2 + inputs_data['param_imag'] = 4 * rng.random(param_imag_shape).astype(np.float32) - 2 + return inputs_data + + def create_complex_abs_net(self, input_shape): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + param_real = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real') + param_imag = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag') + complex = tf.raw_ops.Complex(real=param_real, imag=param_imag) + tf.raw_ops.ComplexAbs(x=complex) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(input_shape=[]), + dict(input_shape=[2]), + dict(input_shape=[1, 3]), + dict(input_shape=[2, 3, 4]), + dict(input_shape=[3, 4, 5, 6]), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_complex_abs_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test( + *self.create_complex_abs_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) + + +class TestComplexRFFT(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + rng = np.random.default_rng() + assert 'param' in inputs_info + param_shape = inputs_info['param'] + inputs_data = {} + inputs_data['param'] = 4 * rng.random(param_shape).astype(np.float32) - 2 + return inputs_data + + def create_complex_rfft_net(self, input_shape, fft_length, rfft_op): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + param = tf.compat.v1.placeholder(np.float32, input_shape, 'param') + fft_length_const = tf.constant(fft_length, dtype=tf.int32) + rfft = rfft_op(input=param, fft_length=fft_length_const) + real = tf.raw_ops.Real(input=rfft) + imag = tf.raw_ops.Imag(input=rfft) + tf.raw_ops.Pack(values=[real, imag], axis=-1) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(input_shape=[1, 3, 20], fft_length=[10], rfft_op=tf.raw_ops.RFFT), + dict(input_shape=[1, 3, 20], fft_length=[20], rfft_op=tf.raw_ops.RFFT), + dict(input_shape=[1, 3, 20, 40], fft_length=[20, 10], rfft_op=tf.raw_ops.RFFT2D), + dict(input_shape=[1, 3, 20, 40], fft_length=[10, 40], rfft_op=tf.raw_ops.RFFT2D), + dict(input_shape=[1, 2, 10, 20, 5], fft_length=[2, 5, 3], rfft_op=tf.raw_ops.RFFT3D), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_complex_rfft_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test( + *self.create_complex_rfft_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) + + +class TestComplexIRFFT(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + rng = np.random.default_rng() + assert 'param_real' in inputs_info + assert 'param_imag' in inputs_info + param_real_shape = inputs_info['param_real'] + param_imag_shape = inputs_info['param_imag'] + inputs_data = {} + inputs_data['param_real'] = 4 * rng.random(param_real_shape).astype(np.float32) - 2 + inputs_data['param_imag'] = 4 * rng.random(param_imag_shape).astype(np.float32) - 2 + return inputs_data + + def create_complex_irfft_net(self, input_shape, fft_length, irfft_op): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + param_real = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real') + param_imag = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag') + fft_length_const = tf.constant(fft_length, dtype=tf.int32) + complex = tf.raw_ops.Complex(real=param_real, imag=param_imag) + irfft_op(input=complex, fft_length=fft_length_const) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(input_shape=[1, 3, 20], fft_length=[10], irfft_op=tf.raw_ops.IRFFT), + dict(input_shape=[1, 3, 20], fft_length=[20], irfft_op=tf.raw_ops.IRFFT), + dict(input_shape=[1, 3, 20, 40], fft_length=[20, 10], irfft_op=tf.raw_ops.IRFFT2D), + dict(input_shape=[1, 3, 20, 40], fft_length=[10, 40], irfft_op=tf.raw_ops.IRFFT2D), + pytest.param(dict(input_shape=[1, 10, 20, 30, 5], fft_length=[2, 3, 4], irfft_op=tf.raw_ops.IRFFT3D), + marks=pytest.mark.xfail(reason="accuracy-issue-TBD")) + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_complex_irfft_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test( + *self.create_complex_irfft_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) From 115d0822fc627ba44178bfc1103753de39a90e14 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Mon, 6 Nov 2023 00:40:49 +0400 Subject: [PATCH 7/8] Update supported ops documentation Signed-off-by: Kazantsev, Roman --- .../tensorflow/docs/supported_ops.md | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/frontends/tensorflow/docs/supported_ops.md b/src/frontends/tensorflow/docs/supported_ops.md index e9b9a499f55a76..32bd4caef5bc37 100644 --- a/src/frontends/tensorflow/docs/supported_ops.md +++ b/src/frontends/tensorflow/docs/supported_ops.md @@ -218,8 +218,8 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV | CollectiveReduceV2 | NO | | | CollectiveReduceV3 | NO | | | CombinedNonMaxSuppression | NO | | -| Complex | NO | | -| ComplexAbs | NO | | +| Complex | YES | | +| ComplexAbs | YES | | | CompositeTensorVariantFromComponents | NO | | | CompositeTensorVariantToComponents | NO | | | CompressElement | NO | | @@ -425,9 +425,9 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV | ExtractImagePatches | YES | | | ExtractJpegShape | NO | | | ExtractVolumePatches | NO | | -| FFT | NO | | -| FFT2D | NO | | -| FFT3D | NO | | +| FFT | YES | | +| FFT2D | YES | | +| FFT3D | YES | | | FIFOQueue | YES | | | FIFOQueueV2 | YES | | | Fact | NO | | @@ -492,12 +492,12 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV | HashTableV2 | YES | | | HistogramFixedWidth | NO | | | HistogramSummary | NO | | -| IFFT | NO | | -| IFFT2D | NO | | -| IFFT3D | NO | | -| IRFFT | NO | | -| IRFFT2D | NO | | -| IRFFT3D | NO | | +| IFFT | YES | | +| IFFT2D | YES | | +| IFFT3D | YES | | +| IRFFT | YES | | +| IRFFT2D | YES | | +| IRFFT3D | YES | | | Identity | YES | | | IdentityN | YES | | | IdentityReader | NO | | @@ -507,7 +507,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV | IgammaGradA | NO | | | Igammac | NO | | | IgnoreErrorsDataset | NO | | -| Imag | NO | | +| Imag | YES | | | ImageProjectiveTransformV2 | NO | | | ImageProjectiveTransformV3 | NO | | | ImageSummary | NO | | @@ -826,9 +826,9 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV | QueueIsClosedV2 | NO | | | QueueSize | NO | | | QueueSizeV2 | NO | | -| RFFT | NO | | -| RFFT2D | NO | | -| RFFT3D | NO | | +| RFFT | YES | | +| RFFT2D | YES | | +| RFFT3D | YES | | | RGBToHSV | NO | | | RaggedBincount | NO | | | RaggedCountSparseOutput | NO | | @@ -876,7 +876,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV | ReaderRestoreStateV2 | NO | | | ReaderSerializeState | NO | | | ReaderSerializeStateV2 | NO | | -| Real | NO | | +| Real | YES | | | RealDiv | YES | | | RebatchDataset | NO | | | RebatchDatasetV2 | NO | | From b9ba78ad8f2f883411b41b3f06750b260591bcf6 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Mon, 6 Nov 2023 00:54:10 +0400 Subject: [PATCH 8/8] Add a comment for ComplexTypeMark Signed-off-by: Kazantsev, Roman --- .../include/helper_ops/complex_type_mark.hpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/frontends/tensorflow_common/include/helper_ops/complex_type_mark.hpp b/src/frontends/tensorflow_common/include/helper_ops/complex_type_mark.hpp index 61191c1f951566..5d0e2d5eb2d140 100644 --- a/src/frontends/tensorflow_common/include/helper_ops/complex_type_mark.hpp +++ b/src/frontends/tensorflow_common/include/helper_ops/complex_type_mark.hpp @@ -11,6 +11,13 @@ namespace ov { namespace frontend { namespace tensorflow { +// ComplexTypeMark serves to mark places that require complex type propagation +// that means to represent native complex type with simulating floating-point tensor +// that has one extra dimension to concatenate real and imaginary parts of complex tensor. +// For example, a tensor of complex type with shape [N1, N2, ..., Nk] will be transformed +// into a floating-point tensor [N1, N2, ..., Nk, 2] +// where a slice with index [..., 0] represents a real part and +// a slice with index [..., 1] represents a imaginary part. class ComplexTypeMark : public ov::op::util::FrameworkNode { public: OPENVINO_OP("ComplexTypeMark", "util", ov::op::util::FrameworkNode);