forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reference implementation for ScatterUpdate (openvinotoolkit#1678)
* Reference implementation for ScatterUpdate and use of it in evaluate. * Review comments. Clarify comments. * Update file directory. * Replace scatter_update reference implementation in ngraph/core/reference/ * Remove template code from ScatterUpdate reference implementation * Apply review requests Co-authored-by: mitruska <[email protected]>
- Loading branch information
1 parent
7888ef7
commit 1a5d687
Showing
5 changed files
with
388 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
154 changes: 94 additions & 60 deletions
154
ngraph/core/reference/include/ngraph/runtime/reference/scatter_update.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,120 @@ | ||
// Copyright (C) 2020 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
//***************************************************************************** | ||
// Copyright 2017-2020 Intel Corporation | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
//***************************************************************************** | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include "ngraph/check.hpp" | ||
#include "ngraph/coordinate_transform.hpp" | ||
#include "ngraph/shape.hpp" | ||
|
||
using namespace ngraph; | ||
|
||
namespace ngraph | ||
{ | ||
namespace runtime | ||
{ | ||
namespace reference | ||
{ | ||
template <typename dataType, typename indicesType, typename axisType> | ||
void scatterUpdate(const dataType* inputData, | ||
const indicesType* indices, | ||
const dataType* updates, | ||
const axisType* _axis, | ||
dataType* outBuf, | ||
const Shape& dataShape, | ||
const Shape& indicesShape, | ||
const Shape& updatesShape) | ||
void scatter_update(const char* input_data, | ||
const int64_t* indices, | ||
const char* updates, | ||
const int64_t axis, | ||
char* out_buf, | ||
const size_t elem_size, | ||
const Shape& data_shape, | ||
const Shape& indices_shape, | ||
const Shape& updates_shape) | ||
{ | ||
int rank = static_cast<int>(dataShape.size()); | ||
if (_axis[0] < -rank || _axis[0] > rank - 1) | ||
{ | ||
std::string error = | ||
std::string("ScatterUpdate layer has out of bounds axis value: ") + | ||
std::to_string(_axis[0]); | ||
throw ngraph_error(error); | ||
} | ||
size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0]; | ||
CoordinateTransform indicesTransform{indicesShape}; | ||
// Copy inputs to out | ||
std::memcpy(out_buf, input_data, elem_size * shape_size(data_shape)); | ||
|
||
// Algorithm overview | ||
// data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...] | ||
// where first ... in the data corresponds to first axis dimensions, | ||
// last ... in the data corresponds to the rank(data) - (axis + 1) dimensions. | ||
|
||
// | ||
// for i_coord in indices[m, n, ..., p]: | ||
// # get linear index | ||
// i_idx = index(i_coord) | ||
// # simultaneously iterate over two slices of data with same elements count | ||
// for d_coord in slice data[..., i_idx, ...], | ||
// u_coord in slice updates[..., i_coord, ...] | ||
// data[index(d_coord)] = updates[index(u_coord)] | ||
|
||
Shape dataShapeIter = dataShape; | ||
dataShapeIter.erase(dataShapeIter.begin() + axis); | ||
CoordinateTransform dataTransfIter{dataShapeIter}; | ||
CoordinateTransform indices_transform{indices_shape}; | ||
CoordinateTransform data_transform{data_shape}; | ||
|
||
CoordinateTransform updateTransform{updatesShape}; | ||
CoordinateTransform dataTransform{dataShape}; | ||
size_t indices_ndim = indices_shape.size(); | ||
size_t updates_ndim = updates_shape.size(); | ||
|
||
std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape)); | ||
// Create an outer CoordinateTransform for "update", which would allow to | ||
// iterate only over "indices" dimensions: | ||
// set to "1" all non-indices dimensions | ||
// updates[1, ..., 1, m, n, ..., p, 1, 1,..., 1] | ||
Coordinate updates_indices_start_corner(updates_ndim, 0); | ||
Coordinate updates_indices_end_corner(updates_ndim, 1); | ||
for (size_t i = 0; i < indices_ndim; ++i) | ||
{ | ||
updates_indices_end_corner[axis + i] = updates_shape[axis + i]; | ||
} | ||
CoordinateTransform updates_indices_transform( | ||
updates_shape, updates_indices_start_corner, updates_indices_end_corner); | ||
// Is needed to simultaneously iterate over updates coordinates while | ||
// iterating over indices. | ||
auto updates_indices_coord_iter = updates_indices_transform.begin(); | ||
|
||
for (const Coordinate& indicesCoordIt : indicesTransform) | ||
for (const Coordinate& indices_cord : indices_transform) | ||
{ | ||
const size_t indicesIdx = indicesTransform.index(indicesCoordIt); | ||
const size_t indices_idx = indices_transform.index(indices_cord); | ||
int64_t slice_index = indices[indices_idx]; | ||
|
||
if (indices[indicesIdx] < 0) | ||
{ | ||
std::string error = | ||
std::string("ScatterUpdate layer has negative index value: ") + | ||
std::to_string(indices[indicesIdx]); | ||
throw ngraph_error(error); | ||
} | ||
const size_t idx = static_cast<size_t>(indices[indicesIdx]); | ||
if (dataShape[axis] <= idx) | ||
// Define the extent of coordinates which will be updated. | ||
Coordinate out_start_corner(data_shape.size(), 0); | ||
Coordinate out_end_corner(data_shape); | ||
out_start_corner[axis] = static_cast<size_t>(slice_index); | ||
out_end_corner[axis] = out_start_corner[axis] + 1; | ||
CoordinateTransform out_transform(data_shape, out_start_corner, out_end_corner); | ||
|
||
// Define the CoordinateTransform for updates coordinates. | ||
// All except indices-dimensions. | ||
Coordinate updates_update_start_corner = *updates_indices_coord_iter; | ||
Coordinate updates_update_end_corner(updates_shape); | ||
for (size_t i = 0; i < indices_ndim; ++i) | ||
{ | ||
std::string error = | ||
std::string("ScatterUpdate layer has out of bounds coordinate: ") + | ||
std::to_string(idx) + " on 'data' input on " + std::to_string(axis) + | ||
"th axis"; | ||
throw ngraph_error(error); | ||
updates_update_end_corner[axis + i] = | ||
updates_update_start_corner[axis + i] + 1; | ||
} | ||
|
||
for (const Coordinate& dataCoordIt : dataTransfIter) | ||
// The m, n, .., p symbols stand for values at those axes. | ||
// The m+1 means value at axis m plus 1. | ||
// udpates_shape (start): [ 0, ..., 0, m , n , ... p , 0, ..., 0] | ||
// updates_shape (end): [-1, ..., -1, m+1, n+1, ... p+1, -1, ..., -1] | ||
CoordinateTransform updates_update_transform( | ||
updates_shape, updates_update_start_corner, updates_update_end_corner); | ||
auto updates_update_coord_iter = updates_update_transform.begin(); | ||
for (const Coordinate& out_cord : out_transform) | ||
{ | ||
Coordinate dataCoord = dataCoordIt; | ||
dataCoord.insert(dataCoord.begin() + axis, idx); | ||
const size_t startIndices = dataTransform.index(dataCoord); | ||
|
||
auto updCoord = dataCoordIt; | ||
updCoord.insert( | ||
updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end()); | ||
const size_t startUpd = updateTransform.index(updCoord); | ||
outBuf[startIndices] = updates[startUpd]; | ||
const auto src_idx = | ||
updates_update_transform.index(*updates_update_coord_iter) * elem_size; | ||
std::copy(updates + src_idx, | ||
updates + (src_idx + elem_size), | ||
out_buf + out_transform.index(out_cord) * elem_size); | ||
updates_update_coord_iter++; | ||
} | ||
updates_indices_coord_iter++; | ||
} | ||
} | ||
} // namespace reference | ||
} // namespace runtime | ||
} // namespace ngraph | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.