-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Revise spec * Comparison backend test POC * Split Comparison ops tests into separate files * remove comparison.cpp, remove unused imports, replace for_each with range based for * remove unnecessary ngraph:: prefixes * Fix links in spec * Add Less to trusted ops list * Add missing ',' * Use builder in backend tests * Remove old backend tests for less, equal
- Loading branch information
Bartosz Lesniewski
authored
Jul 28, 2021
1 parent
6ee9285
commit 1471095
Showing
6 changed files
with
243 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
docs/template_plugin/tests/functional/op_reference/comparison.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <ie_core.hpp> | ||
#include <ie_ngraph_utils.hpp> | ||
#include <ngraph/ngraph.hpp> | ||
#include <shared_test_classes/base/layer_test_utils.hpp> | ||
#include <vector> | ||
|
||
#include "base_reference_test.hpp" | ||
#include "ngraph_functions/builders.hpp" | ||
|
||
namespace reference_tests { | ||
namespace ComparisonOpsRefTestDefinitions { | ||
|
||
struct RefComparisonParams { | ||
ngraph::helpers::ComparisonTypes compType; | ||
Tensor input1; | ||
Tensor input2; | ||
Tensor expected; | ||
}; | ||
|
||
struct Builder : ParamsBuilder<RefComparisonParams> { | ||
REFERENCE_TESTS_ADD_SET_PARAM(Builder, compType); | ||
REFERENCE_TESTS_ADD_SET_PARAM(Builder, input1); | ||
REFERENCE_TESTS_ADD_SET_PARAM(Builder, input2); | ||
REFERENCE_TESTS_ADD_SET_PARAM(Builder, expected); | ||
}; | ||
|
||
class ReferenceComparisonLayerTest : public testing::TestWithParam<RefComparisonParams>, public CommonReferenceTest { | ||
public: | ||
void SetUp() override { | ||
const auto& params = GetParam(); | ||
function = CreateFunction(params.compType, params.input1.shape, params.input2.shape, params.input1.type, params.expected.type); | ||
inputData = {params.input1.data, params.input2.data}; | ||
refOutData = {params.expected.data}; | ||
} | ||
static std::string getTestCaseName(const testing::TestParamInfo<RefComparisonParams>& obj) { | ||
const auto& param = obj.param; | ||
std::ostringstream result; | ||
result << "comparisonType=" << param.compType << "_"; | ||
result << "inpt_shape1=" << param.input1.shape << "_"; | ||
result << "inpt_shape2=" << param.input2.shape << "_"; | ||
result << "iType=" << param.input1.type << "_"; | ||
result << "oType=" << param.expected.type; | ||
return result.str(); | ||
} | ||
|
||
private: | ||
static std::shared_ptr<ngraph::Function> CreateFunction(ngraph::helpers::ComparisonTypes comp_op_type, const ngraph::PartialShape& input_shape1, | ||
const ngraph::PartialShape& input_shape2, const ngraph::element::Type& input_type, | ||
const ngraph::element::Type& expected_output_type) { | ||
const auto in = std::make_shared<ngraph::op::Parameter>(input_type, input_shape1); | ||
const auto in2 = std::make_shared<ngraph::op::Parameter>(input_type, input_shape2); | ||
const auto comp = ngraph::builder::makeComparison(in, in2, comp_op_type); | ||
return std::make_shared<ngraph::Function>(ngraph::NodeVector {comp}, ngraph::ParameterVector {in, in2}); | ||
} | ||
}; | ||
} // namespace ComparisonOpsRefTestDefinitions | ||
} // namespace reference_tests |
84 changes: 84 additions & 0 deletions
84
docs/template_plugin/tests/functional/op_reference/equal.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <ie_core.hpp> | ||
#include <ie_ngraph_utils.hpp> | ||
#include <ngraph/ngraph.hpp> | ||
#include <shared_test_classes/base/layer_test_utils.hpp> | ||
|
||
#include "comparison.hpp" | ||
|
||
using namespace ngraph; | ||
using namespace InferenceEngine; | ||
using ComparisonTypes = ngraph::helpers::ComparisonTypes; | ||
|
||
|
||
namespace reference_tests { | ||
namespace ComparisonOpsRefTestDefinitions { | ||
namespace { | ||
|
||
TEST_P(ReferenceComparisonLayerTest, EqualCompareWithHardcodedRefs) { | ||
Exec(); | ||
} | ||
|
||
template <element::Type_t IN_ET> | ||
std::vector<RefComparisonParams> generateComparisonParams(const element::Type& type) { | ||
using T = typename element_type_traits<IN_ET>::value_type; | ||
std::vector<RefComparisonParams> compParams { | ||
// 1D // 2D // 3D // 4D | ||
Builder {} | ||
.compType(ComparisonTypes::EQUAL) | ||
.input1({{2, 2}, type, std::vector<T> {0, 12, 23, 0}}) | ||
.input2({{2, 2}, type, std::vector<T> {0, 12, 23, 0}}) | ||
.expected({{2, 2}, element::boolean, std::vector<char> {1, 1, 1, 1}}), | ||
Builder {} | ||
.compType(ComparisonTypes::EQUAL) | ||
.input1({{2, 3}, type, std::vector<T> {0, 6, 45, 1, 21, 21}}) | ||
.input2({{2, 3}, type, std::vector<T> {1, 18, 23, 1, 19, 21}}) | ||
.expected({{2, 3}, element::boolean, std::vector<char> {0, 0, 0, 1, 0, 1}}), | ||
Builder {} | ||
.compType(ComparisonTypes::EQUAL) | ||
.input1({{1}, type, std::vector<T> {53}}) | ||
.input2({{1}, type, std::vector<T> {53}}) | ||
.expected({{1}, element::boolean, std::vector<char> {1}}), | ||
Builder {} | ||
.compType(ComparisonTypes::EQUAL) | ||
.input1({{2, 4}, type, std::vector<T> {0, 12, 23, 0, 1, 5, 11, 8}}) | ||
.input2({{2, 4}, type, std::vector<T> {0, 12, 23, 0, 10, 5, 11, 8}}) | ||
.expected({{2, 4}, element::boolean, std::vector<char> {1, 1, 1, 1, 0, 1, 1, 1}}), | ||
Builder {} | ||
.compType(ComparisonTypes::EQUAL) | ||
.input1({{3, 1, 2}, type, std::vector<T> {2, 1, 4, 1, 3, 1}}) | ||
.input2({{1, 2, 1}, type, std::vector<T> {1, 1}}) | ||
.expected({{3, 2, 2}, element::boolean, std::vector<char> {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}}), | ||
Builder {} | ||
.compType(ComparisonTypes::EQUAL) | ||
.input1({{2, 1, 2, 1}, type, std::vector<T> {2, 1, 4, 1}}) | ||
.input2({{1, 2, 1}, type, std::vector<T> {1, 1}}) | ||
.expected({{2, 1, 2, 1}, element::boolean, std::vector<char> {0, 1, 0, 1}})}; | ||
return compParams; | ||
} | ||
|
||
std::vector<RefComparisonParams> generateComparisonCombinedParams() { | ||
const std::vector<std::vector<RefComparisonParams>> compTypeParams { | ||
generateComparisonParams<element::Type_t::f32>(element::f32), | ||
generateComparisonParams<element::Type_t::f16>(element::f16), | ||
generateComparisonParams<element::Type_t::i32>(element::i32), | ||
generateComparisonParams<element::Type_t::u32>(element::u32), | ||
generateComparisonParams<element::Type_t::u8>(element::boolean)}; | ||
std::vector<RefComparisonParams> combinedParams; | ||
|
||
for (const auto& params : compTypeParams) { | ||
combinedParams.insert(combinedParams.end(), params.begin(), params.end()); | ||
} | ||
return combinedParams; | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P(smoke_Comparison_With_Hardcoded_Refs, ReferenceComparisonLayerTest, ::testing::ValuesIn(generateComparisonCombinedParams()), | ||
ReferenceComparisonLayerTest::getTestCaseName); | ||
} // namespace | ||
} // namespace ComparisonOpsRefTestDefinitions | ||
} // namespace reference_tests |
82 changes: 82 additions & 0 deletions
82
docs/template_plugin/tests/functional/op_reference/less.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <ie_core.hpp> | ||
#include <ie_ngraph_utils.hpp> | ||
#include <ngraph/ngraph.hpp> | ||
#include <shared_test_classes/base/layer_test_utils.hpp> | ||
|
||
#include "comparison.hpp" | ||
|
||
using namespace ngraph; | ||
using namespace InferenceEngine; | ||
using ComparisonTypes = ngraph::helpers::ComparisonTypes; | ||
|
||
namespace reference_tests { | ||
namespace ComparisonOpsRefTestDefinitions { | ||
namespace { | ||
TEST_P(ReferenceComparisonLayerTest, LessCompareWithHardcodedRefs) { | ||
Exec(); | ||
} | ||
|
||
template <element::Type_t IN_ET> | ||
std::vector<RefComparisonParams> generateComparisonParams(const element::Type& type) { | ||
using T = typename element_type_traits<IN_ET>::value_type; | ||
std::vector<RefComparisonParams> compParams { | ||
// 1D // 2D // 3D // 4D | ||
Builder {} | ||
.compType(ComparisonTypes::LESS) | ||
.input1({{2, 2}, type, std::vector<T> {0, 12, 23, 0}}) | ||
.input2({{2, 2}, type, std::vector<T> {0, 12, 23, 0}}) | ||
.expected({{2, 2}, element::boolean, std::vector<char> {0, 0, 0, 0}}), | ||
Builder {} | ||
.compType(ComparisonTypes::LESS) | ||
.input1({{2, 3}, type, std::vector<T> {0, 6, 45, 1, 21, 21}}) | ||
.input2({{2, 3}, type, std::vector<T> {1, 18, 23, 1, 19, 21}}) | ||
.expected({{2, 3}, element::boolean, std::vector<char> {1, 1, 0, 0, 0, 0}}), | ||
Builder {} | ||
.compType(ComparisonTypes::LESS) | ||
.input1({{1}, type, std::vector<T> {53}}) | ||
.input2({{1}, type, std::vector<T> {53}}) | ||
.expected({{1}, element::boolean, std::vector<char> {0}}), | ||
Builder {} | ||
.compType(ComparisonTypes::LESS) | ||
.input1({{2, 4}, type, std::vector<T> {0, 12, 23, 0, 1, 5, 11, 8}}) | ||
.input2({{2, 4}, type, std::vector<T> {0, 12, 23, 0, 10, 5, 11, 8}}) | ||
.expected({{2, 4}, element::boolean, std::vector<char> {0, 0, 0, 0, 1, 0, 0, 0}}), | ||
Builder {} | ||
.compType(ComparisonTypes::LESS) | ||
.input1({{3, 1, 2}, type, std::vector<T> {2, 1, 4, 1, 3, 1}}) | ||
.input2({{1, 2, 1}, type, std::vector<T> {1, 1}}) | ||
.expected({{3, 2, 2}, element::boolean, std::vector<char> {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}), | ||
Builder {} | ||
.compType(ComparisonTypes::LESS) | ||
.input1({{2, 1, 2, 1}, type, std::vector<T> {2, 1, 4, 1}}) | ||
.input2({{1, 2, 1}, type, std::vector<T> {1, 1}}) | ||
.expected({{2, 1, 2, 1}, element::boolean, std::vector<char> {0, 0, 0, 0}})}; | ||
return compParams; | ||
} | ||
|
||
std::vector<RefComparisonParams> generateComparisonCombinedParams() { | ||
const std::vector<std::vector<RefComparisonParams>> compTypeParams { | ||
generateComparisonParams<element::Type_t::f32>(element::f32), | ||
generateComparisonParams<element::Type_t::f16>(element::f16), | ||
generateComparisonParams<element::Type_t::i32>(element::i32), | ||
generateComparisonParams<element::Type_t::u32>(element::u32), | ||
generateComparisonParams<element::Type_t::u8>(element::boolean)}; | ||
std::vector<RefComparisonParams> combinedParams; | ||
|
||
for (const auto& params : compTypeParams) { | ||
combinedParams.insert(combinedParams.end(), params.begin(), params.end()); | ||
} | ||
return combinedParams; | ||
} | ||
|
||
} // namespace | ||
INSTANTIATE_TEST_SUITE_P(smoke_Comparison_With_Hardcoded_Refs, ReferenceComparisonLayerTest, ::testing::ValuesIn(generateComparisonCombinedParams()), | ||
ReferenceComparisonLayerTest::getTestCaseName); | ||
} // namespace ComparisonOpsRefTestDefinitions | ||
} // namespace reference_tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,7 @@ | |
'HSwish-4', | ||
'HardSigmoid-1', | ||
'Interpolate-4', | ||
'Less-1', | ||
'LRN-1', | ||
'LSTMCell-4', | ||
'LSTMSequence-5', | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters