-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revise ScatterElementsUpdate Op (#7162)
* Use ngraph rtti macros * Add visitor test * Add SSLT * Add hardcoded refs tests for ScatterElementsUpdate * Add ScatterElementsUpdate to trusted ops list * Add i16 case to backend tests * Add typed testcase generation, check for all supported types * Remove redundant parameters from generateScatterParams
- Loading branch information
Bartosz Lesniewski
authored
Aug 30, 2021
1 parent
4622f2f
commit 4a9ac22
Showing
7 changed files
with
260 additions
and
5 deletions.
There are no files selected for viewing
176 changes: 176 additions & 0 deletions
176
docs/template_plugin/tests/functional/op_reference/scatter_elements_update.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <ie_core.hpp> | ||
#include <ie_ngraph_utils.hpp> | ||
#include <ngraph/ngraph.hpp> | ||
#include <shared_test_classes/base/layer_test_utils.hpp> | ||
#include <tuple> | ||
|
||
#include "base_reference_test.hpp" | ||
|
||
using namespace reference_tests; | ||
using namespace ngraph; | ||
using namespace InferenceEngine; | ||
|
||
namespace { | ||
struct ScatterElementsUpdateParams { | ||
ScatterElementsUpdateParams(const Tensor& paramData, | ||
const Tensor& paramIndices, | ||
const Tensor& paramUpdates, | ||
const Tensor& paramAxis, | ||
const Tensor& paramExpected) | ||
: input(paramData), | ||
indices(paramIndices), | ||
updates(paramUpdates), | ||
axis(paramAxis), | ||
expected(paramExpected) {} | ||
|
||
Tensor input; | ||
Tensor indices; | ||
Tensor updates; | ||
Tensor axis; | ||
Tensor expected; | ||
}; | ||
|
||
class ReferenceScatterElementsUpdateLayerTest : public testing::TestWithParam<ScatterElementsUpdateParams>, | ||
public CommonReferenceTest { | ||
public: | ||
void SetUp() override { | ||
auto params = GetParam(); | ||
function = CreateFunction(params); | ||
inputData = {params.input.data, | ||
params.indices.data, | ||
params.updates.data, | ||
params.axis.data}; | ||
refOutData = {params.expected.data}; | ||
} | ||
static std::string getTestCaseName(const testing::TestParamInfo<ScatterElementsUpdateParams>& obj) { | ||
auto param = obj.param; | ||
std::ostringstream result; | ||
result << "data_sh=" << param.input.shape; | ||
result << "_data_pr=" << param.input.type; | ||
result << "_indx_sh=" << param.indices.shape; | ||
result << "_indx_pr=" << param.indices.type; | ||
result << "_updt_sh=" << param.updates.shape; | ||
result << "_updt_pr=" << param.updates.type; | ||
result << "_axis_sh=" << param.axis.shape; | ||
result << "_axis_pr=" << param.axis.type; | ||
return result.str(); | ||
} | ||
private: | ||
static std::shared_ptr<Function> CreateFunction(const ScatterElementsUpdateParams& params) { | ||
const auto A = std::make_shared<op::Parameter>(params.input.type, params.input.shape); | ||
const auto B = std::make_shared<op::Parameter>(params.indices.type, params.indices.shape); | ||
const auto C = std::make_shared<op::Parameter>(params.updates.type, params.updates.shape); | ||
const auto D = std::make_shared<op::Parameter>(params.axis.type, params.axis.shape); | ||
auto scatterElts = std::make_shared<op::ScatterElementsUpdate>(A, B, C, D); | ||
return std::make_shared<Function>(NodeVector{scatterElts}, ParameterVector{A, B, C, D}); | ||
} | ||
}; | ||
|
||
TEST_P(ReferenceScatterElementsUpdateLayerTest, CompareWithHardcodedRefs) { | ||
Exec(); | ||
} | ||
|
||
template <element::Type_t ET, element::Type_t ET_IND> | ||
std::vector<ScatterElementsUpdateParams> generateScatterParams() { | ||
using T = typename element_type_traits<ET>::value_type; | ||
using T_INT = typename element_type_traits<ET_IND>::value_type; | ||
std::vector<ScatterElementsUpdateParams> scatterParams{ | ||
// axis = 0 | ||
ScatterElementsUpdateParams(Tensor({2, 2}, element::Type(ET), std::vector<T>{1, 2, 3, 4}), // input | ||
Tensor({2, 2}, element::Type(ET_IND), std::vector<T_INT>{1, 1, 0, 0}), // indices | ||
Tensor({2, 2}, element::Type(ET), std::vector<T>{10, 20, 30, 40}), // updates | ||
Tensor({1}, element::Type(ET_IND), std::vector<T_INT>{0}), // axis | ||
Tensor({2, 2}, element::Type(ET), std::vector<T>{30, 40, 10, 20})), // expected | ||
// axis = 1 | ||
ScatterElementsUpdateParams(Tensor({2, 1}, element::Type(ET), std::vector<T>{1, 2}), // input | ||
Tensor({2, 1}, element::Type(ET_IND), std::vector<T_INT>{0, 0}), // indices | ||
Tensor({2, 1}, element::Type(ET), std::vector<T>{10, 20}), // updates | ||
Tensor({1}, element::Type(ET_IND), std::vector<T_INT>{1}), // axis | ||
Tensor({2, 1}, element::Type(ET), std::vector<T>{10, 20})), // expected | ||
}; | ||
return scatterParams; | ||
} | ||
|
||
std::vector<ScatterElementsUpdateParams> generateScatterCombinedParams() { | ||
const std::vector<std::vector<ScatterElementsUpdateParams>> scatterTypeParams{ | ||
// i16 | ||
generateScatterParams<element::Type_t::i16, element::Type_t::i8>(), | ||
generateScatterParams<element::Type_t::i16, element::Type_t::u8>(), | ||
generateScatterParams<element::Type_t::i16, element::Type_t::i16>(), | ||
generateScatterParams<element::Type_t::i16, element::Type_t::u16>(), | ||
generateScatterParams<element::Type_t::i16, element::Type_t::i32>(), | ||
generateScatterParams<element::Type_t::i16, element::Type_t::u32>(), | ||
generateScatterParams<element::Type_t::i16, element::Type_t::i64>(), | ||
generateScatterParams<element::Type_t::i16, element::Type_t::u64>(), | ||
// i32 | ||
generateScatterParams<element::Type_t::i32, element::Type_t::i8>(), | ||
generateScatterParams<element::Type_t::i32, element::Type_t::u8>(), | ||
generateScatterParams<element::Type_t::i32, element::Type_t::i16>(), | ||
generateScatterParams<element::Type_t::i32, element::Type_t::u16>(), | ||
generateScatterParams<element::Type_t::i32, element::Type_t::i32>(), | ||
generateScatterParams<element::Type_t::i32, element::Type_t::u32>(), | ||
generateScatterParams<element::Type_t::i32, element::Type_t::i64>(), | ||
generateScatterParams<element::Type_t::i32, element::Type_t::u64>(), | ||
// i64 | ||
generateScatterParams<element::Type_t::i64, element::Type_t::i8>(), | ||
generateScatterParams<element::Type_t::i64, element::Type_t::u8>(), | ||
generateScatterParams<element::Type_t::i64, element::Type_t::i16>(), | ||
generateScatterParams<element::Type_t::i64, element::Type_t::u16>(), | ||
generateScatterParams<element::Type_t::i64, element::Type_t::i32>(), | ||
generateScatterParams<element::Type_t::i64, element::Type_t::u32>(), | ||
generateScatterParams<element::Type_t::i64, element::Type_t::i64>(), | ||
generateScatterParams<element::Type_t::i64, element::Type_t::u64>(), | ||
// u32 | ||
generateScatterParams<element::Type_t::u32, element::Type_t::i8>(), | ||
generateScatterParams<element::Type_t::u32, element::Type_t::u8>(), | ||
generateScatterParams<element::Type_t::u32, element::Type_t::i16>(), | ||
generateScatterParams<element::Type_t::u32, element::Type_t::u16>(), | ||
generateScatterParams<element::Type_t::u32, element::Type_t::i32>(), | ||
generateScatterParams<element::Type_t::u32, element::Type_t::u32>(), | ||
generateScatterParams<element::Type_t::u32, element::Type_t::i64>(), | ||
generateScatterParams<element::Type_t::u32, element::Type_t::u64>(), | ||
// u64 | ||
generateScatterParams<element::Type_t::u64, element::Type_t::i8>(), | ||
generateScatterParams<element::Type_t::u64, element::Type_t::u8>(), | ||
generateScatterParams<element::Type_t::u64, element::Type_t::i16>(), | ||
generateScatterParams<element::Type_t::u64, element::Type_t::u16>(), | ||
generateScatterParams<element::Type_t::u64, element::Type_t::i32>(), | ||
generateScatterParams<element::Type_t::u64, element::Type_t::u32>(), | ||
generateScatterParams<element::Type_t::u64, element::Type_t::i64>(), | ||
generateScatterParams<element::Type_t::u64, element::Type_t::u64>(), | ||
// f16 | ||
generateScatterParams<element::Type_t::f16, element::Type_t::i8>(), | ||
generateScatterParams<element::Type_t::f16, element::Type_t::u8>(), | ||
generateScatterParams<element::Type_t::f16, element::Type_t::i16>(), | ||
generateScatterParams<element::Type_t::f16, element::Type_t::u16>(), | ||
generateScatterParams<element::Type_t::f16, element::Type_t::i32>(), | ||
generateScatterParams<element::Type_t::f16, element::Type_t::u32>(), | ||
generateScatterParams<element::Type_t::f16, element::Type_t::i64>(), | ||
generateScatterParams<element::Type_t::f16, element::Type_t::u64>(), | ||
// f32 | ||
generateScatterParams<element::Type_t::f32, element::Type_t::i8>(), | ||
generateScatterParams<element::Type_t::f32, element::Type_t::u8>(), | ||
generateScatterParams<element::Type_t::f32, element::Type_t::i16>(), | ||
generateScatterParams<element::Type_t::f32, element::Type_t::u16>(), | ||
generateScatterParams<element::Type_t::f32, element::Type_t::i32>(), | ||
generateScatterParams<element::Type_t::f32, element::Type_t::u32>(), | ||
generateScatterParams<element::Type_t::f32, element::Type_t::i64>(), | ||
generateScatterParams<element::Type_t::f32, element::Type_t::u64>(), | ||
}; | ||
std::vector<ScatterElementsUpdateParams> combinedParams; | ||
for (const auto& params : scatterTypeParams) { | ||
combinedParams.insert(combinedParams.end(), params.begin(), params.end()); | ||
} | ||
return combinedParams; | ||
} | ||
INSTANTIATE_TEST_SUITE_P(smoke_ScatterEltsUpdate_With_Hardcoded_Refs, | ||
ReferenceScatterElementsUpdateLayerTest, | ||
::testing::ValuesIn(generateScatterCombinedParams()), | ||
ReferenceScatterElementsUpdateLayerTest::getTestCaseName); | ||
} // namespace |
51 changes: 51 additions & 0 deletions
51
.../tests/functional/inference_engine/serialization/single_layer/scatter_elements_update.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <vector> | ||
#include <ngraph/opsets/opset3.hpp> | ||
|
||
#include "shared_test_classes/single_layer/scatter_elements_update.hpp" | ||
#include "common_test_utils/test_constants.hpp" | ||
|
||
using namespace LayerTestsDefinitions; | ||
using namespace ngraph::opset3; | ||
|
||
namespace { | ||
TEST_P(ScatterElementsUpdateLayerTest, Serialize) { | ||
Serialize(); | ||
} | ||
// map<inputShape, map<indicesShape, axis>> | ||
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape { | ||
{{10, 12, 15}, {{{1, 2, 4}, {0, 1, 2}}, {{2, 2, 2}, {-1, -2, -3}}}}, | ||
{{15, 9, 8, 12}, {{{1, 2, 2, 2}, {0, 1, 2, 3}}, {{1, 2, 1, 4}, {-1, -2, -3, -4}}}}, | ||
{{9, 9, 8, 8, 11, 10}, {{{1, 2, 1, 2, 1, 2}, {5, -3}}}}, | ||
}; | ||
// index value should not be random data | ||
const std::vector<std::vector<size_t>> idxValue = { | ||
{1, 0, 4, 6, 2, 3, 7, 5} | ||
}; | ||
|
||
const std::vector<InferenceEngine::Precision> inputPrecisions = { | ||
InferenceEngine::Precision::FP32, | ||
InferenceEngine::Precision::FP16, | ||
InferenceEngine::Precision::I32, | ||
}; | ||
|
||
const std::vector<InferenceEngine::Precision> idxPrecisions = { | ||
InferenceEngine::Precision::I32, | ||
InferenceEngine::Precision::I64, | ||
}; | ||
|
||
const auto ScatterEltUpdateCases = ::testing::Combine( | ||
::testing::ValuesIn(ScatterElementsUpdateLayerTest::combineShapes(axesShapeInShape)), | ||
::testing::ValuesIn(idxValue), | ||
::testing::ValuesIn(inputPrecisions), | ||
::testing::ValuesIn(idxPrecisions), | ||
::testing::Values(CommonTestUtils::DEVICE_CPU) | ||
); | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_ScatterEltsUpdateSerialization, ScatterElementsUpdateLayerTest, | ||
ScatterEltUpdateCases, ScatterElementsUpdateLayerTest::getTestCaseName); | ||
|
||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "gtest/gtest.h" | ||
#include "ngraph/ngraph.hpp" | ||
#include "ngraph/op/util/attr_types.hpp" | ||
#include "ngraph/opsets/opset3.hpp" | ||
#include "util/visitor.hpp" | ||
|
||
using namespace ngraph; | ||
using ngraph::test::NodeBuilder; | ||
using ngraph::test::ValueMap; | ||
|
||
TEST(attributes, scatter_elements_update) { | ||
NodeBuilder::get_ops().register_factory<opset3::ScatterElementsUpdate>(); | ||
|
||
auto data = std::make_shared<op::Parameter>(element::f32, Shape{2, 4, 5, 7}); | ||
auto indices = std::make_shared<op::Parameter>(element::i16, Shape{2, 2, 2, 2}); | ||
auto updates = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 2, 2}); | ||
auto axis = std::make_shared<op::Parameter>(element::i16, Shape{}); | ||
|
||
auto scatter = std::make_shared<opset3::ScatterElementsUpdate>(data, indices, updates, axis); | ||
NodeBuilder builder(scatter); | ||
|
||
const auto expected_attr_count = 0; | ||
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count); | ||
} |