Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reference implementation for ScatterUpdate #1678

Merged
4 changes: 4 additions & 0 deletions ngraph/core/include/ngraph/op/scatter_update.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "ngraph/op/op.hpp"
#include "ngraph/op/util/scatter_base.hpp"
#include "ngraph/runtime/host_tensor.hpp"

namespace ngraph
{
Expand Down Expand Up @@ -49,6 +50,9 @@ namespace ngraph

virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& inputs) const override;

bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
};
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,86 +1,122 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <string>
#include <cstring>
mitruska marked this conversation as resolved.
Show resolved Hide resolved

#include "ngraph/check.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape.hpp"

using namespace ngraph;

namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename dataType, typename indicesType, typename axisType>
void scatterUpdate(const dataType* inputData,
const indicesType* indices,
const dataType* updates,
const axisType* _axis,
dataType* outBuf,
const Shape& dataShape,
const Shape& indicesShape,
const Shape& updatesShape)
void scatter_update(const char* input_data,
const int64_t* indices,
const char* updates,
const int64_t& axis,
mitruska marked this conversation as resolved.
Show resolved Hide resolved
char* out_buf,
const size_t elem_size,
const Shape& data_shape,
const Shape& indices_shape,
const Shape& updates_shape)
{
int rank = static_cast<int>(dataShape.size());
if (_axis[0] < -rank || _axis[0] > rank - 1)
{
std::string error =
std::string("ScatterUpdate layer has out of bounds axis value: ") +
std::to_string(_axis[0]);
throw ngraph_error(error);
}
size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0];
CoordinateTransform indicesTransform{indicesShape};
// Copy inputs to out
std::memcpy(out_buf, input_data, elem_size * shape_size(data_shape));

Shape dataShapeIter = dataShape;
dataShapeIter.erase(dataShapeIter.begin() + axis);
CoordinateTransform dataTransfIter{dataShapeIter};
// Algorithm overview
// data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...]
// where first ... in the data corresponds to first axis dimensions,
// last ... in the data corresponds to the rank(data) - (axis + 1) dimensions.

CoordinateTransform updateTransform{updatesShape};
CoordinateTransform dataTransform{dataShape};
//
// for i_coord in indices[m, n, ..., p]:
// # get linear index
// i_idx = index(i_coord)
// # simultaneously iterate over two slices of data with same elements count
// for d_coord in slice data[..., i_idx, ...],
// u_coord in slice updates[..., i_coord, ...]
// data[index(d_coord)] = updates[index(u_coord)]

std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
CoordinateTransform indices_transform{indices_shape};
CoordinateTransform data_transform{data_shape};

for (const Coordinate& indicesCoordIt : indicesTransform)
size_t indices_ndim = indices_shape.size();
size_t updates_ndim = updates_shape.size();

// Create an outer CoordinateTransform for "update", which would allow to
// iterate only over "indices" dimensions:
// set to "1" all non-indices dimensions
// updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1]
Coordinate updates_indices_start_corner(updates_ndim, 0);
Coordinate updates_indices_end_corner(updates_ndim, 1);
for (size_t i = 0; i < indices_ndim; ++i)
{
const size_t indicesIdx = indicesTransform.index(indicesCoordIt);
updates_indices_end_corner[axis + i] = updates_shape[axis + i];
}
CoordinateTransform updates_indices_transform(
updates_shape, updates_indices_start_corner, updates_indices_end_corner);
// Is needed to simultaneously iterate over updates coordinates while
// iterating over indices.
auto updates_indices_coord_iter = updates_indices_transform.begin();

if (indices[indicesIdx] < 0)
{
std::string error =
std::string("ScatterUpdate layer has negative index value: ") +
std::to_string(indices[indicesIdx]);
throw ngraph_error(error);
}
const size_t idx = static_cast<size_t>(indices[indicesIdx]);
if (dataShape[axis] <= idx)
for (const Coordinate& indices_cord : indices_transform)
{
const size_t indices_idx = indices_transform.index(indices_cord);
int64_t slice_index = indices[indices_idx];

// Define the extent of coordinates which will be updated.
Coordinate out_start_corner(data_shape.size(), 0);
Coordinate out_end_corner(data_shape);
out_start_corner[axis] = static_cast<size_t>(slice_index);
out_end_corner[axis] = out_start_corner[axis] + 1;
CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner);

// Define the CoordinateTransform for updates coordinates.
// All except indices-dimensions.
Coordinate updates_update_start_corner = *updates_indices_coord_iter;
Coordinate updates_update_end_corner(updates_shape);
for (size_t i = 0; i < indices_ndim; ++i)
{
std::string error =
std::string("ScatterUpdate layer has out of bounds coordinate: ") +
std::to_string(idx) + " on 'data' input on " + std::to_string(axis) +
"th axis";
throw ngraph_error(error);
updates_update_end_corner[axis + i] =
updates_update_start_corner[axis + i] + 1;
}

for (const Coordinate& dataCoordIt : dataTransfIter)
// The m, n, .., p symbols stand for values at those axes.
// The m+1 means value at axis m plus 1.
// udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0]
// updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1]
CoordinateTransform updates_update_transform(
updates_shape, updates_update_start_corner, updates_update_end_corner);
auto updates_update_coord_iter = updates_update_transform.begin();
for (const Coordinate& out_cord : out_transform)
{
Coordinate dataCoord = dataCoordIt;
dataCoord.insert(dataCoord.begin() + axis, idx);
const size_t startIndices = dataTransform.index(dataCoord);

auto updCoord = dataCoordIt;
updCoord.insert(
updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end());
const size_t startUpd = updateTransform.index(updCoord);
outBuf[startIndices] = updates[startUpd];
const auto src_idx =
updates_update_transform.index(*updates_update_coord_iter) * elem_size;
std::copy(updates + src_idx,
updates + (src_idx + elem_size),
out_buf + out_transform.index(out_cord) * elem_size);
updates_update_coord_iter++;
}
updates_indices_coord_iter++;
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph
}
}
}
111 changes: 111 additions & 0 deletions ngraph/core/src/op/scatter_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
//*****************************************************************************

#include "ngraph/op/scatter_update.hpp"
#include "ngraph/runtime/reference/scatter_update.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type_traits.hpp"
#include "ngraph/validation_util.hpp"

using namespace std;
using namespace ngraph;
Expand All @@ -36,3 +40,110 @@ shared_ptr<Node> op::v3::ScatterUpdate::clone_with_new_inputs(const OutputVector
return make_shared<v3::ScatterUpdate>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
}

bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const
{
const auto& data = inputs[0];
const auto& indices = inputs[1];
const auto& updates = inputs[2];
const auto& axis = inputs[3];
const auto& out = outputs[0];

const auto elem_size = data->get_element_type().size();
out->set_shape(data->get_shape());

int64_t axis_val = 0;
switch (axis->get_element_type())
{
case element::Type_t::i8: axis_val = axis->get_data_ptr<element::Type_t::i8>()[0]; break;
case element::Type_t::i16: axis_val = axis->get_data_ptr<element::Type_t::i16>()[0]; break;
case element::Type_t::i32: axis_val = axis->get_data_ptr<element::Type_t::i32>()[0]; break;
case element::Type_t::i64: axis_val = axis->get_data_ptr<element::Type_t::i64>()[0]; break;
case element::Type_t::u8: axis_val = axis->get_data_ptr<element::Type_t::u8>()[0]; break;
case element::Type_t::u16: axis_val = axis->get_data_ptr<element::Type_t::u16>()[0]; break;
case element::Type_t::u32: axis_val = axis->get_data_ptr<element::Type_t::u32>()[0]; break;
case element::Type_t::u64: axis_val = axis->get_data_ptr<element::Type_t::u64>()[0]; break;
default: throw ngraph_error("axis element type is not integral data type");
}

if (axis_val < 0)
{
axis_val =
ngraph::normalize_axis(this, axis_val, static_cast<int64_t>(data->get_shape().size()));
}

std::vector<int64_t> indices_casted_vector;
switch (indices->get_element_type())
{
case element::Type_t::i8:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i8>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::i16:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i16>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::i32:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i32>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::i64:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i64>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u8:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u8>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u16:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u16>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u32:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u32>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u64:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u64>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
default: throw ngraph_error("indices element type is not integral data type");
}

runtime::reference::scatter_update(data->get_data_ptr<char>(),
indices_casted_vector.data(),
updates->get_data_ptr<char>(),
axis_val,
out->get_data_ptr<char>(),
elem_size,
data->get_shape(),
indices->get_shape(),
updates->get_shape());

return true;
}
Loading