forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revise GatherTree reference implementation (openvinotoolkit#7275)
* Add visitor api test * Review ngraph op shell with type_prop tests * Add op to list of trusted operations * Change name of struct with information of inputs * Add include of array data structure to fix windowds compilation error * Add template plugin test class * Remove usage of CoordinateTransform index function call from reference implementation * Rename SLT test suite * Add template plugin unit test * Add serialization SLTs * Add indentation on GatherTreeParams class data members
- Loading branch information
Showing
8 changed files
with
506 additions
and
93 deletions.
There are no files selected for viewing
100 changes: 100 additions & 0 deletions
100
docs/template_plugin/tests/functional/op_reference/gather_tree.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <ie_core.hpp> | ||
#include <ie_ngraph_utils.hpp> | ||
#include <limits> | ||
#include <algorithm> | ||
#include <ngraph/ngraph.hpp> | ||
#include <shared_test_classes/base/layer_test_utils.hpp> | ||
|
||
#include "base_reference_test.hpp" | ||
|
||
using namespace reference_tests; | ||
using namespace ngraph; | ||
using namespace InferenceEngine; | ||
|
||
namespace { | ||
struct GatherTreeParams { | ||
template <class IN_ET> | ||
GatherTreeParams(const ngraph::Shape inShape, std::vector<IN_ET> stepIds, const std::vector<IN_ET> parentIds, | ||
const std::vector<IN_ET> maxSeqLen, const std::vector<IN_ET> endToken, std::vector<IN_ET> output) : | ||
stepIdsTensor(inShape, element::from<IN_ET>(), stepIds), parentIdsTensor(inShape, element::from<IN_ET>(), parentIds), | ||
maxSeqLenTensor(ngraph::Shape{inShape[1]}, element::from<IN_ET>(), maxSeqLen), endTokenTensor(ngraph::Shape{}, element::from<IN_ET>(), endToken), | ||
expectedTensor(inShape, element::from<IN_ET>(), output) {} | ||
Tensor stepIdsTensor; | ||
Tensor parentIdsTensor; | ||
Tensor maxSeqLenTensor; | ||
Tensor endTokenTensor; | ||
Tensor expectedTensor; | ||
}; | ||
|
||
class ReferenceGatherTreeTest : public testing::TestWithParam<GatherTreeParams>, public CommonReferenceTest { | ||
public: | ||
void SetUp() override { | ||
auto params = GetParam(); | ||
function = CreateFunction(params); | ||
inputData = {params.stepIdsTensor.data, params.parentIdsTensor.data, params.maxSeqLenTensor.data, params.endTokenTensor.data}; | ||
refOutData = {params.expectedTensor.data}; | ||
} | ||
static std::string getTestCaseName(const testing::TestParamInfo<GatherTreeParams>& obj) { | ||
auto param = obj.param; | ||
std::ostringstream result; | ||
result << "iType=" << param.stepIdsTensor.type << "_"; | ||
result << "iShape=" << param.stepIdsTensor.shape; | ||
return result.str(); | ||
} | ||
|
||
private: | ||
static std::shared_ptr<Function> CreateFunction(const GatherTreeParams& params) { | ||
const auto stepIds = std::make_shared<op::Parameter>(params.stepIdsTensor.type, params.stepIdsTensor.shape); | ||
const auto parentIds = std::make_shared<op::Parameter>(params.parentIdsTensor.type, params.parentIdsTensor.shape); | ||
const auto maxSeqLen = std::make_shared<op::Parameter>(params.maxSeqLenTensor.type, params.maxSeqLenTensor.shape); | ||
const auto endToken = std::make_shared<op::Parameter>(params.endTokenTensor.type, params.endTokenTensor.shape); | ||
const auto gatherTree = std::make_shared<op::v1::GatherTree>(stepIds, parentIds, maxSeqLen, endToken); | ||
return std::make_shared<Function>(NodeVector {gatherTree}, ParameterVector {stepIds, parentIds, maxSeqLen, endToken}); | ||
} | ||
}; | ||
|
||
TEST_P(ReferenceGatherTreeTest, CompareWithRefs) { | ||
Exec(); | ||
} | ||
|
||
template <element::Type_t IN_ET> | ||
std::vector<GatherTreeParams> generateGatherTreeParams() { | ||
using T = typename element_type_traits<IN_ET>::value_type; | ||
std::vector<GatherTreeParams> gatherTreeParams { | ||
GatherTreeParams(Shape{4, 1, 3}, | ||
std::vector<T>{1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1}, | ||
std::vector<T>{0, 0, 0, 0, 1, 1, 2, 1, 2, -1, -1, -1}, | ||
std::vector<T>{3}, | ||
std::vector<T>{10}, | ||
std::vector<T>{2, 2, 2, 6, 5, 6, 7, 8, 9, 10, 10, 10}), | ||
GatherTreeParams(Shape{2, 2, 2}, | ||
std::vector<T>{1, 2, 3, 4, 5, 6, 7, 8}, | ||
std::vector<T>{0, 0, 0, 0, 0, 0, 0, 0}, | ||
std::vector<T>{2, 4}, | ||
std::vector<T>{0}, | ||
std::vector<T>{1, 1, 3, 3, 5, 6, 7, 8}) | ||
}; | ||
return gatherTreeParams; | ||
} | ||
|
||
std::vector<GatherTreeParams> generateGatherTreeCombinedParams() { | ||
const std::vector<std::vector<GatherTreeParams>> gatherTreeTypeParams { | ||
generateGatherTreeParams<element::Type_t::f32>(), | ||
generateGatherTreeParams<element::Type_t::i32>()}; | ||
std::vector<GatherTreeParams> combinedParams; | ||
|
||
for (const auto& params : gatherTreeTypeParams) { | ||
combinedParams.insert(combinedParams.end(), params.begin(), params.end()); | ||
} | ||
return combinedParams; | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_GatherTree_With_Hardcoded_Refs, ReferenceGatherTreeTest, | ||
testing::ValuesIn(generateGatherTreeCombinedParams()), ReferenceGatherTreeTest::getTestCaseName); | ||
} // namespace |
41 changes: 41 additions & 0 deletions
41
...rence-engine/tests/functional/inference_engine/serialization/single_layer/gather_tree.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <vector> | ||
|
||
#include "shared_test_classes/single_layer/gather_tree.hpp" | ||
#include "common_test_utils/test_constants.hpp" | ||
|
||
using namespace LayerTestsDefinitions; | ||
|
||
namespace { | ||
|
||
TEST_P(GatherTreeLayerTest, Serialize) { | ||
Serialize(); | ||
} | ||
|
||
const std::vector<InferenceEngine::Precision> netPrecisions = { | ||
InferenceEngine::Precision::FP32, | ||
InferenceEngine::Precision::I32 | ||
}; | ||
|
||
const std::vector<std::vector<size_t>> inputShapes = { {5, 1, 10}, {1, 1, 10}, {20, 1, 10}, {20, 20, 10} }; | ||
|
||
const std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = { | ||
ngraph::helpers::InputLayerType::CONSTANT, | ||
ngraph::helpers::InputLayerType::PARAMETER | ||
}; | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_GatherTree_Serialization, GatherTreeLayerTest, | ||
::testing::Combine( | ||
::testing::ValuesIn(inputShapes), | ||
::testing::ValuesIn(secondaryInputTypes), | ||
::testing::ValuesIn(netPrecisions), | ||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED), | ||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED), | ||
::testing::Values(InferenceEngine::Layout::ANY), | ||
::testing::Values(InferenceEngine::Layout::ANY), | ||
::testing::Values(CommonTestUtils::DEVICE_CPU)), | ||
GatherTreeLayerTest::getTestCaseName); | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.