Skip to content

Commit

Permalink
[ONNX] Reduce* refactoring (openvinotoolkit#23429)
Browse files Browse the repository at this point in the history
### Details:
 - Refactored source code to be able co-work with others
 - Removed unnecessary comments
 - Added check for a wrong input type
 - Added tests for an exception for wrong input type

### Tickets:
 - 125493
  • Loading branch information
gkrivor authored Mar 14, 2024
1 parent ca1153a commit 149381e
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 156 deletions.
64 changes: 44 additions & 20 deletions src/frontends/onnx/frontend/src/op/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,74 +90,98 @@ std::shared_ptr<ov::Node> get_reduction_axes_from_attr(const Node& node) {
return v0::Constant::create(ov::element::i64, ov::Shape{reduction_axes.size()}, reduction_axes);
}

const std::set<element::Type> supported_types_v1 =
{element::u32, element::u64, element::i32, element::i64, element::f16, element::f32, element::f64};
const std::set<element::Type> supported_types_v2 =
{element::u32, element::u64, element::i32, element::i64, element::f16, element::f32, element::f64, element::bf16};

template <typename OpType>
std::shared_ptr<ov::Node> make_ng_reduction_op(const Node& node,
const ov::Output<ov::Node>& ng_input,
std::shared_ptr<ov::Node> make_ov_reduction_op(const Node& node,
const ov::Output<ov::Node>& ov_input,
const std::set<element::Type>& supported_types,
bool axes_as_attr = true) {
const std::int64_t keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);

CHECK_VALID_NODE(node,
supported_types.find(ov_input.get_element_type()) != supported_types.end(),
"Unsupported input type ",
ov_input.get_element_type().get_type_name());

const auto reduction_axes = axes_as_attr ? get_reduction_axes_from_attr(node) : get_reduction_axes_from_input(node);
if (reduction_axes != nullptr) {
return std::make_shared<OpType>(ng_input, reduction_axes, static_cast<bool>(keepdims));
return std::make_shared<OpType>(ov_input, reduction_axes, static_cast<bool>(keepdims));
} else {
return set_1::identity(node).at(0).get_node_shared_ptr();
}
}
} // namespace

namespace set_13 {
ov::OutputVector reduce_sum(const ov::frontend::onnx::Node& node) {
return {make_ng_reduction_op<v1::ReduceSum>(node, node.get_ov_inputs().at(0), false)};
}
} // namespace set_13

namespace set_1 {
ov::OutputVector reduce_log_sum(const ov::frontend::onnx::Node& node) {
const ov::Output<ov::Node> sum_node = make_ng_reduction_op<v1::ReduceSum>(node, node.get_ov_inputs().at(0));
const ov::Output<ov::Node> sum_node =
make_ov_reduction_op<v1::ReduceSum>(node, node.get_ov_inputs().at(0), supported_types_v1);
return {std::make_shared<v0::Log>(sum_node)};
}

ov::OutputVector reduce_log_sum_exp(const ov::frontend::onnx::Node& node) {
const auto exp_node = std::make_shared<v0::Exp>(node.get_ov_inputs().at(0));
const ov::Output<ov::Node> sum_node = make_ng_reduction_op<v1::ReduceSum>(node, exp_node);
const ov::Output<ov::Node> sum_node = make_ov_reduction_op<v1::ReduceSum>(node, exp_node, supported_types_v1);
return {std::make_shared<v0::Log>(sum_node)};
}

ov::OutputVector reduce_l1(const ov::frontend::onnx::Node& node) {
return {make_ng_reduction_op<v4::ReduceL1>(node, node.get_ov_inputs().at(0))};
return {make_ov_reduction_op<v4::ReduceL1>(node, node.get_ov_inputs().at(0), supported_types_v1)};
}

ov::OutputVector reduce_l2(const ov::frontend::onnx::Node& node) {
return {make_ng_reduction_op<v4::ReduceL2>(node, node.get_ov_inputs().at(0))};
return {make_ov_reduction_op<v4::ReduceL2>(node, node.get_ov_inputs().at(0), supported_types_v1)};
}

ov::OutputVector reduce_max(const ov::frontend::onnx::Node& node) {
return {make_ng_reduction_op<v1::ReduceMax>(node, node.get_ov_inputs().at(0))};
return {make_ov_reduction_op<v1::ReduceMax>(node, node.get_ov_inputs().at(0), supported_types_v1)};
}

ov::OutputVector reduce_mean(const ov::frontend::onnx::Node& node) {
return {make_ng_reduction_op<v1::ReduceMean>(node, node.get_ov_inputs().at(0))};
return {make_ov_reduction_op<v1::ReduceMean>(node, node.get_ov_inputs().at(0), supported_types_v1)};
}

ov::OutputVector reduce_min(const ov::frontend::onnx::Node& node) {
return {make_ng_reduction_op<v1::ReduceMin>(node, node.get_ov_inputs().at(0))};
return {make_ov_reduction_op<v1::ReduceMin>(node, node.get_ov_inputs().at(0), supported_types_v1)};
}

ov::OutputVector reduce_prod(const ov::frontend::onnx::Node& node) {
return {make_ng_reduction_op<v1::ReduceProd>(node, node.get_ov_inputs().at(0))};
return {make_ov_reduction_op<v1::ReduceProd>(node, node.get_ov_inputs().at(0), supported_types_v1)};
}

ov::OutputVector reduce_sum(const ov::frontend::onnx::Node& node) {
return {make_ng_reduction_op<v1::ReduceSum>(node, node.get_ov_inputs().at(0))};
return {make_ov_reduction_op<v1::ReduceSum>(node, node.get_ov_inputs().at(0), supported_types_v1)};
}

ov::OutputVector reduce_sum_square(const ov::frontend::onnx::Node& node) {
const auto input = ov::Output<ov::Node>{node.get_ov_inputs().at(0)};
const auto square_node = std::make_shared<v1::Multiply>(input, input);
return {make_ng_reduction_op<v1::ReduceSum>(node, square_node)};
return {make_ov_reduction_op<v1::ReduceSum>(node, square_node, supported_types_v1)};
}

} // namespace set_1

/*
Opset 11 is skipped because there are no significant difference between opset1 and opset 11.
Found difference is:
1. Operations (except ReduceMin and ReduceMax) are lost mention of zero-rank input behavior
from their description. We assume it shouldn't be worse than opset 1.
2. Opset 11 introduced requirement for axes values to be in a range [-r, r-1] where r = rank(data)
Same time Reduce* operations in OpenVINO has same requirement from first version
*/

namespace set_13 {
ov::OutputVector reduce_sum(const ov::frontend::onnx::Node& node) {
return {make_ov_reduction_op<v1::ReduceSum>(node, node.get_ov_inputs().at(0), supported_types_v2, false)};
}
} // namespace set_13

namespace set_18 {
// Placeholder
} // namespace set_18
} // namespace op
} // namespace onnx
} // namespace frontend
Expand Down
157 changes: 21 additions & 136 deletions src/frontends/onnx/frontend/src/op/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,162 +10,47 @@ namespace ov {
namespace frontend {
namespace onnx {
namespace op {
namespace set_13 {
/// \brief Compute the sum of the input tensor's elements along the provided
/// axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor has the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
ov::OutputVector reduce_sum(const ov::frontend::onnx::Node& node);
} // namespace set_13
namespace set_1 {
/// \brief Compute the log sum of the input tensor's elements along the
/// provided axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor have the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
ov::OutputVector reduce_log_sum(const ov::frontend::onnx::Node& node);
} // namespace set_1

/// \brief Compute the log sum exponent of the input tensor's elements along
/// the provided axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor have the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
namespace set_1 {
ov::OutputVector reduce_log_sum_exp(const ov::frontend::onnx::Node& node);
} // namespace set_1

/// \brief Compute the L1 norm of the input tensor's element along the provided
/// axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor have the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
namespace set_1 {
ov::OutputVector reduce_l1(const ov::frontend::onnx::Node& node);
} // namespace set_1

/// \brief Compute the L2 norm of the input tensor's element along the provided
/// axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor have the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
namespace set_1 {
ov::OutputVector reduce_l2(const ov::frontend::onnx::Node& node);
} // namespace set_1

/// \brief Compute the maximum value of the input tensor's elements along the
/// provided axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor have the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
namespace set_1 {
ov::OutputVector reduce_max(const ov::frontend::onnx::Node& node);
} // namespace set_1

/// \brief Compute the mean value of the input tensor's elements along the
/// provided axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor have the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
namespace set_1 {
ov::OutputVector reduce_mean(const ov::frontend::onnx::Node& node);
} // namespace set_1

/// \brief Compute the minimum value of the input tensor's elements along the
/// provided axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor have the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
namespace set_1 {
ov::OutputVector reduce_min(const ov::frontend::onnx::Node& node);
} // namespace set_1

/// \brief Compute the product of the input tensor's elements along the
/// provided axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor have the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
namespace set_1 {
ov::OutputVector reduce_prod(const ov::frontend::onnx::Node& node);
} // namespace set_1

/// \brief Compute the sum of the input tensor's elements along the provided
/// axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor have the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
namespace set_1 {
ov::OutputVector reduce_sum(const ov::frontend::onnx::Node& node);
} // namespace set_1
namespace set_13 {
ov::OutputVector reduce_sum(const ov::frontend::onnx::Node& node);
} // namespace set_13

/// \brief Compute the sum square of the input tensor's element along the
/// provided axes.
///
/// \par Overview
/// The output tensor has the same rank as the input if Node attribute keepdims
/// equals 1. If keepdims equals 0, then the output tensor have the reduced
/// dimension pruned.
///
/// \param[in] node The ONNX node representing operation.
///
/// \return The OV node equivalent of the ONNX operation.
///
namespace set_1 {
ov::OutputVector reduce_sum_square(const ov::frontend::onnx::Node& node);

} // namespace set_1
} // namespace op
} // namespace onnx
Expand Down
48 changes: 48 additions & 0 deletions src/frontends/onnx/tests/models/reduce_wrong_type_v1.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
ir_version: 3
producer_name: "OpenVINO ONNX Frontend"
graph {
node {
input: "A"
output: "B"
op_type: "ReduceProd"
}
name: "compute_graph"
input {
name: "A"
type {
tensor_type {
elem_type: 16
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "B"
type {
tensor_type {
elem_type: 16
shape {
dim {
dim_value: 1
}
}
}
}
}
}
opset_import {
version: 1
}
Loading

0 comments on commit 149381e

Please sign in to comment.