diff --git a/src/core/include/openvino/op/topk.hpp b/src/core/include/openvino/op/topk.hpp index c54f48cf3185e5..27917bd8477326 100644 --- a/src/core/include/openvino/op/topk.hpp +++ b/src/core/include/openvino/op/topk.hpp @@ -16,9 +16,9 @@ namespace v1 { /// \brief Computes indices and values of the k maximum/minimum values /// for each slice along specified axis. /// \ingroup ov_ops_cpp_api -class OPENVINO_API TopK : public util::TopK_Base { +class OPENVINO_API TopK : public util::TopKBase { public: - OPENVINO_OP("TopK", "opset1", op::util::TopK_Base, 1); + OPENVINO_OP("TopK", "opset1", op::util::TopKBase, 1); using SortType = TopKSortType; using Mode = TopKMode; @@ -67,9 +67,9 @@ namespace v3 { /// \brief Computes indices and values of the k maximum/minimum values /// for each slice along specified axis. /// \ingroup ov_ops_cpp_api -class OPENVINO_API TopK : public util::TopK_Base { +class OPENVINO_API TopK : public util::TopKBase { public: - OPENVINO_OP("TopK", "opset3", op::util::TopK_Base, 3); + OPENVINO_OP("TopK", "opset3", op::util::TopKBase, 3); /// \brief Constructs a TopK operation TopK() = default; /// \brief Constructs a TopK operation with two outputs: values and indices. @@ -109,9 +109,9 @@ class OPENVINO_API TopK : public util::TopK_Base { namespace v11 { /// \brief Computes the top K elements of a given tensor along the specified axis. /// \ingroup ov_ops_cpp_api -class OPENVINO_API TopK : public util::TopK_Base { +class OPENVINO_API TopK : public util::TopKBase { public: - OPENVINO_OP("TopK", "opset11", op::util::TopK_Base, 11); + OPENVINO_OP("TopK", "opset11", op::util::TopKBase, 11); /// \brief Constructs a TopK operation TopK() = default; /// \brief Constructs a TopK operation with two outputs: values and indices. diff --git a/src/core/include/openvino/op/util/topk_base.hpp b/src/core/include/openvino/op/util/topk_base.hpp index fe56fc0d7553d2..7b6f15ac6e73a6 100644 --- a/src/core/include/openvino/op/util/topk_base.hpp +++ b/src/core/include/openvino/op/util/topk_base.hpp @@ -10,37 +10,36 @@ namespace ov { namespace op { namespace util { -class OPENVINO_API TopK_Base : public Op { +class OPENVINO_API TopKBase : public Op { public: using Mode = TopKMode; using SortType = TopKSortType; - OPENVINO_OP("TopK_Base", "util"); - TopK_Base() = default; + OPENVINO_OP("TopKBase", "util"); + TopKBase() = default; - /// \param arg The node producing the input data batch tensor. - /// \param strides The strides. - /// \param pads_begin The beginning of padding shape. - /// \param pads_end The end of padding shape. - /// \param kernel The kernel shape. - /// \param rounding_mode Whether to use ceiling or floor rounding type while - /// computing output shape. - /// \param auto_pad The pad type for automatically computing padding sizes. - TopK_Base(const Output& data, - const Output& k, - const int64_t axis, - const std::string& mode, - const std::string& sort, - const element::Type& index_element_type = element::i32, - const bool stable = false); + /// \brief The common base class for all TopK operator versions + /// + /// \param data The input tensor + /// \param k Specifies how many maximum/minimum elements should be computed + /// \param axis The axis along which TopK should be computed + /// \param mode Specifies whether the maximum or minimum elements are selected + /// \param sort Specifies the order of output elements and/or indices + /// Accepted values: none, index, value + /// \param index_element_type Specifies the type of produced indices + TopKBase(const Output& data, + const Output& k, + const int64_t axis, + const std::string& mode, + const std::string& sort, + const element::Type& index_element_type = element::i32); - TopK_Base(const Output& data, - const Output& k, - const int64_t axis, - const TopKMode mode, - const TopKSortType sort, - const element::Type& index_element_type = element::i32, - const bool stable = false); + TopKBase(const Output& data, + const Output& k, + const int64_t axis, + const TopKMode mode, + const TopKSortType sort, + const element::Type& index_element_type = element::i32); void validate_and_infer_types() override; bool visit_attributes(AttributeVisitor& visitor) override; diff --git a/src/core/shape_inference/include/topk_shape_inference.hpp b/src/core/shape_inference/include/topk_shape_inference.hpp index 3c732e85011b24..93893a450f1dcf 100644 --- a/src/core/shape_inference/include/topk_shape_inference.hpp +++ b/src/core/shape_inference/include/topk_shape_inference.hpp @@ -16,9 +16,9 @@ namespace util { // Helper to get correct K from tensor as shape. template struct GetK { - const util::TopK_Base* m_op; + const util::TopKBase* m_op; - GetK(const util::TopK_Base* op) : m_op{op} {} + GetK(const util::TopKBase* op) : m_op{op} {} template T operator()(const K k) const { @@ -43,7 +43,7 @@ struct GetK { * \return Vector of output shapes for */ template -std::vector shape_infer(const util::TopK_Base* op, +std::vector shape_infer(const util::TopKBase* op, const std::vector& input_shapes, const std::map& constant_data = {}) { using TDim = typename TShape::value_type; diff --git a/src/core/src/op/topk.cpp b/src/core/src/op/topk.cpp index 51b15d818078fe..89b07a323cea2b 100644 --- a/src/core/src/op/topk.cpp +++ b/src/core/src/op/topk.cpp @@ -114,7 +114,7 @@ op::v1::TopK::TopK(const Output& data, const std::string& mode, const std::string& sort, const element::Type& index_element_type) - : util::TopK_Base(data, k, axis, mode, sort, index_element_type) { + : util::TopKBase(data, k, axis, mode, sort, index_element_type) { constructor_validate_and_infer_types(); } @@ -124,7 +124,7 @@ op::v1::TopK::TopK(const Output& data, const Mode mode, const SortType sort, const element::Type& index_element_type) - : util::TopK_Base(data, k, axis, mode, sort, index_element_type) { + : util::TopKBase(data, k, axis, mode, sort, index_element_type) { constructor_validate_and_infer_types(); } @@ -233,7 +233,7 @@ op::v3::TopK::TopK(const Output& data, const Mode mode, const SortType sort, const element::Type& index_element_type) - : util::TopK_Base{data, k, axis, mode, sort, index_element_type} { + : util::TopKBase{data, k, axis, mode, sort, index_element_type} { constructor_validate_and_infer_types(); } @@ -335,7 +335,7 @@ ov::op::v11::TopK::TopK(const Output& data, const TopKSortType sort, const element::Type& index_element_type, const bool stable) - : util::TopK_Base{data, k, axis, mode, sort, index_element_type}, + : util::TopKBase{data, k, axis, mode, sort, index_element_type}, m_stable{stable} { constructor_validate_and_infer_types(); } @@ -351,12 +351,12 @@ void ov::op::v11::TopK::validate_and_infer_types() { AttributeAdapter(m_sort).get()); } - util::TopK_Base::validate_and_infer_types(); + util::TopKBase::validate_and_infer_types(); } bool ov::op::v11::TopK::visit_attributes(AttributeVisitor& visitor) { OV_OP_SCOPE(v11_TopK_visit_attributes); - util::TopK_Base::visit_attributes(visitor); + util::TopKBase::visit_attributes(visitor); visitor.on_attribute("stable", m_stable); return true; } diff --git a/src/core/src/op/util/topk_base.cpp b/src/core/src/op/util/topk_base.cpp index d30d23eb07faa2..3bc07aa0b7ba5b 100644 --- a/src/core/src/op/util/topk_base.cpp +++ b/src/core/src/op/util/topk_base.cpp @@ -11,28 +11,24 @@ #include "itt.hpp" #include "openvino/op/util/precision_sensitive_attribute.hpp" -using namespace std; - namespace { constexpr auto UNKNOWN_NORMALIZED_AXIS = std::numeric_limits::max(); } -ov::op::util::TopK_Base::TopK_Base(const Output& data, - const Output& k, - const int64_t axis, - const std::string& mode, - const std::string& sort, - const element::Type& index_element_type, - const bool stable) - : TopK_Base(data, k, axis, as_enum(mode), as_enum(sort), index_element_type) {} - -ov::op::util::TopK_Base::TopK_Base(const Output& data, - const Output& k, - const int64_t axis, - const TopKMode mode, - const TopKSortType sort, - const element::Type& index_element_type, - const bool stable) +ov::op::util::TopKBase::TopKBase(const Output& data, + const Output& k, + const int64_t axis, + const std::string& mode, + const std::string& sort, + const element::Type& index_element_type) + : TopKBase(data, k, axis, as_enum(mode), as_enum(sort), index_element_type) {} + +ov::op::util::TopKBase::TopKBase(const Output& data, + const Output& k, + const int64_t axis, + const TopKMode mode, + const TopKSortType sort, + const element::Type& index_element_type) : Op{{data, k}}, m_axis{axis}, m_normalized_axis{UNKNOWN_NORMALIZED_AXIS}, @@ -42,7 +38,7 @@ ov::op::util::TopK_Base::TopK_Base(const Output& data, ov::mark_as_precision_sensitive(input(1)); } -void ov::op::util::TopK_Base::validate_and_infer_types() { +void ov::op::util::TopKBase::validate_and_infer_types() { OV_OP_SCOPE(util_TopK_Base_validate_and_infer_types); k_type_check(get_input_element_type(1)); @@ -55,7 +51,7 @@ void ov::op::util::TopK_Base::validate_and_infer_types() { set_output_type(1, m_index_element_type, output_shapes[1]); } -bool ov::op::util::TopK_Base::visit_attributes(AttributeVisitor& visitor) { +bool ov::op::util::TopKBase::visit_attributes(AttributeVisitor& visitor) { OV_OP_SCOPE(util_TopK_Base_visit_attributes); visitor.on_attribute("axis", m_axis); visitor.on_attribute("mode", m_mode); @@ -64,15 +60,15 @@ bool ov::op::util::TopK_Base::visit_attributes(AttributeVisitor& visitor) { return true; } -void ov::op::util::TopK_Base::k_type_check(const element::Type& k_element_type) const { +void ov::op::util::TopKBase::k_type_check(const element::Type& k_element_type) const { NODE_VALIDATION_CHECK(this, k_element_type.is_integral_number(), "K input has to be an integer type, which does match the provided one:", k_element_type); } -size_t ov::op::util::TopK_Base::read_k_from_constant_node(const shared_ptr& node, - const element::Type& k_element_type) const { +size_t ov::op::util::TopKBase::read_k_from_constant_node(const std::shared_ptr& node, + const element::Type& k_element_type) const { k_type_check(k_element_type); const auto k_constant = ov::as_type_ptr(node); @@ -112,7 +108,7 @@ size_t ov::op::util::TopK_Base::read_k_from_constant_node(const shared_ptr } template -size_t ov::op::util::TopK_Base::validate_and_get_k(const shared_ptr& k_constant) const { +size_t ov::op::util::TopKBase::validate_and_get_k(const std::shared_ptr& k_constant) const { const auto k_const_contents = k_constant->get_vector(); NODE_VALIDATION_CHECK(this, @@ -132,11 +128,11 @@ size_t ov::op::util::TopK_Base::validate_and_get_k(const shared_ptr(k_const_contents[0]); } -void ov::op::util::TopK_Base::set_k(size_t k) { +void ov::op::util::TopKBase::set_k(size_t k) { this->input(1).replace_source_output(op::v0::Constant::create(element::i64, ov::Shape{}, {k})->output(0)); } -size_t ov::op::util::TopK_Base::get_k() const { +size_t ov::op::util::TopKBase::get_k() const { size_t k = 0; if (op::util::is_constant(input_value(1).get_node())) { k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(), get_input_element_type(1)); @@ -148,16 +144,16 @@ size_t ov::op::util::TopK_Base::get_k() const { return k; } -void ov::op::util::TopK_Base::set_axis(const int64_t axis) { +void ov::op::util::TopKBase::set_axis(const int64_t axis) { set_axis(get_input_partial_shape(0).rank(), axis); } -void ov::op::util::TopK_Base::set_axis(const Rank& input_rank, const int64_t axis) { +void ov::op::util::TopKBase::set_axis(const Rank& input_rank, const int64_t axis) { m_normalized_axis = input_rank.is_static() ? normalize_axis(this, axis, input_rank) : UNKNOWN_NORMALIZED_AXIS; m_axis = axis; } -uint64_t ov::op::util::TopK_Base::get_axis() const { +uint64_t ov::op::util::TopKBase::get_axis() const { NODE_VALIDATION_CHECK(this, m_normalized_axis != UNKNOWN_NORMALIZED_AXIS, "Normalized axis of TopK is unknown"); return m_normalized_axis; diff --git a/src/frontends/onnx/tests/onnx_tensor_names.cpp b/src/frontends/onnx/tests/onnx_tensor_names.cpp index ea5213ea35baf7..9f04fcf8188d9e 100644 --- a/src/frontends/onnx/tests/onnx_tensor_names.cpp +++ b/src/frontends/onnx/tests/onnx_tensor_names.cpp @@ -69,8 +69,8 @@ NGRAPH_TEST(onnx_tensor_names, node_multiple_outputs) { const auto ops = function->get_ordered_ops(); EXPECT_TRUE(matching_node_found_in_graph(ops, "x", {"x"})); - EXPECT_TRUE(matching_node_found_in_graph(ops, "indices", {"values"}, 0)); - EXPECT_TRUE(matching_node_found_in_graph(ops, "indices", {"indices"}, 1)); + EXPECT_TRUE(matching_node_found_in_graph(ops, "indices", {"values"}, 0)); + EXPECT_TRUE(matching_node_found_in_graph(ops, "indices", {"indices"}, 1)); const auto results = function->get_results(); EXPECT_TRUE(matching_node_found_in_graph(results, "values/sink_port_0", {"values"}));