Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core]Migrate Gathers operators to new API #21390

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/core/include/openvino/op/util/gather_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ class OPENVINO_API GatherBase : public Op {
void validate_and_infer_types() override;
virtual int64_t get_axis() const;

OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;

OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool evaluate_lower(TensorVector& outputs) const override;
bool evaluate_upper(TensorVector& outputs) const override;
bool evaluate_label(TensorLabelVector& output_labels) const override;
Expand Down
6 changes: 2 additions & 4 deletions src/core/reference/include/openvino/reference/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void gather(const T* const data,

int64_t batch_data_mul = shape_size(span(data_shape).subspan(batch_dims));
int64_t batch_out_mul = shape_size(span(out_shape).subspan(batch_dims));
int64_t batch_indices_mul = shape_size(span(indices_shape).subspan(batch_dims));

int64_t axis_size = data_shape[axis];
int64_t data_offset, out_offset, idx;
Expand All @@ -40,17 +39,16 @@ void gather(const T* const data,
data_offset = batch_data_mul * batch + inner_size * axis_size * outer_idx;
out_offset = batch_out_mul * batch + indices_size * inner_size * outer_idx;
for (int64_t i = 0; i < indices_size; i++) {
idx = indices[i + batch_indices_mul * batch];
idx = indices[i + indices_size * batch];
if (idx < 0)
idx += axis_size;
// for out of bound values have to be filled with zeros
if (idx >= axis_size || idx < 0)
continue;

const auto src_begin = std::next(data, data_offset + inner_size * idx);
const auto src_end = std::next(src_begin, inner_size);
const auto out_ptr = std::next(out, out_offset + inner_size * i);
std::copy(src_begin, src_end, out_ptr);
std::copy_n(src_begin, inner_size, out_ptr);
}
}
}
Expand Down
64 changes: 33 additions & 31 deletions src/core/src/op/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,37 @@
#include "openvino/op/gather.hpp"

#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "validation_util.hpp"

namespace ov {

op::v1::Gather::Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axes)
namespace op {
namespace v1 {
Gather::Gather(const Output<Node>& params, const Output<Node>& indices, const Output<Node>& axes)
: GatherBase(params, indices, axes) {
constructor_validate_and_infer_types();
}

int64_t op::v1::Gather::get_axis() const {
OPENVINO_SUPPRESS_DEPRECATED_START
if (!get_constant_from_source(input_value(2))) {
OPENVINO_SUPPRESS_DEPRECATED_END
return AXIS_NOT_SET_VALUE;
}
return GatherBase::get_axis();
int64_t Gather::get_axis() const {
return ov::util::get_constant_from_source(input_value(2)) ? GatherBase::get_axis() : AXIS_NOT_SET_VALUE;
}

std::shared_ptr<Node> op::v1::Gather::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Gather::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_Gather_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<v1::Gather>(new_args.at(0), new_args.at(1), new_args.at(2));
return std::make_shared<Gather>(new_args.at(0), new_args.at(1), new_args.at(2));
}
} // namespace v1

op::v7::Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
namespace v7 {
Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
: GatherBase(data, indices, axis, batch_dims) {
constructor_validate_and_infer_types();
}

void op::v7::Gather::validate_and_infer_types() {
void Gather::validate_and_infer_types() {
OV_OP_SCOPE(v7_Gather_validate_and_infer_types);
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).is_integral_number(),
Expand All @@ -47,37 +45,39 @@ void op::v7::Gather::validate_and_infer_types() {
get_input_element_type(2).is_integral_number(),
"Axis element type must be of an integral number type.");

op::util::GatherBase::validate_and_infer_types();
util::GatherBase::validate_and_infer_types();
}

int64_t op::v7::Gather::get_batch_dims() const {
int64_t Gather::get_batch_dims() const {
if (m_batch_dims < 0 && get_input_partial_shape(1).rank().is_static())
return m_batch_dims + get_input_partial_shape(1).rank().get_length();
else
return m_batch_dims;
}

bool op::v7::Gather::visit_attributes(AttributeVisitor& visitor) {
bool Gather::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v7_Gather_visit_attributes);
visitor.on_attribute("batch_dims", m_batch_dims);
return true;
}

std::shared_ptr<Node> op::v7::Gather::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Gather::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v7_Gather_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<v7::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
return std::make_shared<Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
}
} // namespace v7

op::v8::Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
namespace v8 {
Gather::Gather(const Output<Node>& data,
const Output<Node>& indices,
const Output<Node>& axis,
const int64_t batch_dims)
: GatherBase(data, indices, axis, batch_dims) {
constructor_validate_and_infer_types();
}

void op::v8::Gather::validate_and_infer_types() {
void Gather::validate_and_infer_types() {
OV_OP_SCOPE(v8_Gather_validate_and_infer_types);
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).is_integral_number(),
Expand All @@ -90,22 +90,24 @@ void op::v8::Gather::validate_and_infer_types() {
op::util::GatherBase::validate_and_infer_types();
}

int64_t op::v8::Gather::get_batch_dims() const {
int64_t Gather::get_batch_dims() const {
if (m_batch_dims < 0 && get_input_partial_shape(1).rank().is_static())
return m_batch_dims + get_input_partial_shape(1).rank().get_length();
else
return m_batch_dims;
}

bool op::v8::Gather::visit_attributes(AttributeVisitor& visitor) {
bool Gather::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v8_Gather_visit_attributes);
visitor.on_attribute("batch_dims", m_batch_dims);
return true;
}

std::shared_ptr<Node> op::v8::Gather::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Gather::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v8_Gather_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<v8::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
return std::make_shared<Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
}
} // namespace v8
} // namespace op
} // namespace ov
Loading
Loading