Skip to content

Commit

Permalink
Reference implementation for ScatterUpdate (#1678)
Browse files Browse the repository at this point in the history
* Reference implementation for ScatterUpdate and use of it in evaluate.

* Review comments. Clarify comments.

* Update file directory.

* Replace scatter_update reference implementation in ngraph/core/reference/

* Remove template code from ScatterUpdate reference implementation

* Apply review requests

Co-authored-by: mitruska <[email protected]>
  • Loading branch information
arogowie-intel and mitruska authored Aug 25, 2020
1 parent db2e5c0 commit 393e929
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 103 deletions.
4 changes: 4 additions & 0 deletions ngraph/core/include/ngraph/op/scatter_update.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "ngraph/op/op.hpp"
#include "ngraph/op/util/scatter_base.hpp"
#include "ngraph/runtime/host_tensor.hpp"

namespace ngraph
{
Expand Down Expand Up @@ -49,6 +50,9 @@ namespace ngraph

virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& inputs) const override;

bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
};
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,86 +1,120 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <string>
#include "ngraph/check.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape.hpp"

using namespace ngraph;

namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename dataType, typename indicesType, typename axisType>
void scatterUpdate(const dataType* inputData,
const indicesType* indices,
const dataType* updates,
const axisType* _axis,
dataType* outBuf,
const Shape& dataShape,
const Shape& indicesShape,
const Shape& updatesShape)
void scatter_update(const char* input_data,
const int64_t* indices,
const char* updates,
const int64_t axis,
char* out_buf,
const size_t elem_size,
const Shape& data_shape,
const Shape& indices_shape,
const Shape& updates_shape)
{
int rank = static_cast<int>(dataShape.size());
if (_axis[0] < -rank || _axis[0] > rank - 1)
{
std::string error =
std::string("ScatterUpdate layer has out of bounds axis value: ") +
std::to_string(_axis[0]);
throw ngraph_error(error);
}
size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0];
CoordinateTransform indicesTransform{indicesShape};
// Copy inputs to out
std::memcpy(out_buf, input_data, elem_size * shape_size(data_shape));

// Algorithm overview
// data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...]
// where first ... in the data corresponds to first axis dimensions,
// last ... in the data corresponds to the rank(data) - (axis + 1) dimensions.

//
// for i_coord in indices[m, n, ..., p]:
// # get linear index
// i_idx = index(i_coord)
// # simultaneously iterate over two slices of data with same elements count
// for d_coord in slice data[..., i_idx, ...],
// u_coord in slice updates[..., i_coord, ...]
// data[index(d_coord)] = updates[index(u_coord)]

Shape dataShapeIter = dataShape;
dataShapeIter.erase(dataShapeIter.begin() + axis);
CoordinateTransform dataTransfIter{dataShapeIter};
CoordinateTransform indices_transform{indices_shape};
CoordinateTransform data_transform{data_shape};

CoordinateTransform updateTransform{updatesShape};
CoordinateTransform dataTransform{dataShape};
size_t indices_ndim = indices_shape.size();
size_t updates_ndim = updates_shape.size();

std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
// Create an outer CoordinateTransform for "update", which would allow to
// iterate only over "indices" dimensions:
// set to "1" all non-indices dimensions
// updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1]
Coordinate updates_indices_start_corner(updates_ndim, 0);
Coordinate updates_indices_end_corner(updates_ndim, 1);
for (size_t i = 0; i < indices_ndim; ++i)
{
updates_indices_end_corner[axis + i] = updates_shape[axis + i];
}
CoordinateTransform updates_indices_transform(
updates_shape, updates_indices_start_corner, updates_indices_end_corner);
// Is needed to simultaneously iterate over updates coordinates while
// iterating over indices.
auto updates_indices_coord_iter = updates_indices_transform.begin();

for (const Coordinate& indicesCoordIt : indicesTransform)
for (const Coordinate& indices_cord : indices_transform)
{
const size_t indicesIdx = indicesTransform.index(indicesCoordIt);
const size_t indices_idx = indices_transform.index(indices_cord);
int64_t slice_index = indices[indices_idx];

if (indices[indicesIdx] < 0)
{
std::string error =
std::string("ScatterUpdate layer has negative index value: ") +
std::to_string(indices[indicesIdx]);
throw ngraph_error(error);
}
const size_t idx = static_cast<size_t>(indices[indicesIdx]);
if (dataShape[axis] <= idx)
// Define the extent of coordinates which will be updated.
Coordinate out_start_corner(data_shape.size(), 0);
Coordinate out_end_corner(data_shape);
out_start_corner[axis] = static_cast<size_t>(slice_index);
out_end_corner[axis] = out_start_corner[axis] + 1;
CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner);

// Define the CoordinateTransform for updates coordinates.
// All except indices-dimensions.
Coordinate updates_update_start_corner = *updates_indices_coord_iter;
Coordinate updates_update_end_corner(updates_shape);
for (size_t i = 0; i < indices_ndim; ++i)
{
std::string error =
std::string("ScatterUpdate layer has out of bounds coordinate: ") +
std::to_string(idx) + " on 'data' input on " + std::to_string(axis) +
"th axis";
throw ngraph_error(error);
updates_update_end_corner[axis + i] =
updates_update_start_corner[axis + i] + 1;
}

for (const Coordinate& dataCoordIt : dataTransfIter)
// The m, n, .., p symbols stand for values at those axes.
// The m+1 means value at axis m plus 1.
// udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0]
// updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1]
CoordinateTransform updates_update_transform(
updates_shape, updates_update_start_corner, updates_update_end_corner);
auto updates_update_coord_iter = updates_update_transform.begin();
for (const Coordinate& out_cord : out_transform)
{
Coordinate dataCoord = dataCoordIt;
dataCoord.insert(dataCoord.begin() + axis, idx);
const size_t startIndices = dataTransform.index(dataCoord);

auto updCoord = dataCoordIt;
updCoord.insert(
updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end());
const size_t startUpd = updateTransform.index(updCoord);
outBuf[startIndices] = updates[startUpd];
const auto src_idx =
updates_update_transform.index(*updates_update_coord_iter) * elem_size;
std::copy(updates + src_idx,
updates + (src_idx + elem_size),
out_buf + out_transform.index(out_cord) * elem_size);
updates_update_coord_iter++;
}
updates_indices_coord_iter++;
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph
}
}
}
111 changes: 111 additions & 0 deletions ngraph/core/src/op/scatter_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
//*****************************************************************************

#include "ngraph/op/scatter_update.hpp"
#include "ngraph/runtime/reference/scatter_update.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type_traits.hpp"
#include "ngraph/validation_util.hpp"

using namespace std;
using namespace ngraph;
Expand All @@ -36,3 +40,110 @@ shared_ptr<Node> op::v3::ScatterUpdate::clone_with_new_inputs(const OutputVector
return make_shared<v3::ScatterUpdate>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
}

bool op::v3::ScatterUpdate::evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const
{
const auto& data = inputs[0];
const auto& indices = inputs[1];
const auto& updates = inputs[2];
const auto& axis = inputs[3];
const auto& out = outputs[0];

const auto elem_size = data->get_element_type().size();
out->set_shape(data->get_shape());

int64_t axis_val = 0;
switch (axis->get_element_type())
{
case element::Type_t::i8: axis_val = axis->get_data_ptr<element::Type_t::i8>()[0]; break;
case element::Type_t::i16: axis_val = axis->get_data_ptr<element::Type_t::i16>()[0]; break;
case element::Type_t::i32: axis_val = axis->get_data_ptr<element::Type_t::i32>()[0]; break;
case element::Type_t::i64: axis_val = axis->get_data_ptr<element::Type_t::i64>()[0]; break;
case element::Type_t::u8: axis_val = axis->get_data_ptr<element::Type_t::u8>()[0]; break;
case element::Type_t::u16: axis_val = axis->get_data_ptr<element::Type_t::u16>()[0]; break;
case element::Type_t::u32: axis_val = axis->get_data_ptr<element::Type_t::u32>()[0]; break;
case element::Type_t::u64: axis_val = axis->get_data_ptr<element::Type_t::u64>()[0]; break;
default: throw ngraph_error("axis element type is not integral data type");
}

if (axis_val < 0)
{
axis_val =
ngraph::normalize_axis(this, axis_val, static_cast<int64_t>(data->get_shape().size()));
}

std::vector<int64_t> indices_casted_vector;
switch (indices->get_element_type())
{
case element::Type_t::i8:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i8>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::i16:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i16>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::i32:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i32>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::i64:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::i64>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u8:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u8>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u16:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u16>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u32:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u32>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
case element::Type_t::u64:
{
auto indices_ptr = indices->get_data_ptr<element::Type_t::u64>();
indices_casted_vector =
std::vector<int64_t>(indices_ptr, indices_ptr + indices->get_element_count());
break;
}
default: throw ngraph_error("indices element type is not integral data type");
}

runtime::reference::scatter_update(data->get_data_ptr<char>(),
indices_casted_vector.data(),
updates->get_data_ptr<char>(),
axis_val,
out->get_data_ptr<char>(),
elem_size,
data->get_shape(),
indices->get_shape(),
updates->get_shape());

return true;
}
Loading

0 comments on commit 393e929

Please sign in to comment.