diff --git a/ngraph/core/include/ngraph/op/scatter_update.hpp b/ngraph/core/include/ngraph/op/scatter_update.hpp index f42fb9685fe4f7..25a4b94719e611 100644 --- a/ngraph/core/include/ngraph/op/scatter_update.hpp +++ b/ngraph/core/include/ngraph/op/scatter_update.hpp @@ -18,6 +18,7 @@ #include "ngraph/op/op.hpp" #include "ngraph/op/util/scatter_base.hpp" +#include "ngraph/runtime/host_tensor.hpp" namespace ngraph { @@ -49,6 +50,9 @@ namespace ngraph virtual std::shared_ptr clone_with_new_inputs(const OutputVector& inputs) const override; + + bool evaluate(const HostTensorVector& outputs, + const HostTensorVector& inputs) const override; }; } } diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp index e3cae8c014750b..f8d00b17ebfe50 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp @@ -1,86 +1,120 @@ -// Copyright (C) 2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation // +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** #pragma once -#include +#include "ngraph/check.hpp" #include "ngraph/coordinate_transform.hpp" #include "ngraph/shape.hpp" -using namespace ngraph; - namespace ngraph { namespace runtime { namespace reference { - template - void scatterUpdate(const dataType* inputData, - const indicesType* indices, - const dataType* updates, - const axisType* _axis, - dataType* outBuf, - const Shape& dataShape, - const Shape& indicesShape, - const Shape& updatesShape) + void scatter_update(const char* input_data, + const int64_t* indices, + const char* updates, + const int64_t axis, + char* out_buf, + const size_t elem_size, + const Shape& data_shape, + const Shape& indices_shape, + const Shape& updates_shape) { - int rank = static_cast(dataShape.size()); - if (_axis[0] < -rank || _axis[0] > rank - 1) - { - std::string error = - std::string("ScatterUpdate layer has out of bounds axis value: ") + - std::to_string(_axis[0]); - throw ngraph_error(error); - } - size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0]; - CoordinateTransform indicesTransform{indicesShape}; + // Copy inputs to out + std::memcpy(out_buf, input_data, elem_size * shape_size(data_shape)); + + // Algorithm overview + // data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...] + // where first ... in the data corresponds to first axis dimensions, + // last ... in the data corresponds to the rank(data) - (axis + 1) dimensions. + + // + // for i_coord in indices[m, n, ..., p]: + // # get linear index + // i_idx = index(i_coord) + // # simultaneously iterate over two slices of data with same elements count + // for d_coord in slice data[..., i_idx, ...], + // u_coord in slice updates[..., i_coord, ...] + // data[index(d_coord)] = updates[index(u_coord)] - Shape dataShapeIter = dataShape; - dataShapeIter.erase(dataShapeIter.begin() + axis); - CoordinateTransform dataTransfIter{dataShapeIter}; + CoordinateTransform indices_transform{indices_shape}; + CoordinateTransform data_transform{data_shape}; - CoordinateTransform updateTransform{updatesShape}; - CoordinateTransform dataTransform{dataShape}; + size_t indices_ndim = indices_shape.size(); + size_t updates_ndim = updates_shape.size(); - std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape)); + // Create an outer CoordinateTransform for "update", which would allow to + // iterate only over "indices" dimensions: + // set to "1" all non-indices dimensions + // updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1] + Coordinate updates_indices_start_corner(updates_ndim, 0); + Coordinate updates_indices_end_corner(updates_ndim, 1); + for (size_t i = 0; i < indices_ndim; ++i) + { + updates_indices_end_corner[axis + i] = updates_shape[axis + i]; + } + CoordinateTransform updates_indices_transform( + updates_shape, updates_indices_start_corner, updates_indices_end_corner); + // Is needed to simultaneously iterate over updates coordinates while + // iterating over indices. + auto updates_indices_coord_iter = updates_indices_transform.begin(); - for (const Coordinate& indicesCoordIt : indicesTransform) + for (const Coordinate& indices_cord : indices_transform) { - const size_t indicesIdx = indicesTransform.index(indicesCoordIt); + const size_t indices_idx = indices_transform.index(indices_cord); + int64_t slice_index = indices[indices_idx]; - if (indices[indicesIdx] < 0) - { - std::string error = - std::string("ScatterUpdate layer has negative index value: ") + - std::to_string(indices[indicesIdx]); - throw ngraph_error(error); - } - const size_t idx = static_cast(indices[indicesIdx]); - if (dataShape[axis] <= idx) + // Define the extent of coordinates which will be updated. + Coordinate out_start_corner(data_shape.size(), 0); + Coordinate out_end_corner(data_shape); + out_start_corner[axis] = static_cast(slice_index); + out_end_corner[axis] = out_start_corner[axis] + 1; + CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner); + + // Define the CoordinateTransform for updates coordinates. + // All except indices-dimensions. + Coordinate updates_update_start_corner = *updates_indices_coord_iter; + Coordinate updates_update_end_corner(updates_shape); + for (size_t i = 0; i < indices_ndim; ++i) { - std::string error = - std::string("ScatterUpdate layer has out of bounds coordinate: ") + - std::to_string(idx) + " on 'data' input on " + std::to_string(axis) + - "th axis"; - throw ngraph_error(error); + updates_update_end_corner[axis + i] = + updates_update_start_corner[axis + i] + 1; } - - for (const Coordinate& dataCoordIt : dataTransfIter) + // The m, n, .., p symbols stand for values at those axes. + // The m+1 means value at axis m plus 1. + // udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0] + // updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1] + CoordinateTransform updates_update_transform( + updates_shape, updates_update_start_corner, updates_update_end_corner); + auto updates_update_coord_iter = updates_update_transform.begin(); + for (const Coordinate& out_cord : out_transform) { - Coordinate dataCoord = dataCoordIt; - dataCoord.insert(dataCoord.begin() + axis, idx); - const size_t startIndices = dataTransform.index(dataCoord); - - auto updCoord = dataCoordIt; - updCoord.insert( - updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end()); - const size_t startUpd = updateTransform.index(updCoord); - outBuf[startIndices] = updates[startUpd]; + const auto src_idx = + updates_update_transform.index(*updates_update_coord_iter) * elem_size; + std::copy(updates + src_idx, + updates + (src_idx + elem_size), + out_buf + out_transform.index(out_cord) * elem_size); + updates_update_coord_iter++; } + updates_indices_coord_iter++; } } - } // namespace reference - } // namespace runtime -} // namespace ngraph + } + } +} diff --git a/ngraph/core/src/op/scatter_update.cpp b/ngraph/core/src/op/scatter_update.cpp index 1600c8f30a1d8f..1ebf07e52995ec 100644 --- a/ngraph/core/src/op/scatter_update.cpp +++ b/ngraph/core/src/op/scatter_update.cpp @@ -15,7 +15,11 @@ //***************************************************************************** #include "ngraph/op/scatter_update.hpp" +#include "ngraph/runtime/reference/scatter_update.hpp" #include "ngraph/shape.hpp" +#include "ngraph/type/element_type.hpp" +#include "ngraph/type/element_type_traits.hpp" +#include "ngraph/validation_util.hpp" using namespace std; using namespace ngraph; @@ -36,3 +40,110 @@ shared_ptr op::v3::ScatterUpdate::clone_with_new_inputs(const OutputVector return make_shared( new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3)); } + +bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs, + const HostTensorVector& inputs) const +{ + const auto& data = inputs[0]; + const auto& indices = inputs[1]; + const auto& updates = inputs[2]; + const auto& axis = inputs[3]; + const auto& out = outputs[0]; + + const auto elem_size = data->get_element_type().size(); + out->set_shape(data->get_shape()); + + int64_t axis_val = 0; + switch (axis->get_element_type()) + { + case element::Type_t::i8: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::i16: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::i32: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::i64: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::u8: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::u16: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::u32: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::u64: axis_val = axis->get_data_ptr()[0]; break; + default: throw ngraph_error("axis element type is not integral data type"); + } + + if (axis_val < 0) + { + axis_val = + ngraph::normalize_axis(this, axis_val, static_cast(data->get_shape().size())); + } + + std::vector indices_casted_vector; + switch (indices->get_element_type()) + { + case element::Type_t::i8: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::i16: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::i32: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::i64: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::u8: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::u16: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::u32: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::u64: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + default: throw ngraph_error("indices element type is not integral data type"); + } + + runtime::reference::scatter_update(data->get_data_ptr(), + indices_casted_vector.data(), + updates->get_data_ptr(), + axis_val, + out->get_data_ptr(), + elem_size, + data->get_shape(), + indices->get_shape(), + updates->get_shape()); + + return true; +} diff --git a/ngraph/test/eval.cpp b/ngraph/test/eval.cpp index e32fdcbe40469b..792ee3892c2634 100644 --- a/ngraph/test/eval.cpp +++ b/ngraph/test/eval.cpp @@ -54,6 +54,7 @@ #include "ngraph/op/reshape.hpp" #include "ngraph/op/round.hpp" #include "ngraph/op/scatter_elements_update.hpp" +#include "ngraph/op/scatter_update.hpp" #include "ngraph/op/shape_of.hpp" #include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sign.hpp" @@ -1937,3 +1938,180 @@ TEST(eval, reduce_logical_and__neg_axis) }), ngraph::ngraph_error); } + +TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i32) +{ + const Shape data_shape{3, 3}; + const Shape indices_shape{1, 2}; + const Shape updates_shape{1, 2, 3}; + + auto arg1 = make_shared(element::f32, data_shape); + auto arg2 = make_shared(element::i32, indices_shape); + auto arg3 = make_shared(element::f32, updates_shape); + auto arg4 = make_shared(element::i32, Shape{}); + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE(fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1, 2}), + make_host_tensor( + updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), + make_host_tensor({}, {0})})); + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3})); + auto cval = read_vector(result_tensor); + vector out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}; + ASSERT_EQ(cval, out); +} + +TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i64) +{ + const Shape data_shape{3, 3}; + const Shape indices_shape{1, 2}; + const Shape updates_shape{1, 2, 3}; + + auto arg1 = make_shared(element::f32, data_shape); + auto arg2 = make_shared(element::i64, indices_shape); + auto arg3 = make_shared(element::f32, updates_shape); + auto arg4 = make_shared(element::i64, Shape{}); + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE(fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1, 2}), + make_host_tensor( + updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), + make_host_tensor({}, {0})})); + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3})); + auto cval = read_vector(result_tensor); + vector out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}; + ASSERT_EQ(cval, out); +} + +TEST(eval, evaluate_dynamic_scatter_update_basic) +{ + const Shape data_shape{3, 3}; + const Shape indices_shape{1, 2}; + const Shape updates_shape{1, 2, 3}; + + auto arg1 = make_shared(element::f32, PartialShape::dynamic()); + auto arg2 = make_shared(element::i32, PartialShape::dynamic()); + auto arg3 = make_shared(element::f32, PartialShape::dynamic()); + auto arg4 = make_shared(element::i64, PartialShape::dynamic()); + + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE(fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1, 2}), + make_host_tensor( + updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), + make_host_tensor({}, {0})})); + + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3})); + auto cval = read_vector(result_tensor); + vector out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}; + ASSERT_EQ(cval, out); +} + +TEST(eval, evaluate_dynamic_scatter_update_negative_axis) +{ + const Shape data_shape{3, 3}; + const Shape indices_shape{1, 2}; + const Shape updates_shape{3, 1, 2}; + const Shape axis_shape{}; + + auto arg1 = make_shared(element::f32, PartialShape::dynamic()); + auto arg2 = make_shared(element::i32, PartialShape::dynamic()); + auto arg3 = make_shared(element::f32, PartialShape::dynamic()); + auto arg4 = make_shared(element::i64, PartialShape::dynamic()); + + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE(fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1, 2}), + make_host_tensor( + updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), + make_host_tensor(axis_shape, {-1})})); + + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3})); + auto cval = read_vector(result_tensor); + vector out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f}; + ASSERT_EQ(cval, out); +} + +TEST(eval, evaluate_dynamic_scatter_update_1d_axis) +{ + const Shape data_shape{3, 3}; + const Shape indices_shape{1, 2}; + const Shape updates_shape{3, 1, 2}; + + auto arg1 = make_shared(element::f32, PartialShape::dynamic()); + auto arg2 = make_shared(element::i32, PartialShape::dynamic()); + auto arg3 = make_shared(element::f32, PartialShape::dynamic()); + auto arg4 = make_shared(element::i64, PartialShape::dynamic()); + + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE(fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1, 2}), + make_host_tensor( + updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), + make_host_tensor({1}, {1})})); + + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3})); + auto cval = read_vector(result_tensor); + vector out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f}; + ASSERT_EQ(cval, out); +} + +TEST(eval, evaluate_dynamic_scatter_update_one_elem_i32) +{ + const Shape data_shape{3, 3, 2}; + const Shape indices_shape{1, 1}; + const Shape updates_shape{1, 1, 3, 2}; + + auto arg1 = make_shared(element::i32, PartialShape::dynamic()); + auto arg2 = make_shared(element::i32, PartialShape::dynamic()); + auto arg3 = make_shared(element::i32, PartialShape::dynamic()); + auto arg4 = make_shared(element::i64, PartialShape::dynamic()); + + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE( + fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1}), + make_host_tensor(updates_shape, {1, 2, 3, 4, 5, 6}), + make_host_tensor({}, {0})})); + + EXPECT_EQ(result_tensor->get_element_type(), element::i32); + EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3, 2})); + auto cval = read_vector(result_tensor); + vector out{0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}; + ASSERT_EQ(cval, out); +} diff --git a/ngraph/test/runtime/interpreter/int_executable.hpp b/ngraph/test/runtime/interpreter/int_executable.hpp index e1230c844df40c..20152a8256be56 100644 --- a/ngraph/test/runtime/interpreter/int_executable.hpp +++ b/ngraph/test/runtime/interpreter/int_executable.hpp @@ -79,7 +79,6 @@ #include "ngraph/runtime/reference/reverse_sequence.hpp" #include "ngraph/runtime/reference/round.hpp" #include "ngraph/runtime/reference/scatter_nd_update.hpp" -#include "ngraph/runtime/reference/scatter_update.hpp" #include "ngraph/runtime/reference/select.hpp" #include "ngraph/runtime/reference/sigmoid.hpp" #include "ngraph/runtime/reference/sign.hpp" @@ -1195,48 +1194,6 @@ class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : publ break; } - case OP_TYPEID::ScatterUpdate_v3: - { - const op::v3::ScatterUpdate* scatterUpd = - static_cast(&node); - - if (scatterUpd->get_input_element_type(3) != element::i64) - throw ngraph_error( - "ScatterNDUpdate layer support only i64 'axis' input precision!"); - - auto idxType = scatterUpd->get_input_element_type(1); - if (idxType == element::i32) - { - reference::scatterUpdate( - args[0]->get_data_ptr(), - args[1]->get_data_ptr(), - args[2]->get_data_ptr(), - args[3]->get_data_ptr(), - out[0]->get_data_ptr(), - node.get_input_shape(0), - node.get_input_shape(1), - node.get_input_shape(2)); - } - else if (idxType == element::i64) - { - reference::scatterUpdate( - args[0]->get_data_ptr(), - args[1]->get_data_ptr(), - args[2]->get_data_ptr(), - args[3]->get_data_ptr(), - out[0]->get_data_ptr(), - node.get_input_shape(0), - node.get_input_shape(1), - node.get_input_shape(2)); - } - else - { - throw ngraph_error( - "ScatterUpdate layer support only i32 and i64 'indices' input precision!"); - } - - break; - } // Fused Ops are not supported in interpreter. They need to be decomposed before execution case OP_TYPEID::DepthToSpace: @@ -1255,6 +1212,7 @@ class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : publ case OP_TYPEID::NormalizeL2: case OP_TYPEID::PRelu: case OP_TYPEID::RNNCell: + case OP_TYPEID::ScatterUpdate_v3: case OP_TYPEID::Selu: case OP_TYPEID::ShuffleChannels: case OP_TYPEID::SpaceToDepth: