Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TopK base class cleanup #16154

Merged
merged 4 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/core/include/openvino/op/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace v1 {
/// \brief Computes indices and values of the k maximum/minimum values
/// for each slice along specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public util::TopK_Base {
class OPENVINO_API TopK : public util::TopKBase {
public:
OPENVINO_OP("TopK", "opset1", op::util::TopK_Base, 1);
OPENVINO_OP("TopK", "opset1", op::util::TopKBase, 1);

using SortType = TopKSortType;
using Mode = TopKMode;
Expand Down Expand Up @@ -67,9 +67,9 @@ namespace v3 {
/// \brief Computes indices and values of the k maximum/minimum values
/// for each slice along specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public util::TopK_Base {
class OPENVINO_API TopK : public util::TopKBase {
public:
OPENVINO_OP("TopK", "opset3", op::util::TopK_Base, 3);
OPENVINO_OP("TopK", "opset3", op::util::TopKBase, 3);
/// \brief Constructs a TopK operation
TopK() = default;
/// \brief Constructs a TopK operation with two outputs: values and indices.
Expand Down Expand Up @@ -109,9 +109,9 @@ class OPENVINO_API TopK : public util::TopK_Base {
namespace v11 {
/// \brief Computes the top K elements of a given tensor along the specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public util::TopK_Base {
class OPENVINO_API TopK : public util::TopKBase {
public:
OPENVINO_OP("TopK", "opset11", op::util::TopK_Base, 11);
OPENVINO_OP("TopK", "opset11", op::util::TopKBase, 11);
/// \brief Constructs a TopK operation
TopK() = default;
/// \brief Constructs a TopK operation with two outputs: values and indices.
Expand Down
49 changes: 24 additions & 25 deletions src/core/include/openvino/op/util/topk_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,36 @@
namespace ov {
namespace op {
namespace util {
class OPENVINO_API TopK_Base : public Op {
class OPENVINO_API TopKBase : public Op {
public:
using Mode = TopKMode;
using SortType = TopKSortType;

OPENVINO_OP("TopK_Base", "util");
TopK_Base() = default;
OPENVINO_OP("TopKBase", "util");
TopKBase() = default;

/// \param arg The node producing the input data batch tensor.
/// \param strides The strides.
/// \param pads_begin The beginning of padding shape.
/// \param pads_end The end of padding shape.
/// \param kernel The kernel shape.
/// \param rounding_mode Whether to use ceiling or floor rounding type while
/// computing output shape.
/// \param auto_pad The pad type for automatically computing padding sizes.
TopK_Base(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type = element::i32,
const bool stable = false);
/// \brief The common base class for all TopK operator versions
///
/// \param data The input tensor
/// \param k Specifies how many maximum/minimum elements should be computed
/// \param axis The axis along which TopK should be computed
/// \param mode Specifies whether the maximum or minimum elements are selected
/// \param sort Specifies the order of output elements and/or indices
/// Accepted values: none, index, value
/// \param index_element_type Specifies the type of produced indices
TopKBase(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type = element::i32);

TopK_Base(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type = element::i32,
const bool stable = false);
TopKBase(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type = element::i32);

void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
Expand Down
6 changes: 3 additions & 3 deletions src/core/shape_inference/include/topk_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace util {
// Helper to get correct K from tensor as shape.
template <class T>
struct GetK {
const util::TopK_Base* m_op;
const util::TopKBase* m_op;

GetK(const util::TopK_Base* op) : m_op{op} {}
GetK(const util::TopKBase* op) : m_op{op} {}

template <class K>
T operator()(const K k) const {
Expand All @@ -43,7 +43,7 @@ struct GetK {
* \return Vector of output shapes for
*/
template <class TShape>
std::vector<TShape> shape_infer(const util::TopK_Base* op,
std::vector<TShape> shape_infer(const util::TopKBase* op,
const std::vector<TShape>& input_shapes,
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
using TDim = typename TShape::value_type;
Expand Down
12 changes: 6 additions & 6 deletions src/core/src/op/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ op::v1::TopK::TopK(const Output<Node>& data,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
: util::TopK_Base(data, k, axis, mode, sort, index_element_type) {
: util::TopKBase(data, k, axis, mode, sort, index_element_type) {
constructor_validate_and_infer_types();
}

Expand All @@ -124,7 +124,7 @@ op::v1::TopK::TopK(const Output<Node>& data,
const Mode mode,
const SortType sort,
const element::Type& index_element_type)
: util::TopK_Base(data, k, axis, mode, sort, index_element_type) {
: util::TopKBase(data, k, axis, mode, sort, index_element_type) {
constructor_validate_and_infer_types();
}

Expand Down Expand Up @@ -233,7 +233,7 @@ op::v3::TopK::TopK(const Output<Node>& data,
const Mode mode,
const SortType sort,
const element::Type& index_element_type)
: util::TopK_Base{data, k, axis, mode, sort, index_element_type} {
: util::TopKBase{data, k, axis, mode, sort, index_element_type} {
constructor_validate_and_infer_types();
}

Expand Down Expand Up @@ -335,7 +335,7 @@ ov::op::v11::TopK::TopK(const Output<Node>& data,
const TopKSortType sort,
const element::Type& index_element_type,
const bool stable)
: util::TopK_Base{data, k, axis, mode, sort, index_element_type},
: util::TopKBase{data, k, axis, mode, sort, index_element_type},
m_stable{stable} {
constructor_validate_and_infer_types();
}
Expand All @@ -351,12 +351,12 @@ void ov::op::v11::TopK::validate_and_infer_types() {
AttributeAdapter<TopKSortType>(m_sort).get());
}

util::TopK_Base::validate_and_infer_types();
util::TopKBase::validate_and_infer_types();
}

bool ov::op::v11::TopK::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v11_TopK_visit_attributes);
util::TopK_Base::visit_attributes(visitor);
util::TopKBase::visit_attributes(visitor);
visitor.on_attribute("stable", m_stable);
return true;
}
Expand Down
54 changes: 25 additions & 29 deletions src/core/src/op/util/topk_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,24 @@
#include "itt.hpp"
#include "openvino/op/util/precision_sensitive_attribute.hpp"

using namespace std;

namespace {
constexpr auto UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
}

ov::op::util::TopK_Base::TopK_Base(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type,
const bool stable)
: TopK_Base(data, k, axis, as_enum<TopKMode>(mode), as_enum<TopKSortType>(sort), index_element_type) {}

ov::op::util::TopK_Base::TopK_Base(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type,
const bool stable)
ov::op::util::TopKBase::TopKBase(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
: TopKBase(data, k, axis, as_enum<TopKMode>(mode), as_enum<TopKSortType>(sort), index_element_type) {}

ov::op::util::TopKBase::TopKBase(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type)
: Op{{data, k}},
m_axis{axis},
m_normalized_axis{UNKNOWN_NORMALIZED_AXIS},
Expand All @@ -42,7 +38,7 @@ ov::op::util::TopK_Base::TopK_Base(const Output<Node>& data,
ov::mark_as_precision_sensitive(input(1));
}

void ov::op::util::TopK_Base::validate_and_infer_types() {
void ov::op::util::TopKBase::validate_and_infer_types() {
OV_OP_SCOPE(util_TopK_Base_validate_and_infer_types);

k_type_check(get_input_element_type(1));
Expand All @@ -55,7 +51,7 @@ void ov::op::util::TopK_Base::validate_and_infer_types() {
set_output_type(1, m_index_element_type, output_shapes[1]);
}

bool ov::op::util::TopK_Base::visit_attributes(AttributeVisitor& visitor) {
bool ov::op::util::TopKBase::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(util_TopK_Base_visit_attributes);
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("mode", m_mode);
Expand All @@ -64,15 +60,15 @@ bool ov::op::util::TopK_Base::visit_attributes(AttributeVisitor& visitor) {
return true;
}

void ov::op::util::TopK_Base::k_type_check(const element::Type& k_element_type) const {
void ov::op::util::TopKBase::k_type_check(const element::Type& k_element_type) const {
NODE_VALIDATION_CHECK(this,
k_element_type.is_integral_number(),
"K input has to be an integer type, which does match the provided one:",
k_element_type);
}

size_t ov::op::util::TopK_Base::read_k_from_constant_node(const shared_ptr<Node>& node,
const element::Type& k_element_type) const {
size_t ov::op::util::TopKBase::read_k_from_constant_node(const std::shared_ptr<Node>& node,
const element::Type& k_element_type) const {
k_type_check(k_element_type);

const auto k_constant = ov::as_type_ptr<op::v0::Constant>(node);
Expand Down Expand Up @@ -112,7 +108,7 @@ size_t ov::op::util::TopK_Base::read_k_from_constant_node(const shared_ptr<Node>
}

template <typename T>
size_t ov::op::util::TopK_Base::validate_and_get_k(const shared_ptr<op::v0::Constant>& k_constant) const {
size_t ov::op::util::TopKBase::validate_and_get_k(const std::shared_ptr<op::v0::Constant>& k_constant) const {
const auto k_const_contents = k_constant->get_vector<T>();

NODE_VALIDATION_CHECK(this,
Expand All @@ -132,11 +128,11 @@ size_t ov::op::util::TopK_Base::validate_and_get_k(const shared_ptr<op::v0::Cons
return static_cast<size_t>(k_const_contents[0]);
}

void ov::op::util::TopK_Base::set_k(size_t k) {
void ov::op::util::TopKBase::set_k(size_t k) {
this->input(1).replace_source_output(op::v0::Constant::create(element::i64, ov::Shape{}, {k})->output(0));
}

size_t ov::op::util::TopK_Base::get_k() const {
size_t ov::op::util::TopKBase::get_k() const {
size_t k = 0;
if (op::util::is_constant(input_value(1).get_node())) {
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(), get_input_element_type(1));
Expand All @@ -148,16 +144,16 @@ size_t ov::op::util::TopK_Base::get_k() const {
return k;
}

void ov::op::util::TopK_Base::set_axis(const int64_t axis) {
void ov::op::util::TopKBase::set_axis(const int64_t axis) {
set_axis(get_input_partial_shape(0).rank(), axis);
}

void ov::op::util::TopK_Base::set_axis(const Rank& input_rank, const int64_t axis) {
void ov::op::util::TopKBase::set_axis(const Rank& input_rank, const int64_t axis) {
m_normalized_axis = input_rank.is_static() ? normalize_axis(this, axis, input_rank) : UNKNOWN_NORMALIZED_AXIS;
m_axis = axis;
}

uint64_t ov::op::util::TopK_Base::get_axis() const {
uint64_t ov::op::util::TopKBase::get_axis() const {
NODE_VALIDATION_CHECK(this, m_normalized_axis != UNKNOWN_NORMALIZED_AXIS, "Normalized axis of TopK is unknown");

return m_normalized_axis;
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/onnx/tests/onnx_tensor_names.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ NGRAPH_TEST(onnx_tensor_names, node_multiple_outputs) {

const auto ops = function->get_ordered_ops();
EXPECT_TRUE(matching_node_found_in_graph<op::Parameter>(ops, "x", {"x"}));
EXPECT_TRUE(matching_node_found_in_graph<ov::op::util::TopK_Base>(ops, "indices", {"values"}, 0));
EXPECT_TRUE(matching_node_found_in_graph<ov::op::util::TopK_Base>(ops, "indices", {"indices"}, 1));
EXPECT_TRUE(matching_node_found_in_graph<ov::op::util::TopKBase>(ops, "indices", {"values"}, 0));
EXPECT_TRUE(matching_node_found_in_graph<ov::op::util::TopKBase>(ops, "indices", {"indices"}, 1));

const auto results = function->get_results();
EXPECT_TRUE(matching_node_found_in_graph<op::Result>(results, "values/sink_port_0", {"values"}));
Expand Down