Skip to content

Commit

Permalink
[core]Migrate Gathers operators to new API (openvinotoolkit#21390)
Browse files Browse the repository at this point in the history
* Migrate Gather operators to new API

* Remove redundant code form reference

* Use IF_TYPE_OF macro

* Remove unused include

* Use common utils in gather base

* Fix normalize after merge issues
  • Loading branch information
praasz authored and akuporos committed Dec 8, 2023
1 parent 34589a7 commit a8aee30
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 222 deletions.
5 changes: 1 addition & 4 deletions src/core/include/openvino/op/util/gather_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ class OPENVINO_API GatherBase : public Op {
void validate_and_infer_types() override;
virtual int64_t get_axis() const;

OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;

OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool evaluate_lower(TensorVector& outputs) const override;
bool evaluate_upper(TensorVector& outputs) const override;
bool evaluate_label(TensorLabelVector& output_labels) const override;
Expand Down
6 changes: 2 additions & 4 deletions src/core/reference/include/openvino/reference/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void gather(const T* const data,

int64_t batch_data_mul = shape_size(span(data_shape).subspan(batch_dims));
int64_t batch_out_mul = shape_size(span(out_shape).subspan(batch_dims));
int64_t batch_indices_mul = shape_size(span(indices_shape).subspan(batch_dims));

int64_t axis_size = data_shape[axis];
int64_t data_offset, out_offset, idx;
Expand All @@ -40,17 +39,16 @@ void gather(const T* const data,
data_offset = batch_data_mul * batch + inner_size * axis_size * outer_idx;
out_offset = batch_out_mul * batch + indices_size * inner_size * outer_idx;
for (int64_t i = 0; i < indices_size; i++) {
idx = indices[i + batch_indices_mul * batch];
idx = indices[i + indices_size * batch];
if (idx < 0)
idx += axis_size;
// for out of bound values have to be filled with zeros
if (idx >= axis_size || idx < 0)
continue;

const auto src_begin = std::next(data, data_offset + inner_size * idx);
const auto src_end = std::next(src_begin, inner_size);
const auto out_ptr = std::next(out, out_offset + inner_size * i);
std::copy(src_begin, src_end, out_ptr);
std::copy_n(src_begin, inner_size, out_ptr);
}
}
}
Expand Down
64 changes: 33 additions & 31 deletions src/core/src/op/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,37 @@
#include "openvino/op/gather.hpp"

#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "validation_util.hpp"

namespace ov {

op::v1::Gather::Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axes)
namespace op {
namespace v1 {
Gather::Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axes)
: GatherBase(params, indices, axes) {
constructor_validate_and_infer_types();
}

int64_t op::v1::Gather::get_axis() const {
OPENVINO_SUPPRESS_DEPRECATED_START
if (!get_constant_from_source(input_value(2))) {
OPENVINO_SUPPRESS_DEPRECATED_END
return AXIS_NOT_SET_VALUE;
}
return GatherBase::get_axis();
int64_t Gather::get_axis() const {
return ov::util::get_constant_from_source(input_value(2)) ? GatherBase::get_axis() : AXIS_NOT_SET_VALUE;
}

std::shared_ptr<Node> op::v1::Gather::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Gather::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_Gather_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<v1::Gather>(new_args.at(0), new_args.at(1), new_args.at(2));
return std::make_shared<Gather>(new_args.at(0), new_args.at(1), new_args.at(2));
}
} // namespace v1

op::v7::Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
namespace v7 {
Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
: GatherBase(data, indices, axis, batch_dims) {
constructor_validate_and_infer_types();
}

void op::v7::Gather::validate_and_infer_types() {
void Gather::validate_and_infer_types() {
OV_OP_SCOPE(v7_Gather_validate_and_infer_types);
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).is_integral_number(),
Expand All @@ -47,37 +45,39 @@ void op::v7::Gather::validate_and_infer_types() {
get_input_element_type(2).is_integral_number(),
"Axis element type must be of an integral number type.");

op::util::GatherBase::validate_and_infer_types();
util::GatherBase::validate_and_infer_types();
}

int64_t op::v7::Gather::get_batch_dims() const {
int64_t Gather::get_batch_dims() const {
if (m_batch_dims < 0 && get_input_partial_shape(1).rank().is_static())
return m_batch_dims + get_input_partial_shape(1).rank().get_length();
else
return m_batch_dims;
}

bool op::v7::Gather::visit_attributes(AttributeVisitor& visitor) {
bool Gather::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v7_Gather_visit_attributes);
visitor.on_attribute("batch_dims", m_batch_dims);
return true;
}

std::shared_ptr<Node> op::v7::Gather::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Gather::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v7_Gather_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<v7::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
return std::make_shared<Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
}
} // namespace v7

op::v8::Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
namespace v8 {
Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
: GatherBase(data, indices, axis, batch_dims) {
constructor_validate_and_infer_types();
}

void op::v8::Gather::validate_and_infer_types() {
void Gather::validate_and_infer_types() {
OV_OP_SCOPE(v8_Gather_validate_and_infer_types);
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).is_integral_number(),
Expand All @@ -90,22 +90,24 @@ void op::v8::Gather::validate_and_infer_types() {
op::util::GatherBase::validate_and_infer_types();
}

int64_t op::v8::Gather::get_batch_dims() const {
int64_t Gather::get_batch_dims() const {
if (m_batch_dims < 0 && get_input_partial_shape(1).rank().is_static())
return m_batch_dims + get_input_partial_shape(1).rank().get_length();
else
return m_batch_dims;
}

bool op::v8::Gather::visit_attributes(AttributeVisitor& visitor) {
bool Gather::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v8_Gather_visit_attributes);
visitor.on_attribute("batch_dims", m_batch_dims);
return true;
}

std::shared_ptr<Node> op::v8::Gather::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Gather::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v8_Gather_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<v8::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
return std::make_shared<Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
}
} // namespace v8
} // namespace op
} // namespace ov
Loading

0 comments on commit a8aee30

Please sign in to comment.