Skip to content

Commit

Permalink
[NGraph] Add scatterNDUpdate and scatterUpdate reference implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
mandrono committed Aug 3, 2020
1 parent e273820 commit 1f09059
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 200 deletions.
1 change: 0 additions & 1 deletion inference-engine/src/mkldnn_plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ set(LAYERS
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_tree.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/grn.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/non_max_suppression.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/scatter.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/log_softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/math.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/one_hot.cpp
Expand Down
1 change: 0 additions & 1 deletion inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ MKLDNN_EXTENSION_NODE(FillImpl, Fill);
MKLDNN_EXTENSION_NODE(UniqueImpl, Unique);
MKLDNN_EXTENSION_NODE(PSROIPoolingImpl, PSROIPooling);
MKLDNN_EXTENSION_NODE(DepthToSpaceImpl, DepthToSpace);
MKLDNN_EXTENSION_NODE(ScatterImpl, ScatterUpdate);
MKLDNN_EXTENSION_NODE(OneHotImpl, OneHot);
MKLDNN_EXTENSION_NODE(BroadcastImpl, Broadcast);
MKLDNN_EXTENSION_NODE(ExperimentalSparseWeightedReduceImpl, ExperimentalSparseWeightedSum);
Expand Down
188 changes: 0 additions & 188 deletions inference-engine/src/mkldnn_plugin/nodes/scatter.cpp

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ const auto ScatterNDUpdateCases = ::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_CPU)
);

// open after ops support in ngraph merged
// INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);

} // namespace
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ const auto ScatterUpdateCase = ::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_CPU)
);

// open after ngraph reference implementation merged
// INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName);

} // namespace
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>

#include "ngraph_functions/utils/data_utils.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ std::shared_ptr<ngraph::Node> makeScatterNDUpdate(const ngraph::Output<Node> &in
const std::vector<size_t>& indices,
const ngraph::Output<Node> &update) {
auto indicesNode = std::make_shared<ngraph::opset1::Constant>(indicesType, indicesShape, indices);
// blocked by ngraph merge
// auto dtsNode = std::make_shared<ngraph::opset3::ScatterNDUpdate>(in, indicesNode, update);
// return dtsNode;
return nullptr;
auto dtsNode = std::make_shared<ngraph::opset4::ScatterNDUpdate>(in, indicesNode, update);
return dtsNode;
}

} // namespace builder
Expand Down
78 changes: 78 additions & 0 deletions ngraph/test/runtime/interpreter/int_executable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@
#include "op/convolution.hpp"
#include "op/group_conv.hpp"

#include "reference/scatter_nd_update.hpp"
#include "reference/scatter_update.hpp"

namespace ngraph
{
namespace runtime
Expand Down Expand Up @@ -1129,6 +1132,81 @@ class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : publ
}
break;
}
case OP_TYPEID::ScatterNDUpdate_v3:
{
const op::ScatterNDUpdate* scatterNDUpd =
static_cast<const op::v3::ScatterNDUpdate*>(&node);
auto idxType = scatterNDUpd->get_input_element_type(1);
if (idxType == element::i32)
{
reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int32_t>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else if (idxType == element::i64)
{
reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int64_t>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else
{
throw ngraph_error(
"ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
}

break;
}
case OP_TYPEID::ScatterUpdate_v3:
{
const op::v3::ScatterUpdate* scatterUpd =
static_cast<const op::v3::ScatterUpdate*>(&node);

if (scatterUpd->get_input_element_type(3) != element::i64)
throw ngraph_error(
"ScatterNDUpdate layer support only i64 'axis' input precision!");

auto idxType = scatterUpd->get_input_element_type(1);
if (idxType == element::i32)
{
reference::scatterUpdate<T, int32_t, int64_t>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int32_t>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else if (idxType == element::i64)
{
reference::scatterUpdate<T, int64_t, int64_t>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int64_t>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else
{
throw ngraph_error(
"ScatterUpdate layer support only i32 and i64 'indices' input precision!");
}

break;
}

// Fused Ops are not supported in interpreter. They need to be decomposed before execution
case OP_TYPEID::DepthToSpace:
Expand Down
2 changes: 2 additions & 0 deletions ngraph/test/runtime/interpreter/opset_int_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ NGRAPH_OP(EmbeddingSegmentsSum, op::v3)
NGRAPH_OP(ExtractImagePatches, op::v3)
NGRAPH_OP(ShapeOf, op::v3)
NGRAPH_OP(NonZero, op::v3)
NGRAPH_OP(ScatterNDUpdate, op::v3)
NGRAPH_OP(ScatterUpdate, op::v3)
#undef ID_SUFFIX
Loading

0 comments on commit 1f09059

Please sign in to comment.