Skip to content

Commit

Permalink
Revise ScatterElementsUpdate Op (#7162)
Browse files Browse the repository at this point in the history
* Use ngraph rtti macros

* Add visitor test

* Add SSLT

* Add hardcoded refs tests for ScatterElementsUpdate

* Add ScatterElementsUpdate to trusted ops list

* Add i16 case to backend tests

* Add typed testcase generation, check for all supported types

* Remove redundant parameters from generateScatterParams
  • Loading branch information
Bartosz Lesniewski authored Aug 30, 2021
1 parent 4622f2f commit 4a9ac22
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <ie_core.hpp>
#include <ie_ngraph_utils.hpp>
#include <ngraph/ngraph.hpp>
#include <shared_test_classes/base/layer_test_utils.hpp>
#include <tuple>

#include "base_reference_test.hpp"

using namespace reference_tests;
using namespace ngraph;
using namespace InferenceEngine;

namespace {
struct ScatterElementsUpdateParams {
ScatterElementsUpdateParams(const Tensor& paramData,
const Tensor& paramIndices,
const Tensor& paramUpdates,
const Tensor& paramAxis,
const Tensor& paramExpected)
: input(paramData),
indices(paramIndices),
updates(paramUpdates),
axis(paramAxis),
expected(paramExpected) {}

Tensor input;
Tensor indices;
Tensor updates;
Tensor axis;
Tensor expected;
};

class ReferenceScatterElementsUpdateLayerTest : public testing::TestWithParam<ScatterElementsUpdateParams>,
public CommonReferenceTest {
public:
void SetUp() override {
auto params = GetParam();
function = CreateFunction(params);
inputData = {params.input.data,
params.indices.data,
params.updates.data,
params.axis.data};
refOutData = {params.expected.data};
}
static std::string getTestCaseName(const testing::TestParamInfo<ScatterElementsUpdateParams>& obj) {
auto param = obj.param;
std::ostringstream result;
result << "data_sh=" << param.input.shape;
result << "_data_pr=" << param.input.type;
result << "_indx_sh=" << param.indices.shape;
result << "_indx_pr=" << param.indices.type;
result << "_updt_sh=" << param.updates.shape;
result << "_updt_pr=" << param.updates.type;
result << "_axis_sh=" << param.axis.shape;
result << "_axis_pr=" << param.axis.type;
return result.str();
}
private:
static std::shared_ptr<Function> CreateFunction(const ScatterElementsUpdateParams& params) {
const auto A = std::make_shared<op::Parameter>(params.input.type, params.input.shape);
const auto B = std::make_shared<op::Parameter>(params.indices.type, params.indices.shape);
const auto C = std::make_shared<op::Parameter>(params.updates.type, params.updates.shape);
const auto D = std::make_shared<op::Parameter>(params.axis.type, params.axis.shape);
auto scatterElts = std::make_shared<op::ScatterElementsUpdate>(A, B, C, D);
return std::make_shared<Function>(NodeVector{scatterElts}, ParameterVector{A, B, C, D});
}
};

TEST_P(ReferenceScatterElementsUpdateLayerTest, CompareWithHardcodedRefs) {
Exec();
}

template <element::Type_t ET, element::Type_t ET_IND>
std::vector<ScatterElementsUpdateParams> generateScatterParams() {
using T = typename element_type_traits<ET>::value_type;
using T_INT = typename element_type_traits<ET_IND>::value_type;
std::vector<ScatterElementsUpdateParams> scatterParams{
// axis = 0
ScatterElementsUpdateParams(Tensor({2, 2}, element::Type(ET), std::vector<T>{1, 2, 3, 4}), // input
Tensor({2, 2}, element::Type(ET_IND), std::vector<T_INT>{1, 1, 0, 0}), // indices
Tensor({2, 2}, element::Type(ET), std::vector<T>{10, 20, 30, 40}), // updates
Tensor({1}, element::Type(ET_IND), std::vector<T_INT>{0}), // axis
Tensor({2, 2}, element::Type(ET), std::vector<T>{30, 40, 10, 20})), // expected
// axis = 1
ScatterElementsUpdateParams(Tensor({2, 1}, element::Type(ET), std::vector<T>{1, 2}), // input
Tensor({2, 1}, element::Type(ET_IND), std::vector<T_INT>{0, 0}), // indices
Tensor({2, 1}, element::Type(ET), std::vector<T>{10, 20}), // updates
Tensor({1}, element::Type(ET_IND), std::vector<T_INT>{1}), // axis
Tensor({2, 1}, element::Type(ET), std::vector<T>{10, 20})), // expected
};
return scatterParams;
}

std::vector<ScatterElementsUpdateParams> generateScatterCombinedParams() {
const std::vector<std::vector<ScatterElementsUpdateParams>> scatterTypeParams{
// i16
generateScatterParams<element::Type_t::i16, element::Type_t::i8>(),
generateScatterParams<element::Type_t::i16, element::Type_t::u8>(),
generateScatterParams<element::Type_t::i16, element::Type_t::i16>(),
generateScatterParams<element::Type_t::i16, element::Type_t::u16>(),
generateScatterParams<element::Type_t::i16, element::Type_t::i32>(),
generateScatterParams<element::Type_t::i16, element::Type_t::u32>(),
generateScatterParams<element::Type_t::i16, element::Type_t::i64>(),
generateScatterParams<element::Type_t::i16, element::Type_t::u64>(),
// i32
generateScatterParams<element::Type_t::i32, element::Type_t::i8>(),
generateScatterParams<element::Type_t::i32, element::Type_t::u8>(),
generateScatterParams<element::Type_t::i32, element::Type_t::i16>(),
generateScatterParams<element::Type_t::i32, element::Type_t::u16>(),
generateScatterParams<element::Type_t::i32, element::Type_t::i32>(),
generateScatterParams<element::Type_t::i32, element::Type_t::u32>(),
generateScatterParams<element::Type_t::i32, element::Type_t::i64>(),
generateScatterParams<element::Type_t::i32, element::Type_t::u64>(),
// i64
generateScatterParams<element::Type_t::i64, element::Type_t::i8>(),
generateScatterParams<element::Type_t::i64, element::Type_t::u8>(),
generateScatterParams<element::Type_t::i64, element::Type_t::i16>(),
generateScatterParams<element::Type_t::i64, element::Type_t::u16>(),
generateScatterParams<element::Type_t::i64, element::Type_t::i32>(),
generateScatterParams<element::Type_t::i64, element::Type_t::u32>(),
generateScatterParams<element::Type_t::i64, element::Type_t::i64>(),
generateScatterParams<element::Type_t::i64, element::Type_t::u64>(),
// u32
generateScatterParams<element::Type_t::u32, element::Type_t::i8>(),
generateScatterParams<element::Type_t::u32, element::Type_t::u8>(),
generateScatterParams<element::Type_t::u32, element::Type_t::i16>(),
generateScatterParams<element::Type_t::u32, element::Type_t::u16>(),
generateScatterParams<element::Type_t::u32, element::Type_t::i32>(),
generateScatterParams<element::Type_t::u32, element::Type_t::u32>(),
generateScatterParams<element::Type_t::u32, element::Type_t::i64>(),
generateScatterParams<element::Type_t::u32, element::Type_t::u64>(),
// u64
generateScatterParams<element::Type_t::u64, element::Type_t::i8>(),
generateScatterParams<element::Type_t::u64, element::Type_t::u8>(),
generateScatterParams<element::Type_t::u64, element::Type_t::i16>(),
generateScatterParams<element::Type_t::u64, element::Type_t::u16>(),
generateScatterParams<element::Type_t::u64, element::Type_t::i32>(),
generateScatterParams<element::Type_t::u64, element::Type_t::u32>(),
generateScatterParams<element::Type_t::u64, element::Type_t::i64>(),
generateScatterParams<element::Type_t::u64, element::Type_t::u64>(),
// f16
generateScatterParams<element::Type_t::f16, element::Type_t::i8>(),
generateScatterParams<element::Type_t::f16, element::Type_t::u8>(),
generateScatterParams<element::Type_t::f16, element::Type_t::i16>(),
generateScatterParams<element::Type_t::f16, element::Type_t::u16>(),
generateScatterParams<element::Type_t::f16, element::Type_t::i32>(),
generateScatterParams<element::Type_t::f16, element::Type_t::u32>(),
generateScatterParams<element::Type_t::f16, element::Type_t::i64>(),
generateScatterParams<element::Type_t::f16, element::Type_t::u64>(),
// f32
generateScatterParams<element::Type_t::f32, element::Type_t::i8>(),
generateScatterParams<element::Type_t::f32, element::Type_t::u8>(),
generateScatterParams<element::Type_t::f32, element::Type_t::i16>(),
generateScatterParams<element::Type_t::f32, element::Type_t::u16>(),
generateScatterParams<element::Type_t::f32, element::Type_t::i32>(),
generateScatterParams<element::Type_t::f32, element::Type_t::u32>(),
generateScatterParams<element::Type_t::f32, element::Type_t::i64>(),
generateScatterParams<element::Type_t::f32, element::Type_t::u64>(),
};
std::vector<ScatterElementsUpdateParams> combinedParams;
for (const auto& params : scatterTypeParams) {
combinedParams.insert(combinedParams.end(), params.begin(), params.end());
}
return combinedParams;
}
INSTANTIATE_TEST_SUITE_P(smoke_ScatterEltsUpdate_With_Hardcoded_Refs,
ReferenceScatterElementsUpdateLayerTest,
::testing::ValuesIn(generateScatterCombinedParams()),
ReferenceScatterElementsUpdateLayerTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>
#include <ngraph/opsets/opset3.hpp>

#include "shared_test_classes/single_layer/scatter_elements_update.hpp"
#include "common_test_utils/test_constants.hpp"

using namespace LayerTestsDefinitions;
using namespace ngraph::opset3;

namespace {
TEST_P(ScatterElementsUpdateLayerTest, Serialize) {
Serialize();
}
// map<inputShape, map<indicesShape, axis>>
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape {
{{10, 12, 15}, {{{1, 2, 4}, {0, 1, 2}}, {{2, 2, 2}, {-1, -2, -3}}}},
{{15, 9, 8, 12}, {{{1, 2, 2, 2}, {0, 1, 2, 3}}, {{1, 2, 1, 4}, {-1, -2, -3, -4}}}},
{{9, 9, 8, 8, 11, 10}, {{{1, 2, 1, 2, 1, 2}, {5, -3}}}},
};
// index value should not be random data
const std::vector<std::vector<size_t>> idxValue = {
{1, 0, 4, 6, 2, 3, 7, 5}
};

const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
};

const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};

const auto ScatterEltUpdateCases = ::testing::Combine(
::testing::ValuesIn(ScatterElementsUpdateLayerTest::combineShapes(axesShapeInShape)),
::testing::ValuesIn(idxValue),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);

INSTANTIATE_TEST_SUITE_P(smoke_ScatterEltsUpdateSerialization, ScatterElementsUpdateLayerTest,
ScatterEltUpdateCases, ScatterElementsUpdateLayerTest::getTestCaseName);

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
'ReverseSequence-1',
'Round-5',
'SpaceToDepth-1',
'ScatterElementsUpdate-3',
'ScatterNDUpdate-4',
'Select-1',
'ShapeOf-1',
Expand Down
6 changes: 2 additions & 4 deletions ngraph/core/include/ngraph/op/scatter_elements_update.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ namespace op {
namespace v3 {
class NGRAPH_API ScatterElementsUpdate : public Op {
public:
static constexpr NodeTypeInfo type_info{"ScatterElementsUpdate", 3};
const NodeTypeInfo& get_type_info() const override {
return type_info;
}
NGRAPH_RTTI_DECLARATION;

ScatterElementsUpdate() = default;
/// \brief Constructs a ScatterElementsUpdate node

Expand Down
2 changes: 1 addition & 1 deletion ngraph/core/src/op/scatter_elements_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
using namespace ngraph;
using namespace std;

constexpr NodeTypeInfo op::v3::ScatterElementsUpdate::type_info;
NGRAPH_RTTI_DEFINITION(op::ScatterElementsUpdate, "ScatterElementsUpdate", 3);

op::v3::ScatterElementsUpdate::ScatterElementsUpdate(const Output<Node>& data,
const Output<Node>& indices,
Expand Down
1 change: 1 addition & 0 deletions ngraph/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ set(SRC
visitors/op/rnn_cell.cpp
visitors/op/roi_pooling.cpp
visitors/op/round.cpp
visitors/op/scatter_elements_update.cpp
visitors/op/select.cpp
visitors/op/space_to_depth.cpp
visitors/op/selu.cpp
Expand Down
28 changes: 28 additions & 0 deletions ngraph/test/visitors/op/scatter_elements_update.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "util/visitor.hpp"

using namespace ngraph;
using ngraph::test::NodeBuilder;
using ngraph::test::ValueMap;

TEST(attributes, scatter_elements_update) {
NodeBuilder::get_ops().register_factory<opset3::ScatterElementsUpdate>();

auto data = std::make_shared<op::Parameter>(element::f32, Shape{2, 4, 5, 7});
auto indices = std::make_shared<op::Parameter>(element::i16, Shape{2, 2, 2, 2});
auto updates = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 2, 2});
auto axis = std::make_shared<op::Parameter>(element::i16, Shape{});

auto scatter = std::make_shared<opset3::ScatterElementsUpdate>(data, indices, updates, axis);
NodeBuilder builder(scatter);

const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}

0 comments on commit 4a9ac22

Please sign in to comment.