Skip to content

Commit

Permalink
[ONNX] Use Reshape instead of Broadcast in v6 operators: Add, Div, Mu… (
Browse files Browse the repository at this point in the history
  • Loading branch information
mateusztabaka authored May 25, 2021
1 parent 7d429e2 commit d509fe6
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 103 deletions.
27 changes: 2 additions & 25 deletions ngraph/frontend/onnx_import/src/op/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "default_opset.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/shape.hpp"
#include "utils/common.hpp"

namespace ngraph
{
Expand All @@ -17,31 +18,7 @@ namespace ngraph
{
OutputVector add(const Node& node)
{
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
const bool broadcast = node.get_attribute_value<std::int64_t>("broadcast", 0);
if (broadcast)
{
if (node.has_attribute("axis"))
{
// Unidirectional broadcast right node to left shape.
const auto axis = node.get_attribute_value<std::int64_t>("axis");
const auto axes_mapping = builder::opset1::get_axes_mapping_output(
lhs_node.get_partial_shape(), rhs_node.get_partial_shape(), axis);
rhs_node = std::make_shared<default_opset::Broadcast>(
rhs_node,
std::make_shared<default_opset::ShapeOf>(lhs_node),
axes_mapping);
}
else
{
rhs_node = std::make_shared<default_opset::Broadcast>(
rhs_node, std::make_shared<default_opset::ShapeOf>(lhs_node));
}
return {std::make_shared<default_opset::Add>(
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
return {std::make_shared<default_opset::Add>(lhs_node, rhs_node)};
return common::handle_opset6_binary_op<default_opset::Add>(node);
}

} // namespace set_1
Expand Down
26 changes: 1 addition & 25 deletions ngraph/frontend/onnx_import/src/op/div.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,7 @@ namespace ngraph
{
inline OutputVector div(const Node& node)
{
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
const bool broadcast = node.get_attribute_value<std::int64_t>("broadcast", 0);
if (broadcast)
{
if (node.has_attribute("axis"))
{
// Unidirectional broadcast right node to left shape.
const auto axis = node.get_attribute_value<std::int64_t>("axis");
const auto axes_mapping = builder::opset1::get_axes_mapping_output(
lhs_node.get_partial_shape(), rhs_node.get_partial_shape(), axis);
rhs_node = std::make_shared<default_opset::Broadcast>(
rhs_node,
std::make_shared<default_opset::ShapeOf>(lhs_node),
axes_mapping);
}
else
{
rhs_node = std::make_shared<default_opset::Broadcast>(
rhs_node, std::make_shared<default_opset::ShapeOf>(lhs_node));
}
return {std::make_shared<default_opset::Divide>(
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
return {std::make_shared<default_opset::Divide>(lhs_node, rhs_node)};
return common::handle_opset6_binary_op<default_opset::Divide>(node);
}

} // namespace set_1
Expand Down
26 changes: 1 addition & 25 deletions ngraph/frontend/onnx_import/src/op/mul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,7 @@ namespace ngraph
{
inline OutputVector mul(const Node& node)
{
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
const bool broadcast = node.get_attribute_value<std::int64_t>("broadcast", 0);
if (broadcast)
{
if (node.has_attribute("axis"))
{
// Unidirectional broadcast right node to left shape.
const auto axis = node.get_attribute_value<std::int64_t>("axis");
const auto axes_mapping = builder::opset1::get_axes_mapping_output(
lhs_node.get_partial_shape(), rhs_node.get_partial_shape(), axis);
rhs_node = std::make_shared<default_opset::Broadcast>(
rhs_node,
std::make_shared<default_opset::ShapeOf>(lhs_node),
axes_mapping);
}
else
{
rhs_node = std::make_shared<default_opset::Broadcast>(
rhs_node, std::make_shared<default_opset::ShapeOf>(lhs_node));
}
return {std::make_shared<default_opset::Multiply>(
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
return {std::make_shared<default_opset::Multiply>(lhs_node, rhs_node)};
return common::handle_opset6_binary_op<default_opset::Multiply>(node);
}

} // namespace set_1
Expand Down
26 changes: 1 addition & 25 deletions ngraph/frontend/onnx_import/src/op/sub.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,7 @@ namespace ngraph
{
inline OutputVector sub(const Node& node)
{
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
const bool broadcast = node.get_attribute_value<std::int64_t>("broadcast", 0);
if (broadcast)
{
if (node.has_attribute("axis"))
{
// Unidirectional broadcast right node to left shape.
const auto axis = node.get_attribute_value<std::int64_t>("axis");
const auto axes_mapping = builder::opset1::get_axes_mapping_output(
lhs_node.get_partial_shape(), rhs_node.get_partial_shape(), axis);
rhs_node = std::make_shared<default_opset::Broadcast>(
rhs_node,
std::make_shared<default_opset::ShapeOf>(lhs_node),
axes_mapping);
}
else
{
rhs_node = std::make_shared<default_opset::Broadcast>(
rhs_node, std::make_shared<default_opset::ShapeOf>(lhs_node));
}
return {std::make_shared<default_opset::Subtract>(
lhs_node, rhs_node, ngraph::op::AutoBroadcastSpec::NONE)};
}
return {std::make_shared<default_opset::Subtract>(lhs_node, rhs_node)};
return common::handle_opset6_binary_op<default_opset::Subtract>(node);
}

} // namespace set_1
Expand Down
47 changes: 47 additions & 0 deletions ngraph/frontend/onnx_import/src/utils/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,53 @@ namespace ngraph
}
}

template <typename T>
OutputVector handle_opset6_binary_op(const Node& node)
{
const Output<ngraph::Node> lhs_node = node.get_ng_inputs().at(0);
Output<ngraph::Node> rhs_node = node.get_ng_inputs().at(1);
const bool broadcast = node.get_attribute_value<std::int64_t>("broadcast", 0);
if (broadcast)
{
if (node.has_attribute("axis"))
{
NGRAPH_CHECK(lhs_node.get_partial_shape().rank().is_static() &&
rhs_node.get_partial_shape().rank().is_static(),
"Input's rank has to be static.");
auto axis = node.get_attribute_value<std::int64_t>("axis");
auto lhs_rank = lhs_node.get_partial_shape().rank().get_length();
auto rhs_rank = rhs_node.get_partial_shape().rank().get_length();
if (axis < 0)
axis += lhs_rank;
if (lhs_rank > axis + rhs_rank)
{
auto ones = default_opset::Constant::create(
element::i64,
Shape{static_cast<size_t>(lhs_rank - axis - rhs_rank)},
std::vector<int64_t>(lhs_rank - axis - rhs_rank, 1));
auto rhs_shape = std::make_shared<default_opset::ShapeOf>(rhs_node);
auto new_shape = std::make_shared<default_opset::Concat>(
OutputVector{rhs_shape, ones}, 0);
rhs_node = std::make_shared<default_opset::Reshape>(
rhs_node, new_shape, false);
}
}
else
{
rhs_node = std::make_shared<default_opset::Broadcast>(
rhs_node, std::make_shared<default_opset::ShapeOf>(lhs_node));
}
}
return {std::make_shared<T>(lhs_node, rhs_node)};
}

template OutputVector handle_opset6_binary_op<default_opset::Add>(const Node& node);
template OutputVector handle_opset6_binary_op<default_opset::Divide>(const Node& node);
template OutputVector
handle_opset6_binary_op<default_opset::Multiply>(const Node& node);
template OutputVector
handle_opset6_binary_op<default_opset::Subtract>(const Node& node);

} // namespace common
} // namespace onnx_import
} // namespace ngraph
8 changes: 8 additions & 0 deletions ngraph/frontend/onnx_import/src/utils/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ namespace ngraph
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}

/// \brief Function that handles following ONNX operators: Add, Div, Mul, Sub
/// from opset 6.
///
/// \param node ONNX node
///
/// \return OutputVector with binary op
template <typename T>
OutputVector handle_opset6_binary_op(const Node& node);
} // namespace common
} // namespace onnx_import
} // namespace ngraph
3 changes: 0 additions & 3 deletions ngraph/test/onnx/onnx_import_controlflow.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,6 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_controlflow_loop_the_proper_opset_in_subgraph)
});
const auto body_mul_node = ngraph::as_type_ptr<default_opset::Multiply>(*body_mul_node_it);
EXPECT_TRUE(body_mul_node);
EXPECT_EQ(
body_mul_node->get_autob().m_type,
ngraph::op::AutoBroadcastType::NONE); // legacy Mul from ONNX opset1 has NONE broadcasting
}

// ~~~~~~~~STATIC/DYNAMIC/CONSTANT INPUTS TESTS:~~~~~~~~
Expand Down

0 comments on commit d509fe6

Please sign in to comment.