Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TopK v11 core operator #15910

Merged
merged 22 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@


def test_wrap_type_pattern_type():
last_opstet_number = 10
for i in range(1, last_opstet_number + 1):
last_opset_number = 11
for i in range(1, last_opset_number + 1):
WrapType(f"opset{i}.Parameter")
WrapType(f"opset{i}::Parameter")

# Negative check not to forget to update opset map in get_type function
expect_exception(lambda: WrapType(f"opset{last_opstet_number + 1}.Parameter"),
f"Unsupported opset type: opset{last_opstet_number + 1}")
expect_exception(lambda: WrapType(f"opset{last_opset_number + 1}.Parameter"),
f"Unsupported opset type: opset{last_opset_number + 1}")

# Generic negative test cases
expect_exception(lambda: WrapType(""))
Expand Down
4 changes: 4 additions & 0 deletions src/core/include/ngraph/op/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,9 @@ using ov::op::v1::TopK;
namespace v3 {
using ov::op::v3::TopK;
} // namespace v3

namespace v11 {
using ov::op::v11::TopK;
} // namespace v11
} // namespace op
} // namespace ngraph
1 change: 1 addition & 0 deletions src/core/include/ngraph/opsets/opset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,6 @@ const NGRAPH_API OpSet& get_opset7();
const NGRAPH_API OpSet& get_opset8();
const NGRAPH_API OpSet& get_opset9();
const NGRAPH_API OpSet& get_opset10();
const NGRAPH_API OpSet& get_opset11();
const NGRAPH_API std::map<std::string, std::function<const ngraph::OpSet&()>>& get_available_opsets();
} // namespace ngraph
136 changes: 68 additions & 68 deletions src/core/include/openvino/op/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@

#include "openvino/op/constant.hpp"
#include "openvino/op/op.hpp"
#include "openvino/op/util/topk_base.hpp"

namespace ov {
namespace op {
namespace v1 {
/// \brief Computes indices and values of the k maximum/minimum values
/// for each slice along specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public Op {
class OPENVINO_API TopK : public util::TopK_Base {
public:
OPENVINO_OP("TopK", "opset1", op::Op, 1);
OPENVINO_OP("TopK", "opset1", op::util::TopK_Base, 1);

using SortType = TopKSortType;
using Mode = TopKMode;
Expand Down Expand Up @@ -50,79 +51,25 @@ class OPENVINO_API TopK : public Op {
const SortType sort,
const element::Type& index_element_type = element::i32);

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

/// \brief Returns axis value after normalization
/// \note If input rank required to normalization is dynamic, the exception is
/// thrown
uint64_t get_axis() const;
/// \brief Returns axis value before normalization
int64_t get_provided_axis() const {
return m_axis;
}
void set_axis(const int64_t axis);
void set_axis(const Rank& input_rank, const int64_t axis);
Mode get_mode() const {
return m_mode;
}
void set_mode(const Mode mode) {
m_mode = mode;
}
SortType get_sort_type() const {
return m_sort;
}
void set_sort_type(const SortType sort) {
m_sort = sort;
}
element::Type get_index_element_type() const {
return m_index_element_type;
}
void set_index_element_type(const element::Type& index_element_type) {
m_index_element_type = index_element_type;
}
/// \brief Returns the value of K, if available
///
/// \note If the second input to this op is a constant, the value is retrieved
/// and returned. If the input is not constant(dynamic) this method returns 0
size_t get_k() const;
void set_k(size_t k);
size_t get_default_output_index() const override {
return no_default_index();
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool has_evaluate() const override;

protected:
int64_t m_axis;
uint64_t m_normalized_axis;
Mode m_mode;
SortType m_sort;
element::Type m_index_element_type{element::i32};

virtual size_t read_k_from_constant_node(const std::shared_ptr<Node>& node,
const element::Type& k_element_type) const;

template <typename T>
size_t validate_and_get_k(const std::shared_ptr<op::v0::Constant>& k_constant) const;
Shape compute_output_shape(const std::string& node_description,
const PartialShape input_partial_shape,
const int64_t k) const;
virtual void k_type_check(const element::Type& k_element_type) const;
virtual void k_type_check(const element::Type& k_element_type) const override;
};
} // namespace v1

namespace v3 {
/// \brief Computes indices and values of the k maximum/minimum values
/// for each slice along specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public v1::TopK {
class OPENVINO_API TopK : public util::TopK_Base {
public:
OPENVINO_OP("TopK", "opset3", op::Op, 3);
OPENVINO_OP("TopK", "opset3", op::util::TopK_Base, 3);
/// \brief Constructs a TopK operation
TopK() = default;
/// \brief Constructs a TopK operation with two outputs: values and indices.
Expand All @@ -147,23 +94,76 @@ class OPENVINO_API TopK : public v1::TopK {
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const Mode mode,
const SortType sort,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type = element::i32);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool has_evaluate() const override;

protected:
size_t read_k_from_constant_node(const std::shared_ptr<Node>& node,
const element::Type& k_element_type) const override;
void k_type_check(const element::Type& k_element_type) const override;
};
} // namespace v3

namespace v11 {
/// \brief Computes the top K elements of a given tensor along the specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public util::TopK_Base {
public:
OPENVINO_OP("TopK", "opset11", op::util::TopK_Base, 11);
/// \brief Constructs a TopK operation
TopK() = default;
/// \brief Constructs a TopK operation with two outputs: values and indices.
///
/// \param data The input tensor
/// \param k Specifies how many maximum/minimum elements should be computed
/// \param axis The axis along which the TopK operation should be executed
/// \param mode Specifies whether TopK selects the largest or the smallest elements from each slice
/// \param sort Specifies the order of corresponding elements of the output tensor
/// \param index_element_type Specifies the data type type of of the elements in the 'indices' output tensor.
/// \param stable Specifies whether the equivalent elements should maintain their relative order
/// from the input tensor during sorting.
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type = element::i32,
const bool stable = false);

/// \brief Constructs a TopK operation with two outputs: values and indices.
///
/// \param data The input tensor
/// \param k Specifies how many maximum/minimum elements should be computed
/// \param axis The axis along which the TopK operation should be executed
/// \param mode Specifies whether TopK selects the largest or the smallest elements from each slice
/// \param sort Specifies the order of corresponding elements of the output tensor
/// \param index_element_type Specifies the data type type of of the elements in the 'indices' output tensor.
/// \param stable Specifies whether the equivalent elements should maintain their relative order
/// from the input tensor during sorting.
TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type = element::i32,
const bool stable = false);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

bool get_stable() const {
return m_stable;
}

void set_stable(const bool stable) {
m_stable = stable;
}

private:
bool m_stable;
};
} // namespace v11
} // namespace op
} // namespace ov
100 changes: 100 additions & 0 deletions src/core/include/openvino/op/util/topk_base.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "openvino/op/util/attr_types.hpp"

namespace ov {
namespace op {
namespace util {
class OPENVINO_API TopK_Base : public Op {
public:
using Mode = TopKMode;
using SortType = TopKSortType;

OPENVINO_OP("TopK_Base", "util");
TopK_Base() = default;

/// \param arg The node producing the input data batch tensor.
/// \param strides The strides.
/// \param pads_begin The beginning of padding shape.
/// \param pads_end The end of padding shape.
/// \param kernel The kernel shape.
/// \param rounding_mode Whether to use ceiling or floor rounding type while
/// computing output shape.
/// \param auto_pad The pad type for automatically computing padding sizes.
TopK_Base(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type = element::i32,
const bool stable = false);

TopK_Base(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type = element::i32,
const bool stable = false);

void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;

/// \brief Returns axis value after normalization
/// \note If input rank required to normalization is dynamic, the exception is
/// thrown
uint64_t get_axis() const;
/// \brief Returns axis value before normalization
int64_t get_provided_axis() const {
return m_axis;
}
void set_axis(const int64_t axis);
void set_axis(const Rank& input_rank, const int64_t axis);
TopKMode get_mode() const {
return m_mode;
}
void set_mode(const TopKMode mode) {
m_mode = mode;
}
TopKSortType get_sort_type() const {
return m_sort;
}
void set_sort_type(const TopKSortType sort) {
m_sort = sort;
}
element::Type get_index_element_type() const {
return m_index_element_type;
}
void set_index_element_type(const element::Type& index_element_type) {
m_index_element_type = index_element_type;
}
/// \brief Returns the value of K, if available
///
/// \note If the second input to this op is a constant, the value is retrieved
/// and returned. If the input is not constant(dynamic) this method returns 0
size_t get_k() const;
void set_k(size_t k);
size_t get_default_output_index() const override {
return no_default_index();
}

protected:
int64_t m_axis;
uint64_t m_normalized_axis;
TopKMode m_mode;
TopKSortType m_sort;
element::Type m_index_element_type{element::i32};

virtual void k_type_check(const element::Type& k_element_type) const;
size_t read_k_from_constant_node(const std::shared_ptr<Node>& node, const element::Type& k_element_type) const;
template <typename T>
size_t validate_and_get_k(const std::shared_ptr<op::v0::Constant>& k_constant) const;
};
} // namespace util
} // namespace op
} // namespace ov
5 changes: 5 additions & 0 deletions src/core/include/openvino/opsets/opset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ const OPENVINO_API OpSet& get_opset9();
* @ingroup ov_opset_cpp_api
*/
const OPENVINO_API OpSet& get_opset10();
/**
* @brief Returns opset11
* @ingroup ov_opset_cpp_api
*/
const OPENVINO_API OpSet& get_opset11();

/**
* @brief Returns map of available opsets
Expand Down
15 changes: 15 additions & 0 deletions src/core/include/openvino/opsets/opset11.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/ops.hpp"

namespace ov {
namespace opset11 {
#define _OPENVINO_OP_REG(a, b) using b::a;
#include "openvino/opsets/opset11_tbl.hpp"
#undef _OPENVINO_OP_REG
} // namespace opset11
} // namespace ov
Loading