From 4a9ac22787455d911c8d5db9c304ded987c73340 Mon Sep 17 00:00:00 2001 From: Bartosz Lesniewski Date: Mon, 30 Aug 2021 12:47:30 +0200 Subject: [PATCH] Revise ScatterElementsUpdate Op (#7162) * Use ngraph rtti macros * Add visitor test * Add SSLT * Add hardcoded refs tests for ScatterElementsUpdate * Add ScatterElementsUpdate to trusted ops list * Add i16 case to backend tests * Add typed testcase generation, check for all supported types * Remove redundant parameters from generateScatterParams --- .../op_reference/scatter_elements_update.cpp | 176 ++++++++++++++++++ .../single_layer/scatter_elements_update.cpp | 51 +++++ .../layer_tests_summary/utils/constants.py | 1 + .../ngraph/op/scatter_elements_update.hpp | 6 +- .../core/src/op/scatter_elements_update.cpp | 2 +- ngraph/test/CMakeLists.txt | 1 + .../visitors/op/scatter_elements_update.cpp | 28 +++ 7 files changed, 260 insertions(+), 5 deletions(-) create mode 100644 docs/template_plugin/tests/functional/op_reference/scatter_elements_update.cpp create mode 100644 inference-engine/tests/functional/inference_engine/serialization/single_layer/scatter_elements_update.cpp create mode 100644 ngraph/test/visitors/op/scatter_elements_update.cpp diff --git a/docs/template_plugin/tests/functional/op_reference/scatter_elements_update.cpp b/docs/template_plugin/tests/functional/op_reference/scatter_elements_update.cpp new file mode 100644 index 00000000000000..3d844bfcf6dc0b --- /dev/null +++ b/docs/template_plugin/tests/functional/op_reference/scatter_elements_update.cpp @@ -0,0 +1,176 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include + +#include "base_reference_test.hpp" + +using namespace reference_tests; +using namespace ngraph; +using namespace InferenceEngine; + +namespace { +struct ScatterElementsUpdateParams { + ScatterElementsUpdateParams(const Tensor& paramData, + const Tensor& paramIndices, + const Tensor& paramUpdates, + const Tensor& paramAxis, + const Tensor& paramExpected) + : input(paramData), + indices(paramIndices), + updates(paramUpdates), + axis(paramAxis), + expected(paramExpected) {} + + Tensor input; + Tensor indices; + Tensor updates; + Tensor axis; + Tensor expected; +}; + +class ReferenceScatterElementsUpdateLayerTest : public testing::TestWithParam, + public CommonReferenceTest { +public: + void SetUp() override { + auto params = GetParam(); + function = CreateFunction(params); + inputData = {params.input.data, + params.indices.data, + params.updates.data, + params.axis.data}; + refOutData = {params.expected.data}; + } + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + auto param = obj.param; + std::ostringstream result; + result << "data_sh=" << param.input.shape; + result << "_data_pr=" << param.input.type; + result << "_indx_sh=" << param.indices.shape; + result << "_indx_pr=" << param.indices.type; + result << "_updt_sh=" << param.updates.shape; + result << "_updt_pr=" << param.updates.type; + result << "_axis_sh=" << param.axis.shape; + result << "_axis_pr=" << param.axis.type; + return result.str(); + } +private: + static std::shared_ptr CreateFunction(const ScatterElementsUpdateParams& params) { + const auto A = std::make_shared(params.input.type, params.input.shape); + const auto B = std::make_shared(params.indices.type, params.indices.shape); + const auto C = std::make_shared(params.updates.type, params.updates.shape); + const auto D = std::make_shared(params.axis.type, params.axis.shape); + auto scatterElts = std::make_shared(A, B, C, D); + return std::make_shared(NodeVector{scatterElts}, ParameterVector{A, B, C, D}); + } +}; + +TEST_P(ReferenceScatterElementsUpdateLayerTest, CompareWithHardcodedRefs) { + Exec(); +} + +template +std::vector generateScatterParams() { + using T = typename element_type_traits::value_type; + using T_INT = typename element_type_traits::value_type; + std::vector scatterParams{ + // axis = 0 + ScatterElementsUpdateParams(Tensor({2, 2}, element::Type(ET), std::vector{1, 2, 3, 4}), // input + Tensor({2, 2}, element::Type(ET_IND), std::vector{1, 1, 0, 0}), // indices + Tensor({2, 2}, element::Type(ET), std::vector{10, 20, 30, 40}), // updates + Tensor({1}, element::Type(ET_IND), std::vector{0}), // axis + Tensor({2, 2}, element::Type(ET), std::vector{30, 40, 10, 20})), // expected + // axis = 1 + ScatterElementsUpdateParams(Tensor({2, 1}, element::Type(ET), std::vector{1, 2}), // input + Tensor({2, 1}, element::Type(ET_IND), std::vector{0, 0}), // indices + Tensor({2, 1}, element::Type(ET), std::vector{10, 20}), // updates + Tensor({1}, element::Type(ET_IND), std::vector{1}), // axis + Tensor({2, 1}, element::Type(ET), std::vector{10, 20})), // expected + }; + return scatterParams; +} + +std::vector generateScatterCombinedParams() { + const std::vector> scatterTypeParams{ + // i16 + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + // i32 + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + // i64 + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + // u32 + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + // u64 + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + // f16 + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + // f32 + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + generateScatterParams(), + }; + std::vector combinedParams; + for (const auto& params : scatterTypeParams) { + combinedParams.insert(combinedParams.end(), params.begin(), params.end()); + } + return combinedParams; +} +INSTANTIATE_TEST_SUITE_P(smoke_ScatterEltsUpdate_With_Hardcoded_Refs, + ReferenceScatterElementsUpdateLayerTest, + ::testing::ValuesIn(generateScatterCombinedParams()), + ReferenceScatterElementsUpdateLayerTest::getTestCaseName); +} // namespace diff --git a/inference-engine/tests/functional/inference_engine/serialization/single_layer/scatter_elements_update.cpp b/inference-engine/tests/functional/inference_engine/serialization/single_layer/scatter_elements_update.cpp new file mode 100644 index 00000000000000..b41246eace7d1e --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/serialization/single_layer/scatter_elements_update.cpp @@ -0,0 +1,51 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "shared_test_classes/single_layer/scatter_elements_update.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; +using namespace ngraph::opset3; + +namespace { +TEST_P(ScatterElementsUpdateLayerTest, Serialize) { + Serialize(); +} +// map> +std::map, std::map, std::vector>> axesShapeInShape { + {{10, 12, 15}, {{{1, 2, 4}, {0, 1, 2}}, {{2, 2, 2}, {-1, -2, -3}}}}, + {{15, 9, 8, 12}, {{{1, 2, 2, 2}, {0, 1, 2, 3}}, {{1, 2, 1, 4}, {-1, -2, -3, -4}}}}, + {{9, 9, 8, 8, 11, 10}, {{{1, 2, 1, 2, 1, 2}, {5, -3}}}}, +}; +// index value should not be random data +const std::vector> idxValue = { + {1, 0, 4, 6, 2, 3, 7, 5} +}; + +const std::vector inputPrecisions = { + InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP16, + InferenceEngine::Precision::I32, +}; + +const std::vector idxPrecisions = { + InferenceEngine::Precision::I32, + InferenceEngine::Precision::I64, +}; + +const auto ScatterEltUpdateCases = ::testing::Combine( + ::testing::ValuesIn(ScatterElementsUpdateLayerTest::combineShapes(axesShapeInShape)), + ::testing::ValuesIn(idxValue), + ::testing::ValuesIn(inputPrecisions), + ::testing::ValuesIn(idxPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +INSTANTIATE_TEST_SUITE_P(smoke_ScatterEltsUpdateSerialization, ScatterElementsUpdateLayerTest, + ScatterEltUpdateCases, ScatterElementsUpdateLayerTest::getTestCaseName); + +} // namespace diff --git a/inference-engine/tests/ie_test_utils/functional_test_utils/layer_tests_summary/utils/constants.py b/inference-engine/tests/ie_test_utils/functional_test_utils/layer_tests_summary/utils/constants.py index e35678eee29ae4..d40dc40480c2e3 100644 --- a/inference-engine/tests/ie_test_utils/functional_test_utils/layer_tests_summary/utils/constants.py +++ b/inference-engine/tests/ie_test_utils/functional_test_utils/layer_tests_summary/utils/constants.py @@ -105,6 +105,7 @@ 'ReverseSequence-1', 'Round-5', 'SpaceToDepth-1', + 'ScatterElementsUpdate-3', 'ScatterNDUpdate-4', 'Select-1', 'ShapeOf-1', diff --git a/ngraph/core/include/ngraph/op/scatter_elements_update.hpp b/ngraph/core/include/ngraph/op/scatter_elements_update.hpp index e49ed04975016d..78847c2a33eb95 100644 --- a/ngraph/core/include/ngraph/op/scatter_elements_update.hpp +++ b/ngraph/core/include/ngraph/op/scatter_elements_update.hpp @@ -15,10 +15,8 @@ namespace op { namespace v3 { class NGRAPH_API ScatterElementsUpdate : public Op { public: - static constexpr NodeTypeInfo type_info{"ScatterElementsUpdate", 3}; - const NodeTypeInfo& get_type_info() const override { - return type_info; - } + NGRAPH_RTTI_DECLARATION; + ScatterElementsUpdate() = default; /// \brief Constructs a ScatterElementsUpdate node diff --git a/ngraph/core/src/op/scatter_elements_update.cpp b/ngraph/core/src/op/scatter_elements_update.cpp index b4880bbb74b834..46275a703f4270 100644 --- a/ngraph/core/src/op/scatter_elements_update.cpp +++ b/ngraph/core/src/op/scatter_elements_update.cpp @@ -13,7 +13,7 @@ using namespace ngraph; using namespace std; -constexpr NodeTypeInfo op::v3::ScatterElementsUpdate::type_info; +NGRAPH_RTTI_DEFINITION(op::ScatterElementsUpdate, "ScatterElementsUpdate", 3); op::v3::ScatterElementsUpdate::ScatterElementsUpdate(const Output& data, const Output& indices, diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 16ad46af9ed1c4..ec621a10509430 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -330,6 +330,7 @@ set(SRC visitors/op/rnn_cell.cpp visitors/op/roi_pooling.cpp visitors/op/round.cpp + visitors/op/scatter_elements_update.cpp visitors/op/select.cpp visitors/op/space_to_depth.cpp visitors/op/selu.cpp diff --git a/ngraph/test/visitors/op/scatter_elements_update.cpp b/ngraph/test/visitors/op/scatter_elements_update.cpp new file mode 100644 index 00000000000000..fdceae959257e4 --- /dev/null +++ b/ngraph/test/visitors/op/scatter_elements_update.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "ngraph/op/util/attr_types.hpp" +#include "ngraph/opsets/opset3.hpp" +#include "util/visitor.hpp" + +using namespace ngraph; +using ngraph::test::NodeBuilder; +using ngraph::test::ValueMap; + +TEST(attributes, scatter_elements_update) { + NodeBuilder::get_ops().register_factory(); + + auto data = std::make_shared(element::f32, Shape{2, 4, 5, 7}); + auto indices = std::make_shared(element::i16, Shape{2, 2, 2, 2}); + auto updates = std::make_shared(element::f32, Shape{2, 2, 2, 2}); + auto axis = std::make_shared(element::i16, Shape{}); + + auto scatter = std::make_shared(data, indices, updates, axis); + NodeBuilder builder(scatter); + + const auto expected_attr_count = 0; + EXPECT_EQ(builder.get_value_map_size(), expected_attr_count); +}