diff --git a/ngraph/frontend/onnx_import/src/op/add.cpp b/ngraph/frontend/onnx_import/src/op/add.cpp index a88308956f0e23..a09253f70abce8 100644 --- a/ngraph/frontend/onnx_import/src/op/add.cpp +++ b/ngraph/frontend/onnx_import/src/op/add.cpp @@ -6,6 +6,7 @@ #include "default_opset.hpp" #include "ngraph/builder/autobroadcast.hpp" #include "ngraph/shape.hpp" +#include "utils/common.hpp" namespace ngraph { @@ -17,31 +18,7 @@ namespace ngraph { OutputVector add(const Node& node) { - const Output lhs_node = node.get_ng_inputs().at(0); - Output rhs_node = node.get_ng_inputs().at(1); - const bool broadcast = node.get_attribute_value("broadcast", 0); - if (broadcast) - { - if (node.has_attribute("axis")) - { - // Unidirectional broadcast right node to left shape. - const auto axis = node.get_attribute_value("axis"); - const auto axes_mapping = builder::opset1::get_axes_mapping_output( - lhs_node.get_partial_shape(), rhs_node.get_partial_shape(), axis); - rhs_node = std::make_shared( - rhs_node, - std::make_shared(lhs_node), - axes_mapping); - } - else - { - rhs_node = std::make_shared( - rhs_node, std::make_shared(lhs_node)); - } - return {std::make_shared( - lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)}; - } - return {std::make_shared(lhs_node, rhs_node)}; + return common::handle_opset6_binary_op(node); } } // namespace set_1 diff --git a/ngraph/frontend/onnx_import/src/op/div.hpp b/ngraph/frontend/onnx_import/src/op/div.hpp index 8a55d8be3e17ab..fb03ee50e54956 100644 --- a/ngraph/frontend/onnx_import/src/op/div.hpp +++ b/ngraph/frontend/onnx_import/src/op/div.hpp @@ -22,31 +22,7 @@ namespace ngraph { inline OutputVector div(const Node& node) { - const Output lhs_node = node.get_ng_inputs().at(0); - Output rhs_node = node.get_ng_inputs().at(1); - const bool broadcast = node.get_attribute_value("broadcast", 0); - if (broadcast) - { - if (node.has_attribute("axis")) - { - // Unidirectional broadcast right node to left shape. - const auto axis = node.get_attribute_value("axis"); - const auto axes_mapping = builder::opset1::get_axes_mapping_output( - lhs_node.get_partial_shape(), rhs_node.get_partial_shape(), axis); - rhs_node = std::make_shared( - rhs_node, - std::make_shared(lhs_node), - axes_mapping); - } - else - { - rhs_node = std::make_shared( - rhs_node, std::make_shared(lhs_node)); - } - return {std::make_shared( - lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)}; - } - return {std::make_shared(lhs_node, rhs_node)}; + return common::handle_opset6_binary_op(node); } } // namespace set_1 diff --git a/ngraph/frontend/onnx_import/src/op/mul.hpp b/ngraph/frontend/onnx_import/src/op/mul.hpp index 083bbe1b3f45b8..78e8e8ecb872ce 100644 --- a/ngraph/frontend/onnx_import/src/op/mul.hpp +++ b/ngraph/frontend/onnx_import/src/op/mul.hpp @@ -23,31 +23,7 @@ namespace ngraph { inline OutputVector mul(const Node& node) { - const Output lhs_node = node.get_ng_inputs().at(0); - Output rhs_node = node.get_ng_inputs().at(1); - const bool broadcast = node.get_attribute_value("broadcast", 0); - if (broadcast) - { - if (node.has_attribute("axis")) - { - // Unidirectional broadcast right node to left shape. - const auto axis = node.get_attribute_value("axis"); - const auto axes_mapping = builder::opset1::get_axes_mapping_output( - lhs_node.get_partial_shape(), rhs_node.get_partial_shape(), axis); - rhs_node = std::make_shared( - rhs_node, - std::make_shared(lhs_node), - axes_mapping); - } - else - { - rhs_node = std::make_shared( - rhs_node, std::make_shared(lhs_node)); - } - return {std::make_shared( - lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)}; - } - return {std::make_shared(lhs_node, rhs_node)}; + return common::handle_opset6_binary_op(node); } } // namespace set_1 diff --git a/ngraph/frontend/onnx_import/src/op/sub.hpp b/ngraph/frontend/onnx_import/src/op/sub.hpp index eba39f8c6bca99..06183490b27e18 100644 --- a/ngraph/frontend/onnx_import/src/op/sub.hpp +++ b/ngraph/frontend/onnx_import/src/op/sub.hpp @@ -19,31 +19,7 @@ namespace ngraph { inline OutputVector sub(const Node& node) { - const Output lhs_node = node.get_ng_inputs().at(0); - Output rhs_node = node.get_ng_inputs().at(1); - const bool broadcast = node.get_attribute_value("broadcast", 0); - if (broadcast) - { - if (node.has_attribute("axis")) - { - // Unidirectional broadcast right node to left shape. - const auto axis = node.get_attribute_value("axis"); - const auto axes_mapping = builder::opset1::get_axes_mapping_output( - lhs_node.get_partial_shape(), rhs_node.get_partial_shape(), axis); - rhs_node = std::make_shared( - rhs_node, - std::make_shared(lhs_node), - axes_mapping); - } - else - { - rhs_node = std::make_shared( - rhs_node, std::make_shared(lhs_node)); - } - return {std::make_shared( - lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)}; - } - return {std::make_shared(lhs_node, rhs_node)}; + return common::handle_opset6_binary_op(node); } } // namespace set_1 diff --git a/ngraph/frontend/onnx_import/src/utils/common.cpp b/ngraph/frontend/onnx_import/src/utils/common.cpp index 67431ac8c52887..9465894f0d9387 100644 --- a/ngraph/frontend/onnx_import/src/utils/common.cpp +++ b/ngraph/frontend/onnx_import/src/utils/common.cpp @@ -86,6 +86,53 @@ namespace ngraph } } + template + OutputVector handle_opset6_binary_op(const Node& node) + { + const Output lhs_node = node.get_ng_inputs().at(0); + Output rhs_node = node.get_ng_inputs().at(1); + const bool broadcast = node.get_attribute_value("broadcast", 0); + if (broadcast) + { + if (node.has_attribute("axis")) + { + NGRAPH_CHECK(lhs_node.get_partial_shape().rank().is_static() && + rhs_node.get_partial_shape().rank().is_static(), + "Input's rank has to be static."); + auto axis = node.get_attribute_value("axis"); + auto lhs_rank = lhs_node.get_partial_shape().rank().get_length(); + auto rhs_rank = rhs_node.get_partial_shape().rank().get_length(); + if (axis < 0) + axis += lhs_rank; + if (lhs_rank > axis + rhs_rank) + { + auto ones = default_opset::Constant::create( + element::i64, + Shape{static_cast(lhs_rank - axis - rhs_rank)}, + std::vector(lhs_rank - axis - rhs_rank, 1)); + auto rhs_shape = std::make_shared(rhs_node); + auto new_shape = std::make_shared( + OutputVector{rhs_shape, ones}, 0); + rhs_node = std::make_shared( + rhs_node, new_shape, false); + } + } + else + { + rhs_node = std::make_shared( + rhs_node, std::make_shared(lhs_node)); + } + } + return {std::make_shared(lhs_node, rhs_node)}; + } + + template OutputVector handle_opset6_binary_op(const Node& node); + template OutputVector handle_opset6_binary_op(const Node& node); + template OutputVector + handle_opset6_binary_op(const Node& node); + template OutputVector + handle_opset6_binary_op(const Node& node); + } // namespace common } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/utils/common.hpp b/ngraph/frontend/onnx_import/src/utils/common.hpp index b386c002b07eab..b750f70de63a25 100644 --- a/ngraph/frontend/onnx_import/src/utils/common.hpp +++ b/ngraph/frontend/onnx_import/src/utils/common.hpp @@ -137,6 +137,14 @@ namespace ngraph return std::unique_ptr(new T(std::forward(args)...)); } + /// \brief Function that handles following ONNX operators: Add, Div, Mul, Sub + /// from opset 6. + /// + /// \param node ONNX node + /// + /// \return OutputVector with binary op + template + OutputVector handle_opset6_binary_op(const Node& node); } // namespace common } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/test/onnx/onnx_import_controlflow.in.cpp b/ngraph/test/onnx/onnx_import_controlflow.in.cpp index 1eb5ab73f9774b..ffcf4bacd7dbcd 100644 --- a/ngraph/test/onnx/onnx_import_controlflow.in.cpp +++ b/ngraph/test/onnx/onnx_import_controlflow.in.cpp @@ -330,9 +330,6 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_the_proper_opset_in_subgraph) }); const auto body_mul_node = ngraph::as_type_ptr(*body_mul_node_it); EXPECT_TRUE(body_mul_node); - EXPECT_EQ( - body_mul_node->get_autob().m_type, - ngraph::op::AutoBroadcastType::NONE); // legacy Mul from ONNX opset1 has NONE broadcasting } // ~~~~~~~~STATIC/DYNAMIC/CONSTANT INPUTS TESTS:~~~~~~~~