Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Op][Core] Add ScatterNDUpdate-14 core and reference #23754

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a575dfb
[Op][Core] Add ScatterNDUpdate-14 core and reference
mmikolajcz Mar 28, 2024
25595f7
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Apr 4, 2024
0294bab
Switch from label to symbols
mmikolajcz Apr 4, 2024
c8b50a7
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Apr 4, 2024
8fc49ad
Add test cases for negative and duplicate indices
mmikolajcz Apr 4, 2024
ec0357b
Fix opset14 test
mmikolajcz Apr 5, 2024
5c72805
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Apr 5, 2024
aaa6b10
Merge branch 'mateuszm/op/scatternd/core' of https://github.com/mmiko…
mmikolajcz Apr 8, 2024
909a7e6
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Apr 8, 2024
5c778ca
Re-use existing references as reductions
mmikolajcz Apr 9, 2024
1e27f1a
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Apr 9, 2024
b3c4423
Reduce binary size
mmikolajcz Apr 10, 2024
231ec81
Remove copy
mmikolajcz Apr 10, 2024
94a14e6
Add todo about future improvements
mmikolajcz Apr 11, 2024
472bc1e
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Apr 11, 2024
6c4e598
Fix code style
mmikolajcz Apr 11, 2024
da7eb6b
Add requested changes
mmikolajcz Apr 15, 2024
2c472aa
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Apr 15, 2024
16b3fa3
Try to fix duplicate symbol error in CI
mmikolajcz Apr 15, 2024
f31fc48
Use enable_if
mmikolajcz Apr 16, 2024
9050404
Remove unused include
mmikolajcz Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/core/include/openvino/op/scatter_nd_update.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,54 @@ class OPENVINO_API ScatterNDUpdate : public util::ScatterNDBase {
bool has_evaluate() const override;
};
} // namespace v3
namespace v14 {
/// \brief Add updates to slices from inputs addressed by indices
/// \ingroup ov_ops_cpp_api
class OPENVINO_API ScatterNDUpdate : public util::ScatterNDBase {
public:
OPENVINO_OP("ScatterNDUpdate", "opset14", util::ScatterNDBase);

/// \brief Lists the supported reduction types for this version of the operator.
/// See the specification for the description of how reduction works with ScatterNDUpdate.
enum class Reduction { NONE, SUM, SUB, PROD, MIN, MAX };

ScatterNDUpdate() = default;
/// \param inputs Tensor
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param updates Tensor: Must have same type as inputs
/// \param reduction Reduction: Type of operation to perform on inputs
ScatterNDUpdate(const Output<Node>& inputs,
const Output<Node>& indices,
const Output<Node>& updates,
const Reduction reduction = Reduction::NONE);

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool evaluate_lower(TensorVector& output_values) const override;
bool evaluate_upper(TensorVector& output_values) const override;
bool evaluate_symbol(TensorSymbolVector& output_symbols) const override;
bool has_evaluate() const override;

Reduction get_reduction() const;

void set_reduction(const Reduction reduction);

private:
Reduction m_reduction = Reduction::NONE;
};
} // namespace v14
OPENVINO_API
std::ostream& operator<<(std::ostream& s, const v14::ScatterNDUpdate::Reduction& reduction);

} // namespace op
template <>
class OPENVINO_API AttributeAdapter<op::v14::ScatterNDUpdate::Reduction>
: public EnumAttributeAdapterBase<op::v14::ScatterNDUpdate::Reduction> {
public:
AttributeAdapter(op::v14::ScatterNDUpdate::Reduction& value)
: EnumAttributeAdapterBase<op::v14::ScatterNDUpdate::Reduction>(value) {}

OPENVINO_RTTI("AttributeAdapter<v14::ScatterNDUpdate::Reduction>");
};
} // namespace ov
2 changes: 1 addition & 1 deletion src/core/include/openvino/opsets/opset14_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ _OPENVINO_OP_REG(Reshape, ov::op::v1)
_OPENVINO_OP_REG(Result, ov::op::v0)
_OPENVINO_OP_REG(ReverseSequence, ov::op::v0)
_OPENVINO_OP_REG(ROIPooling, ov::op::v0)
_OPENVINO_OP_REG(ScatterNDUpdate, ov::op::v3)
_OPENVINO_OP_REG(Select, ov::op::v1)
_OPENVINO_OP_REG(Selu, ov::op::v0)
_OPENVINO_OP_REG(Sign, ov::op::v0)
Expand Down Expand Up @@ -221,3 +220,4 @@ _OPENVINO_OP_REG(FakeConvert, ov::op::v13)
// New operations added in opset14
_OPENVINO_OP_REG(ConvertPromoteTypes, ov::op::v14)
_OPENVINO_OP_REG(Inverse, ov::op::v14)
_OPENVINO_OP_REG(ScatterNDUpdate, ov::op::v14)
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,93 @@
#include <cstring>
#include <numeric>

#include "add.hpp"
#include "and.hpp"
#include "maximum.hpp"
#include "minimum.hpp"
#include "multiply.hpp"
mmikolajcz marked this conversation as resolved.
Show resolved Hide resolved
#include "openvino/core/shape.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "or.hpp"
#include "subtract.hpp"
mmikolajcz marked this conversation as resolved.
Show resolved Hide resolved
#include "utils/span.hpp"
#include "xor.hpp"

namespace ov {
namespace reference {
using Reduction = ov::op::v14::ScatterNDUpdate::Reduction;
template <typename T>
using reduction_function = T (*)(const T, const T);

namespace func {
// TODO move this functions to other reference implementations to reduce binary size. Binary for
// ScatterElementsUpdate-12 can also be updated. Ticket: CVS-138266
template <class T>
constexpr T add(const T a, const T b) {
return a + b;
}
template <class T>
constexpr T subtract(const T a, const T b) {
return a - b;
}

template <class T>
constexpr T logical_and(const T a, const T b) {
return static_cast<bool>(a) && static_cast<bool>(b);
}

template <class T>
constexpr T logical_or(const T a, const T b) {
return static_cast<bool>(a) || static_cast<bool>(b);
}

} // namespace func

template <typename T>
reduction_function<T> reduction_functor_for(Reduction reduction_type) {
switch (reduction_type) {
case Reduction::MAX:
return func::max<T>;
case Reduction::MIN:
return func::min<T>;
case Reduction::PROD:
return func::multiply<T>;
case Reduction::SUM:
return func::add<T>;
case Reduction::SUB:
return func::subtract<T>;
case Reduction::NONE:
default:
return nullptr;
}
}

template <>
reduction_function<char> reduction_functor_for<char>(const Reduction reduction_type) {
switch (reduction_type) {
case Reduction::MIN:
case Reduction::PROD:
return func::logical_and<char>;
case Reduction::SUM:
case Reduction::MAX:
return func::logical_or<char>;
case Reduction::SUB:
return func::logical_xor<char>;
case Reduction::NONE:
default:
return nullptr;
}
}

template <typename dataType, typename indicesType>
void scatterNdUpdate(const dataType* const inputData,
const indicesType* const indices,
const dataType* const updates,
dataType* const outBuf,
const Shape& dataShape,
const Shape& indicesShape,
const Shape& updatesShape) {
const Shape& updatesShape,
const Reduction reduction_type = Reduction::NONE) {
const auto update_chunk_shape = span(dataShape).drop_front(indicesShape.back());
const auto update_el_number = shape_size(update_chunk_shape);

Expand All @@ -32,9 +106,8 @@ void scatterNdUpdate(const dataType* const inputData,
};
return padding;
}();

const auto reduction = reduction_functor_for<dataType>(reduction_type);
std::vector<indicesType> indicesCopy(indices, indices + shape_size(indicesShape));

const auto num_of_updates = shape_size(span(indicesShape).drop_back(1));
for (size_t i = 0; i != num_of_updates; ++i) {
const auto indices_coord = indicesCopy.data() + i * indicesShape.back();
Expand All @@ -52,10 +125,17 @@ void scatterNdUpdate(const dataType* const inputData,
const auto out_index = std::inner_product(begin(coord), end(coord), begin(input_data_dim_pading), uint64_t(0));

const auto update_data = updates + i * update_el_number;
const auto update_mem_size = update_el_number * sizeof(dataType);
OPENVINO_ASSERT(out_index >= 0 && out_index + update_el_number <= shape_size(dataShape),
"Index is out of bounds");
std::memcpy(outBuf + out_index, update_data, update_mem_size);
if (reduction) {
std::transform(outBuf + out_index,
outBuf + out_index + update_el_number,
update_data,
outBuf + out_index,
reduction);
} else {
std::memcpy(outBuf + out_index, update_data, update_el_number * sizeof(dataType));
}
}
}
} // namespace reference
Expand Down
141 changes: 130 additions & 11 deletions src/core/src/op/scatter_nd_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ struct Evaluate : public element::NoAction<bool> {
Tensor& output,
const Shape& data_shape,
const Shape& indices_shape,
const Shape& updates_shape) {
const Shape& updates_shape,
const v14::ScatterNDUpdate::Reduction reduction = v14::ScatterNDUpdate::Reduction::NONE) {
mmikolajcz marked this conversation as resolved.
Show resolved Hide resolved
using namespace ov::element;
return IF_TYPE_OF(sctter_nd_eval_idx_type,
OV_PP_ET_LIST(i32, i64),
Expand All @@ -34,34 +35,37 @@ struct Evaluate : public element::NoAction<bool> {
output.data<DT>(),
data_shape,
indices_shape,
updates_shape);
updates_shape,
reduction);
}

private:
struct EvaluateByIndicesType : public element::NoAction<bool> {
using element::NoAction<bool>::visit;

template <element::Type_t INDICES_ET, class DT, class IT = fundamental_type_for<INDICES_ET>>
static result_type visit(const DT* const data,
const Tensor& indices,
const DT* const updates,
DT* const output,
const Shape& data_shape,
const Shape& indices_shape,
const Shape& updates_shape) {
static result_type visit(
const DT* const data,
const Tensor& indices,
const DT* const updates,
DT* const output,
const Shape& data_shape,
const Shape& indices_shape,
const Shape& updates_shape,
const v14::ScatterNDUpdate::Reduction reduction = v14::ScatterNDUpdate::Reduction::NONE) {
mmikolajcz marked this conversation as resolved.
Show resolved Hide resolved
reference::scatterNdUpdate(data,
indices.data<IT>(),
updates,
output,
data_shape,
indices_shape,
updates_shape);
updates_shape,
reduction);
return true;
}
};
};
} // namespace scatter_nd_update

namespace v3 {
std::shared_ptr<Node> ScatterNDUpdate::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v3_ScatterNDUpdate_clone_with_new_inputs);
Expand Down Expand Up @@ -140,5 +144,120 @@ bool ScatterNDUpdate::evaluate_symbol(TensorSymbolVector& output_symbols) const
return default_symbol_evaluator(this, {0, 2}, output_symbols);
}
} // namespace v3

namespace v14 {
ScatterNDUpdate::ScatterNDUpdate(const Output<Node>& inputs,
const Output<Node>& indices,
const Output<Node>& updates,
const ScatterNDUpdate::Reduction reduction)
: op::util::ScatterNDBase(inputs, indices, updates),
m_reduction{reduction} {
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> ScatterNDUpdate::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v14_ScatterNDUpdate_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<ScatterNDUpdate>(new_args.at(0), new_args.at(1), new_args.at(2), m_reduction);
}

bool ScatterNDUpdate::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v14_ScatterNDUpdate_visit_attributes);
visitor.on_attribute("reduction", m_reduction);
return true;
}

bool ScatterNDUpdate::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v14_ScatterNDUpdate_evaluate);
OPENVINO_ASSERT(inputs.size() == 3);
OPENVINO_ASSERT(outputs.size() == 1);

const auto& data = inputs[0];
const auto& indices = inputs[1];
const auto& updates = inputs[2];
auto& output = outputs[0];
const auto& data_shape = data.get_shape();
const auto& indices_shape = indices.get_shape();
const auto& updates_shape = updates.get_shape();
output.set_shape(data_shape);
using namespace ov::element;
return IF_TYPE_OF_CONVERT_TENSORS(v14_ScatterNDUpdate_evaluate,
this,
outputs,
inputs,
OV_PP_ET_LIST(boolean, f32, i32, i64, u32, u64),
scatter_nd_update::Evaluate,
data.get_element_type(),
data,
indices,
updates,
output,
data_shape,
indices_shape,
updates_shape,
m_reduction);
mmikolajcz marked this conversation as resolved.
Show resolved Hide resolved
}

bool ScatterNDUpdate::has_evaluate() const {
OV_OP_SCOPE(v14_ScatterNDUpdate_has_evaluate);

switch (get_output_element_type(0)) {
case element::boolean:
case element::f16:
case element::f32:
case element::i32:
case element::i64:
case element::u32:
case element::u64:
break;
default:
return false;
}
switch (get_input_element_type(1)) {
case element::i32:
case element::i64:
return true;
default:
return false;
}
mmikolajcz marked this conversation as resolved.
Show resolved Hide resolved
}

ScatterNDUpdate::Reduction ScatterNDUpdate::get_reduction() const {
return m_reduction;
}

void ScatterNDUpdate::set_reduction(const ScatterNDUpdate::Reduction reduction) {
m_reduction = reduction;
}
bool ScatterNDUpdate::evaluate_lower(TensorVector& output_values) const {
OV_OP_SCOPE(v14_ScatterNDUpdate_evaluate_lower);
return get_input_tensor(1).has_and_set_bound() && default_lower_bound_evaluator(this, output_values);
}

bool ScatterNDUpdate::evaluate_upper(TensorVector& output_values) const {
OV_OP_SCOPE(v14_ScatterNDUpdate_evaluate_upper);
return get_input_tensor(1).has_and_set_bound() && default_upper_bound_evaluator(this, output_values);
}

bool ScatterNDUpdate::evaluate_symbol(TensorSymbolVector& output_symbols) const {
OV_OP_SCOPE(v14_ScatterNDUpdate_evaluate_symbol);
return default_symbol_evaluator(this, {0, 2}, output_symbols);
}

} // namespace v14
std::ostream& operator<<(std::ostream& s, const v14::ScatterNDUpdate::Reduction& reduction) {
return s << as_string(reduction);
}
} // namespace op
template <>
OPENVINO_API EnumNames<op::v14::ScatterNDUpdate::Reduction>& EnumNames<op::v14::ScatterNDUpdate::Reduction>::get() {
static auto enum_names =
EnumNames<op::v14::ScatterNDUpdate::Reduction>("op::v14::ScatterNDUpdate::Reduction",
{{"none", op::v14::ScatterNDUpdate::Reduction::NONE},
{"sum", op::v14::ScatterNDUpdate::Reduction::SUM},
{"sub", op::v14::ScatterNDUpdate::Reduction::SUB},
{"prod", op::v14::ScatterNDUpdate::Reduction::PROD},
{"min", op::v14::ScatterNDUpdate::Reduction::MIN},
{"max", op::v14::ScatterNDUpdate::Reduction::MAX}});
return enum_names;
}
} // namespace ov
Loading
Loading