From fbb575875610e76fcb7c79938d326dbd4781db06 Mon Sep 17 00:00:00 2001 From: Maxim Andronov Date: Fri, 7 Aug 2020 16:09:28 +0300 Subject: [PATCH] [NGraph] Add scatterNDUpdate and scatterUpdate reference implementations (#1494) --- .../src/mkldnn_plugin/CMakeLists.txt | 1 - .../src/mkldnn_plugin/nodes/list_tbl.hpp | 1 - .../src/mkldnn_plugin/nodes/scatter.cpp | 188 ------------------ .../single_layer_tests/scatter_ND_update.cpp | 5 +- .../single_layer_tests/scatter_update.cpp | 5 +- .../include/ngraph_functions/builders.hpp | 1 + .../src/scatter_ND_update.cpp | 6 +- .../runtime/interpreter/int_executable.hpp | 77 +++++++ .../runtime/interpreter/opset_int_tbl.hpp | 2 + .../reference/scatter_nd_update.hpp | 63 ++++++ .../interpreter/reference/scatter_update.hpp | 86 ++++++++ 11 files changed, 235 insertions(+), 200 deletions(-) delete mode 100644 inference-engine/src/mkldnn_plugin/nodes/scatter.cpp create mode 100644 ngraph/test/runtime/interpreter/reference/scatter_nd_update.hpp create mode 100644 ngraph/test/runtime/interpreter/reference/scatter_update.hpp diff --git a/inference-engine/src/mkldnn_plugin/CMakeLists.txt b/inference-engine/src/mkldnn_plugin/CMakeLists.txt index 9e24a738d1466d..b1122bb8c5a3a1 100644 --- a/inference-engine/src/mkldnn_plugin/CMakeLists.txt +++ b/inference-engine/src/mkldnn_plugin/CMakeLists.txt @@ -64,7 +64,6 @@ set(LAYERS ${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_tree.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/grn.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/non_max_suppression.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/nodes/scatter.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/log_softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/math.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/one_hot.cpp diff --git a/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp b/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp index 1e57b9c7fb3418..2015c4cbf7dfc6 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp +++ b/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp @@ -51,7 +51,6 @@ MKLDNN_EXTENSION_NODE(FillImpl, Fill); MKLDNN_EXTENSION_NODE(UniqueImpl, Unique); MKLDNN_EXTENSION_NODE(PSROIPoolingImpl, PSROIPooling); MKLDNN_EXTENSION_NODE(DepthToSpaceImpl, DepthToSpace); -MKLDNN_EXTENSION_NODE(ScatterImpl, ScatterUpdate); MKLDNN_EXTENSION_NODE(OneHotImpl, OneHot); MKLDNN_EXTENSION_NODE(BroadcastImpl, Broadcast); MKLDNN_EXTENSION_NODE(ExperimentalSparseWeightedReduceImpl, ExperimentalSparseWeightedSum); diff --git a/inference-engine/src/mkldnn_plugin/nodes/scatter.cpp b/inference-engine/src/mkldnn_plugin/nodes/scatter.cpp deleted file mode 100644 index 1a4a3edb928523..00000000000000 --- a/inference-engine/src/mkldnn_plugin/nodes/scatter.cpp +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (C) 2018-2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "base.hpp" - -#include -#include -#include -#include -#include -#include -#include "ie_parallel.hpp" -#include "common/simple_copy.h" - -namespace InferenceEngine { -namespace Extensions { -namespace Cpu { - -class ScatterImpl: public ExtLayerBase { -public: - explicit ScatterImpl(const CNNLayer* layer) { - try { - if (layer->insData.size() != 3 || layer->outData.size() != 1) - THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output tensors!"; - - - Precision inIdxPrecision = layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getPrecision(); - if (inIdxPrecision != Precision::FP32 && inIdxPrecision != Precision::I32) - THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indexes' precision. Only FP32 or I32 are supported!"; - - Precision inDataPrecision = layer->insData[SCATTER_DATA].lock()->getTensorDesc().getPrecision(); - if (inDataPrecision != layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getPrecision()) - THROW_IE_EXCEPTION << layer->name << " Precision should be equal for input tensors 'Data' and 'Updates'"; - - // Remove redundant dimensions - const SizeVector& data_dims = layer->insData[SCATTER_DATA].lock()->getTensorDesc().getDims(); - if (data_dims.size() == 0 || - (data_dims.size() == 1 && data_dims[0] == 1) || - layer->insData[SCATTER_DATA].lock()->getTensorDesc().getLayout() == Layout::SCALAR) - THROW_IE_EXCEPTION << layer->name << " 'Data' tensor rank should be >= 1"; - - axis = layer->GetParamAsInt("axis", 0); - - IE_ASSERT(-static_cast(data_dims.size()) <= axis && axis < static_cast(data_dims.size())) - << layer->name << " Incorrect input parameters dimensions and axis number!"; - - if (axis < 0) - axis += data_dims.size(); - - SizeVector dst_dims = layer->outData[0]->getTensorDesc().getDims(); - if (data_dims != dst_dims) - THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output dimensions!"; - - SizeVector idx_dims = layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getDims(); - if (idx_dims.size() == 0 || - (idx_dims.size() == 1 && idx_dims[0] == 1) || - layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getLayout() == Layout::SCALAR) - THROW_IE_EXCEPTION << layer->name << " 'Indexes' tensor rank should be >= 1"; - - SizeVector upd_dims = layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getDims(); - if (layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getLayout() == Layout::SCALAR) - THROW_IE_EXCEPTION << layer->name << " 'Indexes' tensor rank should be >= 1"; - - if (idx_dims != upd_dims) - THROW_IE_EXCEPTION << layer->name << " Incorrect number of 'indexes' and 'updates' tensors dimension"; - - for (size_t i = 0; i < idx_dims.size(); i++) { - if (i == static_cast(axis)) continue; - if (idx_dims[i] > data_dims[i]) - THROW_IE_EXCEPTION << layer->name << " Incorrect number of data and indexes dimensions!"; - } - - LayerConfig config; - DataConfig dataConfig, indexesConfig, updatesConfig; - Precision dataPrecision = layer->outData[0]->getTensorDesc().getPrecision(); - dataConfig.desc = TensorDesc(dataPrecision, data_dims, - layer->insData[SCATTER_DATA].lock()->getTensorDesc().getLayout()); - dataConfig.constant = false; - dataConfig.inPlace = 0; - config.inConfs.push_back(dataConfig); - indexesConfig.desc = TensorDesc(inIdxPrecision, idx_dims, - layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getLayout()); - config.inConfs.push_back(indexesConfig); - updatesConfig.desc = TensorDesc(dataPrecision, upd_dims, - layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getLayout()); - config.inConfs.push_back(updatesConfig); - - DataConfig outConfig; - outConfig.desc = TensorDesc(dataPrecision, dst_dims, layer->outData[0]->getTensorDesc().getLayout()); - outConfig.constant = false; - outConfig.inPlace = 0; - config.outConfs.push_back(outConfig); - config.dynBatchSupport = false; - confs.push_back(config); - } catch (InferenceEngine::details::InferenceEngineException &ex) { - errorMsg = ex.what(); - } - } - - StatusCode execute(std::vector& inputs, std::vector& outputs, ResponseDesc *resp) noexcept override { - switch (inputs[SCATTER_INDEXES]->getTensorDesc().getPrecision()) { - case Precision::FP32: - scatter(inputs[SCATTER_DATA], inputs[SCATTER_INDEXES], inputs[SCATTER_UPDATES], outputs[0]); - break; - case Precision::I32: - scatter(inputs[SCATTER_DATA], inputs[SCATTER_INDEXES], inputs[SCATTER_UPDATES], outputs[0]); - break; - default: - return GENERAL_ERROR; - } - - return OK; - } - -private: - template - void scatter(Blob::Ptr data, Blob::Ptr indexes, Blob::Ptr updates, Blob::Ptr output) { - const uint8_t *src_data = data->cbuffer().as() + data->getTensorDesc().getBlockingDesc().getOffsetPadding(); - const index_t *src_index = indexes->cbuffer().as() + indexes->getTensorDesc().getBlockingDesc().getOffsetPadding(); - const uint8_t *src_updates = updates->cbuffer().as() + updates->getTensorDesc().getBlockingDesc().getOffsetPadding(); - uint8_t *dst_data = output->cbuffer().as() + output->getTensorDesc().getBlockingDesc().getOffsetPadding(); - size_t data_size = data->getTensorDesc().getPrecision().size(); - - InferenceEngine::SizeVector index_dims = indexes->getTensorDesc().getDims(); - InferenceEngine::SizeVector data_dims = data->getTensorDesc().getDims(); - InferenceEngine::SizeVector dataStrides = data->getTensorDesc().getBlockingDesc().getStrides(); - - if (src_data != dst_data) { - parallel_nt(0, [&](const int ithr, const int nthr) { - size_t start = 0, end = 0; - splitter(output->size(), nthr, ithr, start, end); - size_t size = (end - start) * data_size; - start *= data_size; - simple_copy(dst_data + start, size, src_data + start, size); - }); - } - - parallel_nt(0, [&](const int ithr, const int nthr) { - int j; - size_t i, dst_idx = 0, start = 0, end = 0; - SizeVector counters(index_dims.size(), 0); - splitter(indexes->size(), nthr, ithr, start, end); - for (j = index_dims.size() - 1, i = start; j >= 0; j--) { - counters[j] = i % index_dims[j]; - i /= index_dims[j]; - } - - for (i = 0; i < static_cast(axis); ++i) - dst_idx += counters[i] * dataStrides[i]; - for (i++; i < data_dims.size(); ++i) - dst_idx += counters[i] * dataStrides[i]; - - for (size_t iwork = start; iwork < end; iwork++) { - unsigned int idx = static_cast(src_index[iwork]); - if (idx < data_dims[axis]) - simple_copy(dst_data + data_size * (dst_idx + idx * dataStrides[axis]), data_size, - src_updates + iwork * data_size, data_size); - - for (j = index_dims.size() - 1; j >= 0; j--) { - counters[j]++; - if (counters[j] < index_dims[j]) { - if (j != static_cast(axis)) - dst_idx += dataStrides[j]; - break; - } else { - counters[j] = 0; - for (dst_idx = 0, i = 0; i < static_cast(axis); ++i) - dst_idx += counters[i] * dataStrides[i]; - for (i++; i < data_dims.size(); ++i) - dst_idx += counters[i] * dataStrides[i]; - } - } - } - }); - } - - int axis = 0; - const size_t SCATTER_DATA = 0; - const size_t SCATTER_INDEXES = 1; - const size_t SCATTER_UPDATES = 2; -}; - -REG_FACTORY_FOR(ScatterImpl, ScatterUpdate); - -} // namespace Cpu -} // namespace Extensions -} // namespace InferenceEngine diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/scatter_ND_update.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/scatter_ND_update.cpp index a571cb6f867b22..329a5dc5b1c071 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/scatter_ND_update.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/scatter_ND_update.cpp @@ -36,7 +36,6 @@ const auto ScatterNDUpdateCases = ::testing::Combine( ::testing::Values(CommonTestUtils::DEVICE_CPU) ); -// open after ops support in ngraph merged -// INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName); -} // namespace \ No newline at end of file +} // namespace diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/scatter_update.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/scatter_update.cpp index f915fef1325238..dcb22e58c1430a 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/scatter_update.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/scatter_update.cpp @@ -41,7 +41,6 @@ const auto ScatterUpdateCase = ::testing::Combine( ::testing::Values(CommonTestUtils::DEVICE_CPU) ); -// open after ngraph reference implementation merged -// INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName); -} // namespace \ No newline at end of file +} // namespace diff --git a/inference-engine/tests/ngraph_functions/include/ngraph_functions/builders.hpp b/inference-engine/tests/ngraph_functions/include/ngraph_functions/builders.hpp index f96c6c54932683..8c359a56f94ba3 100644 --- a/inference-engine/tests/ngraph_functions/include/ngraph_functions/builders.hpp +++ b/inference-engine/tests/ngraph_functions/include/ngraph_functions/builders.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "ngraph_functions/utils/data_utils.hpp" diff --git a/inference-engine/tests/ngraph_functions/src/scatter_ND_update.cpp b/inference-engine/tests/ngraph_functions/src/scatter_ND_update.cpp index 547d5008d6be9e..5928fd3ddcbd88 100644 --- a/inference-engine/tests/ngraph_functions/src/scatter_ND_update.cpp +++ b/inference-engine/tests/ngraph_functions/src/scatter_ND_update.cpp @@ -13,10 +13,8 @@ std::shared_ptr makeScatterNDUpdate(const ngraph::Output &in const std::vector& indices, const ngraph::Output &update) { auto indicesNode = std::make_shared(indicesType, indicesShape, indices); - // blocked by ngraph merge - // auto dtsNode = std::make_shared(in, indicesNode, update); - // return dtsNode; - return nullptr; + auto dtsNode = std::make_shared(in, indicesNode, update); + return dtsNode; } } // namespace builder diff --git a/ngraph/test/runtime/interpreter/int_executable.hpp b/ngraph/test/runtime/interpreter/int_executable.hpp index f0bb110bc9fd69..cb7abc45dfdbe1 100644 --- a/ngraph/test/runtime/interpreter/int_executable.hpp +++ b/ngraph/test/runtime/interpreter/int_executable.hpp @@ -93,6 +93,8 @@ #include "op/group_conv.hpp" #include "reference/detection_output.hpp" +#include "reference/scatter_nd_update.hpp" +#include "reference/scatter_update.hpp" namespace ngraph { @@ -1144,6 +1146,81 @@ class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : publ break; } + case OP_TYPEID::ScatterNDUpdate_v3: + { + const op::ScatterNDUpdate* scatterNDUpd = + static_cast(&node); + auto idxType = scatterNDUpd->get_input_element_type(1); + if (idxType == element::i32) + { + reference::scatterNdUpdate(args[0]->get_data_ptr(), + args[1]->get_data_ptr(), + args[2]->get_data_ptr(), + out[0]->get_data_ptr(), + node.get_input_shape(0), + node.get_input_shape(1), + node.get_input_shape(2)); + } + else if (idxType == element::i64) + { + reference::scatterNdUpdate(args[0]->get_data_ptr(), + args[1]->get_data_ptr(), + args[2]->get_data_ptr(), + out[0]->get_data_ptr(), + node.get_input_shape(0), + node.get_input_shape(1), + node.get_input_shape(2)); + } + else + { + throw ngraph_error( + "ScatterNDUpdate layer support only i32 and i64 'indices' input precision!"); + } + + break; + } + case OP_TYPEID::ScatterUpdate_v3: + { + const op::v3::ScatterUpdate* scatterUpd = + static_cast(&node); + + if (scatterUpd->get_input_element_type(3) != element::i64) + throw ngraph_error( + "ScatterNDUpdate layer support only i64 'axis' input precision!"); + + auto idxType = scatterUpd->get_input_element_type(1); + if (idxType == element::i32) + { + reference::scatterUpdate( + args[0]->get_data_ptr(), + args[1]->get_data_ptr(), + args[2]->get_data_ptr(), + args[3]->get_data_ptr(), + out[0]->get_data_ptr(), + node.get_input_shape(0), + node.get_input_shape(1), + node.get_input_shape(2)); + } + else if (idxType == element::i64) + { + reference::scatterUpdate( + args[0]->get_data_ptr(), + args[1]->get_data_ptr(), + args[2]->get_data_ptr(), + args[3]->get_data_ptr(), + out[0]->get_data_ptr(), + node.get_input_shape(0), + node.get_input_shape(1), + node.get_input_shape(2)); + } + else + { + throw ngraph_error( + "ScatterUpdate layer support only i32 and i64 'indices' input precision!"); + } + + break; + } // Fused Ops are not supported in interpreter. They need to be decomposed before execution case OP_TYPEID::DepthToSpace: diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index 9677db2fab4827..90b5e74f3c9906 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -37,4 +37,6 @@ NGRAPH_OP(EmbeddingSegmentsSum, op::v3) NGRAPH_OP(ExtractImagePatches, op::v3) NGRAPH_OP(ShapeOf, op::v3) NGRAPH_OP(NonZero, op::v3) +NGRAPH_OP(ScatterNDUpdate, op::v3) +NGRAPH_OP(ScatterUpdate, op::v3) #undef ID_SUFFIX diff --git a/ngraph/test/runtime/interpreter/reference/scatter_nd_update.hpp b/ngraph/test/runtime/interpreter/reference/scatter_nd_update.hpp new file mode 100644 index 00000000000000..37d3b5acb0f577 --- /dev/null +++ b/ngraph/test/runtime/interpreter/reference/scatter_nd_update.hpp @@ -0,0 +1,63 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/coordinate_transform.hpp" +#include "ngraph/shape.hpp" + +using namespace ngraph; + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + template + void scatterNdUpdate(const dataType* inputData, + const indicesType* indices, + const dataType* updates, + dataType* outBuf, + const Shape& dataShape, + const Shape& indicesShape, + const Shape& updatesShape) + { + size_t numSlices = 1; + size_t sliceSize = 1; + for (size_t i = 0; i < indicesShape.size() - 1; i++) + { + numSlices *= indicesShape[i]; + } + for (size_t i = indicesShape.size() - 1; i < updatesShape.size(); i++) + { + sliceSize *= updatesShape[i]; + } + + const size_t k = indicesShape.back(); + std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape)); + CoordinateTransform dataTransform{dataShape}; + + for (size_t i = 0; i < numSlices; i++) + { + Coordinate coord; + for (size_t j = 0; j < k; j++) + { + coord.push_back(indices[i * k + j]); + } + for (size_t j = k; j < dataShape.size(); j++) + { + coord.push_back(0); + } + + const size_t startDataIdx = dataTransform.index(coord); + for (size_t j = 0; j < sliceSize; j++) + { + outBuf[startDataIdx + j] = updates[i * sliceSize + j]; + } + } + } + } // namespace reference + } // namespace runtime +} // namespace ngraph diff --git a/ngraph/test/runtime/interpreter/reference/scatter_update.hpp b/ngraph/test/runtime/interpreter/reference/scatter_update.hpp new file mode 100644 index 00000000000000..e3cae8c014750b --- /dev/null +++ b/ngraph/test/runtime/interpreter/reference/scatter_update.hpp @@ -0,0 +1,86 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "ngraph/coordinate_transform.hpp" +#include "ngraph/shape.hpp" + +using namespace ngraph; + +namespace ngraph +{ + namespace runtime + { + namespace reference + { + template + void scatterUpdate(const dataType* inputData, + const indicesType* indices, + const dataType* updates, + const axisType* _axis, + dataType* outBuf, + const Shape& dataShape, + const Shape& indicesShape, + const Shape& updatesShape) + { + int rank = static_cast(dataShape.size()); + if (_axis[0] < -rank || _axis[0] > rank - 1) + { + std::string error = + std::string("ScatterUpdate layer has out of bounds axis value: ") + + std::to_string(_axis[0]); + throw ngraph_error(error); + } + size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0]; + CoordinateTransform indicesTransform{indicesShape}; + + Shape dataShapeIter = dataShape; + dataShapeIter.erase(dataShapeIter.begin() + axis); + CoordinateTransform dataTransfIter{dataShapeIter}; + + CoordinateTransform updateTransform{updatesShape}; + CoordinateTransform dataTransform{dataShape}; + + std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape)); + + for (const Coordinate& indicesCoordIt : indicesTransform) + { + const size_t indicesIdx = indicesTransform.index(indicesCoordIt); + + if (indices[indicesIdx] < 0) + { + std::string error = + std::string("ScatterUpdate layer has negative index value: ") + + std::to_string(indices[indicesIdx]); + throw ngraph_error(error); + } + const size_t idx = static_cast(indices[indicesIdx]); + if (dataShape[axis] <= idx) + { + std::string error = + std::string("ScatterUpdate layer has out of bounds coordinate: ") + + std::to_string(idx) + " on 'data' input on " + std::to_string(axis) + + "th axis"; + throw ngraph_error(error); + } + + for (const Coordinate& dataCoordIt : dataTransfIter) + { + Coordinate dataCoord = dataCoordIt; + dataCoord.insert(dataCoord.begin() + axis, idx); + const size_t startIndices = dataTransform.index(dataCoord); + + auto updCoord = dataCoordIt; + updCoord.insert( + updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end()); + const size_t startUpd = updateTransform.index(updCoord); + outBuf[startIndices] = updates[startUpd]; + } + } + } + } // namespace reference + } // namespace runtime +} // namespace ngraph