Skip to content

Commit

Permalink
TopK base class cleanup (#16154)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomasz Dołbniak authored Mar 8, 2023
1 parent 50b7687 commit e5ef0fe
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 71 deletions.
12 changes: 6 additions & 6 deletions src/core/include/openvino/op/topk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace v1 {
/// \brief Computes indices and values of the k maximum/minimum values
/// for each slice along specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public util::TopK_Base {
class OPENVINO_API TopK : public util::TopKBase {
public:
OPENVINO_OP("TopK", "opset1", op::util::TopK_Base, 1);
OPENVINO_OP("TopK", "opset1", op::util::TopKBase, 1);

using SortType = TopKSortType;
using Mode = TopKMode;
Expand Down Expand Up @@ -67,9 +67,9 @@ namespace v3 {
/// \brief Computes indices and values of the k maximum/minimum values
/// for each slice along specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public util::TopK_Base {
class OPENVINO_API TopK : public util::TopKBase {
public:
OPENVINO_OP("TopK", "opset3", op::util::TopK_Base, 3);
OPENVINO_OP("TopK", "opset3", op::util::TopKBase, 3);
/// \brief Constructs a TopK operation
TopK() = default;
/// \brief Constructs a TopK operation with two outputs: values and indices.
Expand Down Expand Up @@ -109,9 +109,9 @@ class OPENVINO_API TopK : public util::TopK_Base {
namespace v11 {
/// \brief Computes the top K elements of a given tensor along the specified axis.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API TopK : public util::TopK_Base {
class OPENVINO_API TopK : public util::TopKBase {
public:
OPENVINO_OP("TopK", "opset11", op::util::TopK_Base, 11);
OPENVINO_OP("TopK", "opset11", op::util::TopKBase, 11);
/// \brief Constructs a TopK operation
TopK() = default;
/// \brief Constructs a TopK operation with two outputs: values and indices.
Expand Down
49 changes: 24 additions & 25 deletions src/core/include/openvino/op/util/topk_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,36 @@
namespace ov {
namespace op {
namespace util {
class OPENVINO_API TopK_Base : public Op {
class OPENVINO_API TopKBase : public Op {
public:
using Mode = TopKMode;
using SortType = TopKSortType;

OPENVINO_OP("TopK_Base", "util");
TopK_Base() = default;
OPENVINO_OP("TopKBase", "util");
TopKBase() = default;

/// \param arg The node producing the input data batch tensor.
/// \param strides The strides.
/// \param pads_begin The beginning of padding shape.
/// \param pads_end The end of padding shape.
/// \param kernel The kernel shape.
/// \param rounding_mode Whether to use ceiling or floor rounding type while
/// computing output shape.
/// \param auto_pad The pad type for automatically computing padding sizes.
TopK_Base(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type = element::i32,
const bool stable = false);
/// \brief The common base class for all TopK operator versions
///
/// \param data The input tensor
/// \param k Specifies how many maximum/minimum elements should be computed
/// \param axis The axis along which TopK should be computed
/// \param mode Specifies whether the maximum or minimum elements are selected
/// \param sort Specifies the order of output elements and/or indices
/// Accepted values: none, index, value
/// \param index_element_type Specifies the type of produced indices
TopKBase(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type = element::i32);

TopK_Base(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type = element::i32,
const bool stable = false);
TopKBase(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type = element::i32);

void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
Expand Down
6 changes: 3 additions & 3 deletions src/core/shape_inference/include/topk_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace util {
// Helper to get correct K from tensor as shape.
template <class T>
struct GetK {
const util::TopK_Base* m_op;
const util::TopKBase* m_op;

GetK(const util::TopK_Base* op) : m_op{op} {}
GetK(const util::TopKBase* op) : m_op{op} {}

template <class K>
T operator()(const K k) const {
Expand All @@ -43,7 +43,7 @@ struct GetK {
* \return Vector of output shapes for
*/
template <class TShape>
std::vector<TShape> shape_infer(const util::TopK_Base* op,
std::vector<TShape> shape_infer(const util::TopKBase* op,
const std::vector<TShape>& input_shapes,
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
using TDim = typename TShape::value_type;
Expand Down
12 changes: 6 additions & 6 deletions src/core/src/op/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ op::v1::TopK::TopK(const Output<Node>& data,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
: util::TopK_Base(data, k, axis, mode, sort, index_element_type) {
: util::TopKBase(data, k, axis, mode, sort, index_element_type) {
constructor_validate_and_infer_types();
}

Expand All @@ -124,7 +124,7 @@ op::v1::TopK::TopK(const Output<Node>& data,
const Mode mode,
const SortType sort,
const element::Type& index_element_type)
: util::TopK_Base(data, k, axis, mode, sort, index_element_type) {
: util::TopKBase(data, k, axis, mode, sort, index_element_type) {
constructor_validate_and_infer_types();
}

Expand Down Expand Up @@ -233,7 +233,7 @@ op::v3::TopK::TopK(const Output<Node>& data,
const Mode mode,
const SortType sort,
const element::Type& index_element_type)
: util::TopK_Base{data, k, axis, mode, sort, index_element_type} {
: util::TopKBase{data, k, axis, mode, sort, index_element_type} {
constructor_validate_and_infer_types();
}

Expand Down Expand Up @@ -335,7 +335,7 @@ ov::op::v11::TopK::TopK(const Output<Node>& data,
const TopKSortType sort,
const element::Type& index_element_type,
const bool stable)
: util::TopK_Base{data, k, axis, mode, sort, index_element_type},
: util::TopKBase{data, k, axis, mode, sort, index_element_type},
m_stable{stable} {
constructor_validate_and_infer_types();
}
Expand All @@ -351,12 +351,12 @@ void ov::op::v11::TopK::validate_and_infer_types() {
AttributeAdapter<TopKSortType>(m_sort).get());
}

util::TopK_Base::validate_and_infer_types();
util::TopKBase::validate_and_infer_types();
}

bool ov::op::v11::TopK::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v11_TopK_visit_attributes);
util::TopK_Base::visit_attributes(visitor);
util::TopKBase::visit_attributes(visitor);
visitor.on_attribute("stable", m_stable);
return true;
}
Expand Down
54 changes: 25 additions & 29 deletions src/core/src/op/util/topk_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,24 @@
#include "itt.hpp"
#include "openvino/op/util/precision_sensitive_attribute.hpp"

using namespace std;

namespace {
constexpr auto UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
}

ov::op::util::TopK_Base::TopK_Base(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type,
const bool stable)
: TopK_Base(data, k, axis, as_enum<TopKMode>(mode), as_enum<TopKSortType>(sort), index_element_type) {}

ov::op::util::TopK_Base::TopK_Base(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type,
const bool stable)
ov::op::util::TopKBase::TopKBase(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const element::Type& index_element_type)
: TopKBase(data, k, axis, as_enum<TopKMode>(mode), as_enum<TopKSortType>(sort), index_element_type) {}

ov::op::util::TopKBase::TopKBase(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const TopKMode mode,
const TopKSortType sort,
const element::Type& index_element_type)
: Op{{data, k}},
m_axis{axis},
m_normalized_axis{UNKNOWN_NORMALIZED_AXIS},
Expand All @@ -42,7 +38,7 @@ ov::op::util::TopK_Base::TopK_Base(const Output<Node>& data,
ov::mark_as_precision_sensitive(input(1));
}

void ov::op::util::TopK_Base::validate_and_infer_types() {
void ov::op::util::TopKBase::validate_and_infer_types() {
OV_OP_SCOPE(util_TopK_Base_validate_and_infer_types);

k_type_check(get_input_element_type(1));
Expand All @@ -55,7 +51,7 @@ void ov::op::util::TopK_Base::validate_and_infer_types() {
set_output_type(1, m_index_element_type, output_shapes[1]);
}

bool ov::op::util::TopK_Base::visit_attributes(AttributeVisitor& visitor) {
bool ov::op::util::TopKBase::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(util_TopK_Base_visit_attributes);
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("mode", m_mode);
Expand All @@ -64,15 +60,15 @@ bool ov::op::util::TopK_Base::visit_attributes(AttributeVisitor& visitor) {
return true;
}

void ov::op::util::TopK_Base::k_type_check(const element::Type& k_element_type) const {
void ov::op::util::TopKBase::k_type_check(const element::Type& k_element_type) const {
NODE_VALIDATION_CHECK(this,
k_element_type.is_integral_number(),
"K input has to be an integer type, which does match the provided one:",
k_element_type);
}

size_t ov::op::util::TopK_Base::read_k_from_constant_node(const shared_ptr<Node>& node,
const element::Type& k_element_type) const {
size_t ov::op::util::TopKBase::read_k_from_constant_node(const std::shared_ptr<Node>& node,
const element::Type& k_element_type) const {
k_type_check(k_element_type);

const auto k_constant = ov::as_type_ptr<op::v0::Constant>(node);
Expand Down Expand Up @@ -112,7 +108,7 @@ size_t ov::op::util::TopK_Base::read_k_from_constant_node(const shared_ptr<Node>
}

template <typename T>
size_t ov::op::util::TopK_Base::validate_and_get_k(const shared_ptr<op::v0::Constant>& k_constant) const {
size_t ov::op::util::TopKBase::validate_and_get_k(const std::shared_ptr<op::v0::Constant>& k_constant) const {
const auto k_const_contents = k_constant->get_vector<T>();

NODE_VALIDATION_CHECK(this,
Expand All @@ -132,11 +128,11 @@ size_t ov::op::util::TopK_Base::validate_and_get_k(const shared_ptr<op::v0::Cons
return static_cast<size_t>(k_const_contents[0]);
}

void ov::op::util::TopK_Base::set_k(size_t k) {
void ov::op::util::TopKBase::set_k(size_t k) {
this->input(1).replace_source_output(op::v0::Constant::create(element::i64, ov::Shape{}, {k})->output(0));
}

size_t ov::op::util::TopK_Base::get_k() const {
size_t ov::op::util::TopKBase::get_k() const {
size_t k = 0;
if (op::util::is_constant(input_value(1).get_node())) {
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(), get_input_element_type(1));
Expand All @@ -148,16 +144,16 @@ size_t ov::op::util::TopK_Base::get_k() const {
return k;
}

void ov::op::util::TopK_Base::set_axis(const int64_t axis) {
void ov::op::util::TopKBase::set_axis(const int64_t axis) {
set_axis(get_input_partial_shape(0).rank(), axis);
}

void ov::op::util::TopK_Base::set_axis(const Rank& input_rank, const int64_t axis) {
void ov::op::util::TopKBase::set_axis(const Rank& input_rank, const int64_t axis) {
m_normalized_axis = input_rank.is_static() ? normalize_axis(this, axis, input_rank) : UNKNOWN_NORMALIZED_AXIS;
m_axis = axis;
}

uint64_t ov::op::util::TopK_Base::get_axis() const {
uint64_t ov::op::util::TopKBase::get_axis() const {
NODE_VALIDATION_CHECK(this, m_normalized_axis != UNKNOWN_NORMALIZED_AXIS, "Normalized axis of TopK is unknown");

return m_normalized_axis;
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/onnx/tests/onnx_tensor_names.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ NGRAPH_TEST(onnx_tensor_names, node_multiple_outputs) {

const auto ops = function->get_ordered_ops();
EXPECT_TRUE(matching_node_found_in_graph<op::Parameter>(ops, "x", {"x"}));
EXPECT_TRUE(matching_node_found_in_graph<ov::op::util::TopK_Base>(ops, "indices", {"values"}, 0));
EXPECT_TRUE(matching_node_found_in_graph<ov::op::util::TopK_Base>(ops, "indices", {"indices"}, 1));
EXPECT_TRUE(matching_node_found_in_graph<ov::op::util::TopKBase>(ops, "indices", {"values"}, 0));
EXPECT_TRUE(matching_node_found_in_graph<ov::op::util::TopKBase>(ops, "indices", {"indices"}, 1));

const auto results = function->get_results();
EXPECT_TRUE(matching_node_found_in_graph<op::Result>(results, "values/sink_port_0", {"values"}));
Expand Down

0 comments on commit e5ef0fe

Please sign in to comment.