From 882b20d64ba329675e7c7ec743e093c942f87c35 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 7 Aug 2020 10:46:54 +0200 Subject: [PATCH 1/6] Reference implementation for ScatterUpdate and use of it in evaluate. --- ngraph/src/ngraph/op/scatter_update.cpp | 133 ++++++++++++++++ ngraph/src/ngraph/op/scatter_update.hpp | 4 + .../runtime/reference/scatter_update.hpp | 115 ++++++++++++++ ngraph/test/eval.cpp | 150 ++++++++++++++++++ 4 files changed, 402 insertions(+) create mode 100644 ngraph/src/ngraph/runtime/reference/scatter_update.hpp diff --git a/ngraph/src/ngraph/op/scatter_update.cpp b/ngraph/src/ngraph/op/scatter_update.cpp index 1600c8f30a1d8f..71f066fc737f70 100644 --- a/ngraph/src/ngraph/op/scatter_update.cpp +++ b/ngraph/src/ngraph/op/scatter_update.cpp @@ -15,7 +15,11 @@ //***************************************************************************** #include "ngraph/op/scatter_update.hpp" +#include "ngraph/runtime/reference/scatter_update.hpp" #include "ngraph/shape.hpp" +#include "ngraph/type/element_type.hpp" +#include "ngraph/type/element_type_traits.hpp" +#include "ngraph/validation_util.hpp" using namespace std; using namespace ngraph; @@ -36,3 +40,132 @@ shared_ptr op::v3::ScatterUpdate::clone_with_new_inputs(const OutputVector return make_shared( new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3)); } + +namespace +{ + template + bool evaluate(const HostTensorPtr& data, + const HostTensorPtr& indices, + const HostTensorPtr& updates, + const HostTensorPtr& out, + const int64_t normalized_axis) + { + using DataType = typename element_type_traits
::value_type; + using IndicesType = typename element_type_traits::value_type; + + out->set_shape(data->get_shape()); + runtime::reference::scatter_update( + data->get_data_ptr(), + indices->get_data_ptr(), + updates->get_data_ptr(), + normalized_axis, + out->get_data_ptr(), + data->get_shape(), + indices->get_shape(), + updates->get_shape()); + + return true; + } + + template + bool evaluate(const HostTensorPtr& data, + const HostTensorPtr& indices, + const HostTensorPtr& updates, + const HostTensorPtr& out, + const int64_t normalized_axis) + { + // Dispatch specialization based on indicies data type. + bool rc = true; + + switch (indices->get_element_type()) + { + case element::Type_t::i8: + case element::Type_t::u8: + rc = evaluate(data, indices, updates, out, normalized_axis); + break; + case element::Type_t::i16: + case element::Type_t::u16: + rc = evaluate(data, indices, updates, out, normalized_axis); + break; + case element::Type_t::i32: + case element::Type_t::u32: + rc = evaluate(data, indices, updates, out, normalized_axis); + break; + case element::Type_t::i64: + case element::Type_t::u64: + rc = evaluate(data, indices, updates, out, normalized_axis); + break; + default: rc = false; break; + } + return rc; + } + + bool evaluate_scatter_update(const HostTensorPtr& data, + const HostTensorPtr& indices, + const HostTensorPtr& updates, + const HostTensorPtr& out, + const int64_t normalized_axis) + { + // Dispatch based on data, updates and output data type. + bool rc = true; + switch (out->get_element_type()) + { + case element::Type_t::i32: + case element::Type_t::u32: + rc = evaluate(data, indices, updates, out, normalized_axis); + break; + case element::Type_t::i64: + case element::Type_t::u64: + rc = evaluate(data, indices, updates, out, normalized_axis); + break; + TYPE_CASE(f16)(data, indices, updates, out, normalized_axis); + break; + TYPE_CASE(f32)(data, indices, updates, out, normalized_axis); + break; + default: rc = false; break; + } + return rc; + } +} + +bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs, + const HostTensorVector& inputs) const +{ + const auto& data = inputs[0]; + const auto& indices = inputs[1]; + const auto& updates = inputs[2]; + const auto& axis = inputs[3]; + const auto& out = outputs[0]; + + int64_t axis_val = 0; + switch (axis->get_element_type()) + { + case element::Type_t::i8: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::i16: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::i32: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::i64: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::u8: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::u16: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::u32: axis_val = axis->get_data_ptr()[0]; break; + case element::Type_t::u64: axis_val = axis->get_data_ptr()[0]; break; + default: throw ngraph_error("axis element type is not integral data type"); + } + + const auto& input_rank = get_input_partial_shape(0).rank(); + int64_t normalized_axis = axis_val; + + if (normalized_axis < 0) + { + if (input_rank.is_static()) + { + normalized_axis = ngraph::normalize_axis(this, axis_val, input_rank); + } + else + { + normalized_axis = ngraph::normalize_axis( + this, axis_val, static_cast(data->get_shape().size())); + } + } + + return evaluate_scatter_update(data, indices, updates, out, normalized_axis); +} diff --git a/ngraph/src/ngraph/op/scatter_update.hpp b/ngraph/src/ngraph/op/scatter_update.hpp index f42fb9685fe4f7..25a4b94719e611 100644 --- a/ngraph/src/ngraph/op/scatter_update.hpp +++ b/ngraph/src/ngraph/op/scatter_update.hpp @@ -18,6 +18,7 @@ #include "ngraph/op/op.hpp" #include "ngraph/op/util/scatter_base.hpp" +#include "ngraph/runtime/host_tensor.hpp" namespace ngraph { @@ -49,6 +50,9 @@ namespace ngraph virtual std::shared_ptr clone_with_new_inputs(const OutputVector& inputs) const override; + + bool evaluate(const HostTensorVector& outputs, + const HostTensorVector& inputs) const override; }; } } diff --git a/ngraph/src/ngraph/runtime/reference/scatter_update.hpp b/ngraph/src/ngraph/runtime/reference/scatter_update.hpp new file mode 100644 index 00000000000000..274ece922cc3e6 --- /dev/null +++ b/ngraph/src/ngraph/runtime/reference/scatter_update.hpp @@ -0,0 +1,115 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include + +#include "ngraph/check.hpp" +#include "ngraph/coordinate_transform.hpp" +#include "ngraph/shape.hpp" + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + template + void scatter_update(const DataType* input_data, + const IndicesType* indices, + const DataType* updates, + const int64_t& axis, + DataType* out_buf, + const Shape& data_shape, + const Shape& indices_shape, + const Shape& updates_shape) + { + // Copy inputs to out + std::memcpy(out_buf, input_data, sizeof(DataType) * shape_size(data_shape)); + + // Algorithm overview + // data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...] + // where first ... in the data corresponds to first axis dimensions, + // last ... in the data corresponds to the rank(data) - (axis + 1) dimensions. + + // + // for i_coord in indices[m, n, ..., p]: + // i_idx = index(i_coord) + // for d_coord in slice data[..., i_idx, ...] && + // for u_coord in slice updates[..., i_coord, ...]: + // data[index(d_coord)] = updates[index(u_coord)] + + CoordinateTransform indices_transform{indices_shape}; + CoordinateTransform data_transform{data_shape}; + + size_t indices_ndim = indices_shape.size(); + size_t updates_ndim = updates_shape.size(); + + // Create an outer CoordinateTransform for "update", which would allow to + // iterate only over "indicies" dimensions: + // set to "1" all non-indices dimensions + // updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1] + Coordinate updates_indices_start_corner(updates_ndim, 0); + Coordinate updates_indices_end_corner(updates_ndim, 1); + for (size_t i = 0; i < indices_ndim; ++i) + { + updates_indices_end_corner[axis + i] = updates_shape[axis + i]; + } + CoordinateTransform updates_indices_transform( + updates_shape, updates_indices_start_corner, updates_indices_end_corner); + // Is needed to simultaneously iterate over updates coordinates while + // iterating over indices. + auto updates_indices_coord_iter = updates_indices_transform.begin(); + + for (const Coordinate& indices_cord : indices_transform) + { + const size_t indices_idx = indices_transform.index(indices_cord); + IndicesType slice_index = indices[indices_idx]; + + // Define the extent of coordinates which will be updated. + Coordinate out_start_corner(data_shape.size(), 0); + Coordinate out_end_corner(data_shape); + out_start_corner[axis] = static_cast(slice_index); + out_end_corner[axis] = out_start_corner[axis] + 1; + CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner); + + // Define the CoordinateTransform for updates coordinates. + // All except indices-dimensions. + Coordinate updates_update_start_corner = *updates_indices_coord_iter; + Coordinate updates_update_end_corner(updates_shape); + for (size_t i = 0; i < indices_ndim; ++i) + { + updates_update_end_corner[axis + i] = + updates_update_start_corner[axis + i] + 1; + } + // udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0] + // updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1] + CoordinateTransform updates_update_transform( + updates_shape, updates_update_start_corner, updates_update_end_corner); + auto updates_update_coord_iter = updates_update_transform.begin(); + for (const Coordinate& out_cord : out_transform) + { + out_buf[out_transform.index(out_cord)] = + updates[updates_update_transform.index(*updates_update_coord_iter)]; + updates_update_coord_iter++; + } + updates_indices_coord_iter++; + } + } + } + } +} diff --git a/ngraph/test/eval.cpp b/ngraph/test/eval.cpp index a65829e784ab2c..68d0567f958605 100644 --- a/ngraph/test/eval.cpp +++ b/ngraph/test/eval.cpp @@ -56,6 +56,7 @@ #include "ngraph/op/reshape.hpp" #include "ngraph/op/round.hpp" #include "ngraph/op/scatter_elements_update.hpp" +#include "ngraph/op/scatter_update.hpp" #include "ngraph/op/shape_of.hpp" #include "ngraph/op/sigmoid.hpp" #include "ngraph/op/sign.hpp" @@ -1915,3 +1916,152 @@ TEST(eval, reduce_logical_and__neg_axis) }), ngraph::ngraph_error); } + +TEST(eval, evaluate_static_scatter_update_basic) +{ + const Shape data_shape{3, 3}; + const Shape indices_shape{1, 2}; + const Shape updates_shape{1, 2, 3}; + + auto arg1 = make_shared(element::f32, data_shape); + auto arg2 = make_shared(element::i32, indices_shape); + auto arg3 = make_shared(element::f32, updates_shape); + auto arg4 = make_shared(element::i64, Shape{}); + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE(fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1, 2}), + make_host_tensor( + updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), + make_host_tensor({}, {0})})); + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3})); + auto cval = read_vector(result_tensor); + vector out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}; + ASSERT_EQ(cval, out); +} + +TEST(eval, evaluate_dynamic_scatter_update_basic) +{ + const Shape data_shape{3, 3}; + const Shape indices_shape{1, 2}; + const Shape updates_shape{1, 2, 3}; + + auto arg1 = make_shared(element::f32, PartialShape::dynamic()); + auto arg2 = make_shared(element::i32, PartialShape::dynamic()); + auto arg3 = make_shared(element::f32, PartialShape::dynamic()); + auto arg4 = make_shared(element::i64, PartialShape::dynamic()); + + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE(fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1, 2}), + make_host_tensor( + updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), + make_host_tensor({}, {0})})); + + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3})); + auto cval = read_vector(result_tensor); + vector out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}; + ASSERT_EQ(cval, out); +} + +TEST(eval, evaluate_dynamic_scatter_update_negative_axis) +{ + const Shape data_shape{3, 3}; + const Shape indices_shape{1, 2}; + const Shape updates_shape{3, 1, 2}; + const Shape axis_shape{}; + + auto arg1 = make_shared(element::f32, PartialShape::dynamic()); + auto arg2 = make_shared(element::i32, PartialShape::dynamic()); + auto arg3 = make_shared(element::f32, PartialShape::dynamic()); + auto arg4 = make_shared(element::i64, PartialShape::dynamic()); + + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE(fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1, 2}), + make_host_tensor( + updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), + make_host_tensor(axis_shape, {-1})})); + + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3})); + auto cval = read_vector(result_tensor); + vector out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f}; + ASSERT_EQ(cval, out); +} + +TEST(eval, evaluate_dynamic_scatter_update_1d_axis) +{ + const Shape data_shape{3, 3}; + const Shape indices_shape{1, 2}; + const Shape updates_shape{3, 1, 2}; + + auto arg1 = make_shared(element::f32, PartialShape::dynamic()); + auto arg2 = make_shared(element::i32, PartialShape::dynamic()); + auto arg3 = make_shared(element::f32, PartialShape::dynamic()); + auto arg4 = make_shared(element::i64, PartialShape::dynamic()); + + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE(fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1, 2}), + make_host_tensor( + updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), + make_host_tensor({1}, {1})})); + + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3})); + auto cval = read_vector(result_tensor); + vector out{0.f, 1.0f, 1.1f, 0.0f, 1.2f, 2.0f, 0.0f, 2.1f, 2.2f}; + ASSERT_EQ(cval, out); +} + +TEST(eval, evaluate_dynamic_scatter_update_one_elem_i32) +{ + const Shape data_shape{3, 3, 2}; + const Shape indices_shape{1, 1}; + const Shape updates_shape{1, 1, 3, 2}; + + auto arg1 = make_shared(element::i32, PartialShape::dynamic()); + auto arg2 = make_shared(element::i32, PartialShape::dynamic()); + auto arg3 = make_shared(element::i32, PartialShape::dynamic()); + auto arg4 = make_shared(element::i64, PartialShape::dynamic()); + + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE( + fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1}), + make_host_tensor(updates_shape, {1, 2, 3, 4, 5, 6}), + make_host_tensor({}, {0})})); + + EXPECT_EQ(result_tensor->get_element_type(), element::i32); + EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{3, 3, 2})); + auto cval = read_vector(result_tensor); + vector out{0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}; + ASSERT_EQ(cval, out); +} From 5f7860ea693195c058504ab8227fdc85a50acd00 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 11 Aug 2020 10:39:22 +0200 Subject: [PATCH 2/6] Review comments. Clarify comments. --- ngraph/core/src/op/scatter_update.cpp | 11 ++--------- .../src/ngraph/runtime/reference/scatter_update.hpp | 10 +++++++--- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/ngraph/core/src/op/scatter_update.cpp b/ngraph/core/src/op/scatter_update.cpp index 71f066fc737f70..597edd6b1f55d2 100644 --- a/ngraph/core/src/op/scatter_update.cpp +++ b/ngraph/core/src/op/scatter_update.cpp @@ -156,15 +156,8 @@ bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs, if (normalized_axis < 0) { - if (input_rank.is_static()) - { - normalized_axis = ngraph::normalize_axis(this, axis_val, input_rank); - } - else - { - normalized_axis = ngraph::normalize_axis( - this, axis_val, static_cast(data->get_shape().size())); - } + normalized_axis = + ngraph::normalize_axis(this, axis_val, static_cast(data->get_shape().size())); } return evaluate_scatter_update(data, indices, updates, out, normalized_axis); diff --git a/ngraph/src/ngraph/runtime/reference/scatter_update.hpp b/ngraph/src/ngraph/runtime/reference/scatter_update.hpp index 274ece922cc3e6..f46033eac1e1b8 100644 --- a/ngraph/src/ngraph/runtime/reference/scatter_update.hpp +++ b/ngraph/src/ngraph/runtime/reference/scatter_update.hpp @@ -48,9 +48,11 @@ namespace ngraph // // for i_coord in indices[m, n, ..., p]: + // # get linear index // i_idx = index(i_coord) - // for d_coord in slice data[..., i_idx, ...] && - // for u_coord in slice updates[..., i_coord, ...]: + // # simultaneously iterate over two slices of data with same elements count + // for d_coord in slice data[..., i_idx, ...], + // u_coord in slice updates[..., i_coord, ...] // data[index(d_coord)] = updates[index(u_coord)] CoordinateTransform indices_transform{indices_shape}; @@ -60,7 +62,7 @@ namespace ngraph size_t updates_ndim = updates_shape.size(); // Create an outer CoordinateTransform for "update", which would allow to - // iterate only over "indicies" dimensions: + // iterate only over "indices" dimensions: // set to "1" all non-indices dimensions // updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1] Coordinate updates_indices_start_corner(updates_ndim, 0); @@ -96,6 +98,8 @@ namespace ngraph updates_update_end_corner[axis + i] = updates_update_start_corner[axis + i] + 1; } + // The m, n, .., p symbols stand for values at those axes. + // The m+1 means value at axis m plus 1. // udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0] // updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1] CoordinateTransform updates_update_transform( From 871d198f6f520016970cf16d6d32f34b72e70351 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 11 Aug 2020 10:53:03 +0200 Subject: [PATCH 3/6] Update file directory. --- .../include}/ngraph/runtime/reference/scatter_update.hpp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename ngraph/{src => core/include}/ngraph/runtime/reference/scatter_update.hpp (100%) diff --git a/ngraph/src/ngraph/runtime/reference/scatter_update.hpp b/ngraph/core/include/ngraph/runtime/reference/scatter_update.hpp similarity index 100% rename from ngraph/src/ngraph/runtime/reference/scatter_update.hpp rename to ngraph/core/include/ngraph/runtime/reference/scatter_update.hpp From 45e32df50d5cc42fd718176eaf848007392519be Mon Sep 17 00:00:00 2001 From: mitruska Date: Thu, 20 Aug 2020 16:58:19 +0200 Subject: [PATCH 4/6] Replace scatter_update reference implementation in ngraph/core/reference/ --- .../runtime/reference/scatter_update.hpp | 119 -------------- .../runtime/reference/scatter_update.hpp | 153 +++++++++++------- .../runtime/interpreter/int_executable.hpp | 8 +- 3 files changed, 97 insertions(+), 183 deletions(-) delete mode 100644 ngraph/core/include/ngraph/runtime/reference/scatter_update.hpp diff --git a/ngraph/core/include/ngraph/runtime/reference/scatter_update.hpp b/ngraph/core/include/ngraph/runtime/reference/scatter_update.hpp deleted file mode 100644 index f46033eac1e1b8..00000000000000 --- a/ngraph/core/include/ngraph/runtime/reference/scatter_update.hpp +++ /dev/null @@ -1,119 +0,0 @@ -//***************************************************************************** -// Copyright 2017-2020 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -//***************************************************************************** - -#pragma once - -#include - -#include "ngraph/check.hpp" -#include "ngraph/coordinate_transform.hpp" -#include "ngraph/shape.hpp" - -namespace ngraph -{ - namespace runtime - { - namespace reference - { - template - void scatter_update(const DataType* input_data, - const IndicesType* indices, - const DataType* updates, - const int64_t& axis, - DataType* out_buf, - const Shape& data_shape, - const Shape& indices_shape, - const Shape& updates_shape) - { - // Copy inputs to out - std::memcpy(out_buf, input_data, sizeof(DataType) * shape_size(data_shape)); - - // Algorithm overview - // data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...] - // where first ... in the data corresponds to first axis dimensions, - // last ... in the data corresponds to the rank(data) - (axis + 1) dimensions. - - // - // for i_coord in indices[m, n, ..., p]: - // # get linear index - // i_idx = index(i_coord) - // # simultaneously iterate over two slices of data with same elements count - // for d_coord in slice data[..., i_idx, ...], - // u_coord in slice updates[..., i_coord, ...] - // data[index(d_coord)] = updates[index(u_coord)] - - CoordinateTransform indices_transform{indices_shape}; - CoordinateTransform data_transform{data_shape}; - - size_t indices_ndim = indices_shape.size(); - size_t updates_ndim = updates_shape.size(); - - // Create an outer CoordinateTransform for "update", which would allow to - // iterate only over "indices" dimensions: - // set to "1" all non-indices dimensions - // updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1] - Coordinate updates_indices_start_corner(updates_ndim, 0); - Coordinate updates_indices_end_corner(updates_ndim, 1); - for (size_t i = 0; i < indices_ndim; ++i) - { - updates_indices_end_corner[axis + i] = updates_shape[axis + i]; - } - CoordinateTransform updates_indices_transform( - updates_shape, updates_indices_start_corner, updates_indices_end_corner); - // Is needed to simultaneously iterate over updates coordinates while - // iterating over indices. - auto updates_indices_coord_iter = updates_indices_transform.begin(); - - for (const Coordinate& indices_cord : indices_transform) - { - const size_t indices_idx = indices_transform.index(indices_cord); - IndicesType slice_index = indices[indices_idx]; - - // Define the extent of coordinates which will be updated. - Coordinate out_start_corner(data_shape.size(), 0); - Coordinate out_end_corner(data_shape); - out_start_corner[axis] = static_cast(slice_index); - out_end_corner[axis] = out_start_corner[axis] + 1; - CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner); - - // Define the CoordinateTransform for updates coordinates. - // All except indices-dimensions. - Coordinate updates_update_start_corner = *updates_indices_coord_iter; - Coordinate updates_update_end_corner(updates_shape); - for (size_t i = 0; i < indices_ndim; ++i) - { - updates_update_end_corner[axis + i] = - updates_update_start_corner[axis + i] + 1; - } - // The m, n, .., p symbols stand for values at those axes. - // The m+1 means value at axis m plus 1. - // udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0] - // updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1] - CoordinateTransform updates_update_transform( - updates_shape, updates_update_start_corner, updates_update_end_corner); - auto updates_update_coord_iter = updates_update_transform.begin(); - for (const Coordinate& out_cord : out_transform) - { - out_buf[out_transform.index(out_cord)] = - updates[updates_update_transform.index(*updates_update_coord_iter)]; - updates_update_coord_iter++; - } - updates_indices_coord_iter++; - } - } - } - } -} diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp index e3cae8c014750b..f46033eac1e1b8 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp @@ -1,86 +1,119 @@ -// Copyright (C) 2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation // +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** #pragma once -#include +#include + +#include "ngraph/check.hpp" #include "ngraph/coordinate_transform.hpp" #include "ngraph/shape.hpp" -using namespace ngraph; - namespace ngraph { namespace runtime { namespace reference { - template - void scatterUpdate(const dataType* inputData, - const indicesType* indices, - const dataType* updates, - const axisType* _axis, - dataType* outBuf, - const Shape& dataShape, - const Shape& indicesShape, - const Shape& updatesShape) + template + void scatter_update(const DataType* input_data, + const IndicesType* indices, + const DataType* updates, + const int64_t& axis, + DataType* out_buf, + const Shape& data_shape, + const Shape& indices_shape, + const Shape& updates_shape) { - int rank = static_cast(dataShape.size()); - if (_axis[0] < -rank || _axis[0] > rank - 1) - { - std::string error = - std::string("ScatterUpdate layer has out of bounds axis value: ") + - std::to_string(_axis[0]); - throw ngraph_error(error); - } - size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0]; - CoordinateTransform indicesTransform{indicesShape}; + // Copy inputs to out + std::memcpy(out_buf, input_data, sizeof(DataType) * shape_size(data_shape)); - Shape dataShapeIter = dataShape; - dataShapeIter.erase(dataShapeIter.begin() + axis); - CoordinateTransform dataTransfIter{dataShapeIter}; + // Algorithm overview + // data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...] + // where first ... in the data corresponds to first axis dimensions, + // last ... in the data corresponds to the rank(data) - (axis + 1) dimensions. - CoordinateTransform updateTransform{updatesShape}; - CoordinateTransform dataTransform{dataShape}; + // + // for i_coord in indices[m, n, ..., p]: + // # get linear index + // i_idx = index(i_coord) + // # simultaneously iterate over two slices of data with same elements count + // for d_coord in slice data[..., i_idx, ...], + // u_coord in slice updates[..., i_coord, ...] + // data[index(d_coord)] = updates[index(u_coord)] - std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape)); + CoordinateTransform indices_transform{indices_shape}; + CoordinateTransform data_transform{data_shape}; - for (const Coordinate& indicesCoordIt : indicesTransform) + size_t indices_ndim = indices_shape.size(); + size_t updates_ndim = updates_shape.size(); + + // Create an outer CoordinateTransform for "update", which would allow to + // iterate only over "indices" dimensions: + // set to "1" all non-indices dimensions + // updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1] + Coordinate updates_indices_start_corner(updates_ndim, 0); + Coordinate updates_indices_end_corner(updates_ndim, 1); + for (size_t i = 0; i < indices_ndim; ++i) { - const size_t indicesIdx = indicesTransform.index(indicesCoordIt); + updates_indices_end_corner[axis + i] = updates_shape[axis + i]; + } + CoordinateTransform updates_indices_transform( + updates_shape, updates_indices_start_corner, updates_indices_end_corner); + // Is needed to simultaneously iterate over updates coordinates while + // iterating over indices. + auto updates_indices_coord_iter = updates_indices_transform.begin(); - if (indices[indicesIdx] < 0) - { - std::string error = - std::string("ScatterUpdate layer has negative index value: ") + - std::to_string(indices[indicesIdx]); - throw ngraph_error(error); - } - const size_t idx = static_cast(indices[indicesIdx]); - if (dataShape[axis] <= idx) + for (const Coordinate& indices_cord : indices_transform) + { + const size_t indices_idx = indices_transform.index(indices_cord); + IndicesType slice_index = indices[indices_idx]; + + // Define the extent of coordinates which will be updated. + Coordinate out_start_corner(data_shape.size(), 0); + Coordinate out_end_corner(data_shape); + out_start_corner[axis] = static_cast(slice_index); + out_end_corner[axis] = out_start_corner[axis] + 1; + CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner); + + // Define the CoordinateTransform for updates coordinates. + // All except indices-dimensions. + Coordinate updates_update_start_corner = *updates_indices_coord_iter; + Coordinate updates_update_end_corner(updates_shape); + for (size_t i = 0; i < indices_ndim; ++i) { - std::string error = - std::string("ScatterUpdate layer has out of bounds coordinate: ") + - std::to_string(idx) + " on 'data' input on " + std::to_string(axis) + - "th axis"; - throw ngraph_error(error); + updates_update_end_corner[axis + i] = + updates_update_start_corner[axis + i] + 1; } - - for (const Coordinate& dataCoordIt : dataTransfIter) + // The m, n, .., p symbols stand for values at those axes. + // The m+1 means value at axis m plus 1. + // udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0] + // updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1] + CoordinateTransform updates_update_transform( + updates_shape, updates_update_start_corner, updates_update_end_corner); + auto updates_update_coord_iter = updates_update_transform.begin(); + for (const Coordinate& out_cord : out_transform) { - Coordinate dataCoord = dataCoordIt; - dataCoord.insert(dataCoord.begin() + axis, idx); - const size_t startIndices = dataTransform.index(dataCoord); - - auto updCoord = dataCoordIt; - updCoord.insert( - updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end()); - const size_t startUpd = updateTransform.index(updCoord); - outBuf[startIndices] = updates[startUpd]; + out_buf[out_transform.index(out_cord)] = + updates[updates_update_transform.index(*updates_update_coord_iter)]; + updates_update_coord_iter++; } + updates_indices_coord_iter++; } } - } // namespace reference - } // namespace runtime -} // namespace ngraph + } + } +} diff --git a/ngraph/test/runtime/interpreter/int_executable.hpp b/ngraph/test/runtime/interpreter/int_executable.hpp index 04be016c8b1795..65ede84cb54976 100644 --- a/ngraph/test/runtime/interpreter/int_executable.hpp +++ b/ngraph/test/runtime/interpreter/int_executable.hpp @@ -1207,11 +1207,11 @@ class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : publ auto idxType = scatterUpd->get_input_element_type(1); if (idxType == element::i32) { - reference::scatterUpdate( + reference::scatter_update( args[0]->get_data_ptr(), args[1]->get_data_ptr(), args[2]->get_data_ptr(), - args[3]->get_data_ptr(), + *args[3]->get_data_ptr(), out[0]->get_data_ptr(), node.get_input_shape(0), node.get_input_shape(1), @@ -1219,11 +1219,11 @@ class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : publ } else if (idxType == element::i64) { - reference::scatterUpdate( + reference::scatter_update( args[0]->get_data_ptr(), args[1]->get_data_ptr(), args[2]->get_data_ptr(), - args[3]->get_data_ptr(), + *args[3]->get_data_ptr(), out[0]->get_data_ptr(), node.get_input_shape(0), node.get_input_shape(1), From e5f8df5a2f0e60b90c8d530cd89cf0de57f461f7 Mon Sep 17 00:00:00 2001 From: mitruska Date: Mon, 24 Aug 2020 13:31:45 +0200 Subject: [PATCH 5/6] Remove template code from ScatterUpdate reference implementation --- .../runtime/reference/scatter_update.hpp | 21 ++- ngraph/core/src/op/scatter_update.cpp | 171 ++++++++---------- ngraph/test/eval.cpp | 32 +++- .../runtime/interpreter/int_executable.hpp | 45 +---- 4 files changed, 122 insertions(+), 147 deletions(-) diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp index f46033eac1e1b8..18b645ef4c2fa1 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp @@ -28,18 +28,18 @@ namespace ngraph { namespace reference { - template - void scatter_update(const DataType* input_data, - const IndicesType* indices, - const DataType* updates, + void scatter_update(const char* input_data, + const int64_t* indices, + const char* updates, const int64_t& axis, - DataType* out_buf, + char* out_buf, + const size_t elem_size, const Shape& data_shape, const Shape& indices_shape, const Shape& updates_shape) { // Copy inputs to out - std::memcpy(out_buf, input_data, sizeof(DataType) * shape_size(data_shape)); + std::memcpy(out_buf, input_data, elem_size * shape_size(data_shape)); // Algorithm overview // data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...] @@ -80,7 +80,7 @@ namespace ngraph for (const Coordinate& indices_cord : indices_transform) { const size_t indices_idx = indices_transform.index(indices_cord); - IndicesType slice_index = indices[indices_idx]; + int64_t slice_index = indices[indices_idx]; // Define the extent of coordinates which will be updated. Coordinate out_start_corner(data_shape.size(), 0); @@ -107,8 +107,11 @@ namespace ngraph auto updates_update_coord_iter = updates_update_transform.begin(); for (const Coordinate& out_cord : out_transform) { - out_buf[out_transform.index(out_cord)] = - updates[updates_update_transform.index(*updates_update_coord_iter)]; + const auto src_idx = + updates_update_transform.index(*updates_update_coord_iter) * elem_size; + std::copy(updates + src_idx, + updates + (src_idx + elem_size), + out_buf + out_transform.index(out_cord) * elem_size); updates_update_coord_iter++; } updates_indices_coord_iter++; diff --git a/ngraph/core/src/op/scatter_update.cpp b/ngraph/core/src/op/scatter_update.cpp index 597edd6b1f55d2..1ebf07e52995ec 100644 --- a/ngraph/core/src/op/scatter_update.cpp +++ b/ngraph/core/src/op/scatter_update.cpp @@ -41,93 +41,6 @@ shared_ptr op::v3::ScatterUpdate::clone_with_new_inputs(const OutputVector new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3)); } -namespace -{ - template - bool evaluate(const HostTensorPtr& data, - const HostTensorPtr& indices, - const HostTensorPtr& updates, - const HostTensorPtr& out, - const int64_t normalized_axis) - { - using DataType = typename element_type_traits
::value_type; - using IndicesType = typename element_type_traits::value_type; - - out->set_shape(data->get_shape()); - runtime::reference::scatter_update( - data->get_data_ptr(), - indices->get_data_ptr(), - updates->get_data_ptr(), - normalized_axis, - out->get_data_ptr(), - data->get_shape(), - indices->get_shape(), - updates->get_shape()); - - return true; - } - - template - bool evaluate(const HostTensorPtr& data, - const HostTensorPtr& indices, - const HostTensorPtr& updates, - const HostTensorPtr& out, - const int64_t normalized_axis) - { - // Dispatch specialization based on indicies data type. - bool rc = true; - - switch (indices->get_element_type()) - { - case element::Type_t::i8: - case element::Type_t::u8: - rc = evaluate(data, indices, updates, out, normalized_axis); - break; - case element::Type_t::i16: - case element::Type_t::u16: - rc = evaluate(data, indices, updates, out, normalized_axis); - break; - case element::Type_t::i32: - case element::Type_t::u32: - rc = evaluate(data, indices, updates, out, normalized_axis); - break; - case element::Type_t::i64: - case element::Type_t::u64: - rc = evaluate(data, indices, updates, out, normalized_axis); - break; - default: rc = false; break; - } - return rc; - } - - bool evaluate_scatter_update(const HostTensorPtr& data, - const HostTensorPtr& indices, - const HostTensorPtr& updates, - const HostTensorPtr& out, - const int64_t normalized_axis) - { - // Dispatch based on data, updates and output data type. - bool rc = true; - switch (out->get_element_type()) - { - case element::Type_t::i32: - case element::Type_t::u32: - rc = evaluate(data, indices, updates, out, normalized_axis); - break; - case element::Type_t::i64: - case element::Type_t::u64: - rc = evaluate(data, indices, updates, out, normalized_axis); - break; - TYPE_CASE(f16)(data, indices, updates, out, normalized_axis); - break; - TYPE_CASE(f32)(data, indices, updates, out, normalized_axis); - break; - default: rc = false; break; - } - return rc; - } -} - bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const { @@ -137,6 +50,9 @@ bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs, const auto& axis = inputs[3]; const auto& out = outputs[0]; + const auto elem_size = data->get_element_type().size(); + out->set_shape(data->get_shape()); + int64_t axis_val = 0; switch (axis->get_element_type()) { @@ -151,14 +67,83 @@ bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs, default: throw ngraph_error("axis element type is not integral data type"); } - const auto& input_rank = get_input_partial_shape(0).rank(); - int64_t normalized_axis = axis_val; - - if (normalized_axis < 0) + if (axis_val < 0) { - normalized_axis = + axis_val = ngraph::normalize_axis(this, axis_val, static_cast(data->get_shape().size())); } - return evaluate_scatter_update(data, indices, updates, out, normalized_axis); + std::vector indices_casted_vector; + switch (indices->get_element_type()) + { + case element::Type_t::i8: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::i16: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::i32: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::i64: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::u8: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::u16: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::u32: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + case element::Type_t::u64: + { + auto indices_ptr = indices->get_data_ptr(); + indices_casted_vector = + std::vector(indices_ptr, indices_ptr + indices->get_element_count()); + break; + } + default: throw ngraph_error("indices element type is not integral data type"); + } + + runtime::reference::scatter_update(data->get_data_ptr(), + indices_casted_vector.data(), + updates->get_data_ptr(), + axis_val, + out->get_data_ptr(), + elem_size, + data->get_shape(), + indices->get_shape(), + updates->get_shape()); + + return true; } diff --git a/ngraph/test/eval.cpp b/ngraph/test/eval.cpp index db05ee0e1645d6..792ee3892c2634 100644 --- a/ngraph/test/eval.cpp +++ b/ngraph/test/eval.cpp @@ -1939,7 +1939,7 @@ TEST(eval, reduce_logical_and__neg_axis) ngraph::ngraph_error); } -TEST(eval, evaluate_static_scatter_update_basic) +TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i32) { const Shape data_shape{3, 3}; const Shape indices_shape{1, 2}; @@ -1948,7 +1948,7 @@ TEST(eval, evaluate_static_scatter_update_basic) auto arg1 = make_shared(element::f32, data_shape); auto arg2 = make_shared(element::i32, indices_shape); auto arg3 = make_shared(element::f32, updates_shape); - auto arg4 = make_shared(element::i64, Shape{}); + auto arg4 = make_shared(element::i32, Shape{}); auto scatter_update = make_shared(arg1, arg2, arg3, arg4); auto fun = make_shared(OutputVector{scatter_update}, ParameterVector{arg1, arg2, arg3, arg4}); @@ -1957,6 +1957,34 @@ TEST(eval, evaluate_static_scatter_update_basic) {make_host_tensor( data_shape, std::vector(shape_size(data_shape))), make_host_tensor(indices_shape, {1, 2}), + make_host_tensor( + updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), + make_host_tensor({}, {0})})); + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_shape(), (Shape{3, 3})); + auto cval = read_vector(result_tensor); + vector out{0.f, 0.f, 0.f, 1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}; + ASSERT_EQ(cval, out); +} + +TEST(eval, evaluate_static_scatter_update_basic_axes_indices_i64) +{ + const Shape data_shape{3, 3}; + const Shape indices_shape{1, 2}; + const Shape updates_shape{1, 2, 3}; + + auto arg1 = make_shared(element::f32, data_shape); + auto arg2 = make_shared(element::i64, indices_shape); + auto arg3 = make_shared(element::f32, updates_shape); + auto arg4 = make_shared(element::i64, Shape{}); + auto scatter_update = make_shared(arg1, arg2, arg3, arg4); + auto fun = make_shared(OutputVector{scatter_update}, + ParameterVector{arg1, arg2, arg3, arg4}); + auto result_tensor = make_shared(); + ASSERT_TRUE(fun->evaluate({result_tensor}, + {make_host_tensor( + data_shape, std::vector(shape_size(data_shape))), + make_host_tensor(indices_shape, {1, 2}), make_host_tensor( updates_shape, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}), make_host_tensor({}, {0})})); diff --git a/ngraph/test/runtime/interpreter/int_executable.hpp b/ngraph/test/runtime/interpreter/int_executable.hpp index 65ede84cb54976..387c9a84d27e24 100644 --- a/ngraph/test/runtime/interpreter/int_executable.hpp +++ b/ngraph/test/runtime/interpreter/int_executable.hpp @@ -79,7 +79,7 @@ #include "ngraph/runtime/reference/reverse_sequence.hpp" #include "ngraph/runtime/reference/round.hpp" #include "ngraph/runtime/reference/scatter_nd_update.hpp" -#include "ngraph/runtime/reference/scatter_update.hpp" +// #include "ngraph/runtime/reference/scatter_update.hpp" #include "ngraph/runtime/reference/select.hpp" #include "ngraph/runtime/reference/sigmoid.hpp" #include "ngraph/runtime/reference/sign.hpp" @@ -1195,48 +1195,6 @@ class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : publ break; } - case OP_TYPEID::ScatterUpdate_v3: - { - const op::v3::ScatterUpdate* scatterUpd = - static_cast(&node); - - if (scatterUpd->get_input_element_type(3) != element::i64) - throw ngraph_error( - "ScatterNDUpdate layer support only i64 'axis' input precision!"); - - auto idxType = scatterUpd->get_input_element_type(1); - if (idxType == element::i32) - { - reference::scatter_update( - args[0]->get_data_ptr(), - args[1]->get_data_ptr(), - args[2]->get_data_ptr(), - *args[3]->get_data_ptr(), - out[0]->get_data_ptr(), - node.get_input_shape(0), - node.get_input_shape(1), - node.get_input_shape(2)); - } - else if (idxType == element::i64) - { - reference::scatter_update( - args[0]->get_data_ptr(), - args[1]->get_data_ptr(), - args[2]->get_data_ptr(), - *args[3]->get_data_ptr(), - out[0]->get_data_ptr(), - node.get_input_shape(0), - node.get_input_shape(1), - node.get_input_shape(2)); - } - else - { - throw ngraph_error( - "ScatterUpdate layer support only i32 and i64 'indices' input precision!"); - } - - break; - } // Fused Ops are not supported in interpreter. They need to be decomposed before execution case OP_TYPEID::DepthToSpace: @@ -1255,6 +1213,7 @@ class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : publ case OP_TYPEID::NormalizeL2: case OP_TYPEID::PRelu: case OP_TYPEID::RNNCell: + case OP_TYPEID::ScatterUpdate_v3: case OP_TYPEID::Selu: case OP_TYPEID::ShuffleChannels: case OP_TYPEID::SpaceToDepth: From 6663b9823afd9715abc11e990fb40aac01db7604 Mon Sep 17 00:00:00 2001 From: mitruska Date: Mon, 24 Aug 2020 17:59:12 +0200 Subject: [PATCH 6/6] Apply review requests --- .../include/ngraph/runtime/reference/scatter_update.hpp | 4 +--- ngraph/test/runtime/interpreter/int_executable.hpp | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp index 18b645ef4c2fa1..f8d00b17ebfe50 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp @@ -16,8 +16,6 @@ #pragma once -#include - #include "ngraph/check.hpp" #include "ngraph/coordinate_transform.hpp" #include "ngraph/shape.hpp" @@ -31,7 +29,7 @@ namespace ngraph void scatter_update(const char* input_data, const int64_t* indices, const char* updates, - const int64_t& axis, + const int64_t axis, char* out_buf, const size_t elem_size, const Shape& data_shape, diff --git a/ngraph/test/runtime/interpreter/int_executable.hpp b/ngraph/test/runtime/interpreter/int_executable.hpp index 387c9a84d27e24..8a07add52bb14e 100644 --- a/ngraph/test/runtime/interpreter/int_executable.hpp +++ b/ngraph/test/runtime/interpreter/int_executable.hpp @@ -79,7 +79,6 @@ #include "ngraph/runtime/reference/reverse_sequence.hpp" #include "ngraph/runtime/reference/round.hpp" #include "ngraph/runtime/reference/scatter_nd_update.hpp" -// #include "ngraph/runtime/reference/scatter_update.hpp" #include "ngraph/runtime/reference/select.hpp" #include "ngraph/runtime/reference/sigmoid.hpp" #include "ngraph/runtime/reference/sign.hpp"