forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revise erf OP (openvinotoolkit#6477)
* Revise Erf OP sepc Signed-off-by: Luwei Zhou <[email protected]> * Revise the NGraph Erf OP implment to switch to RTTI. Signed-off-by: Luwei Zhou <[email protected]> * Remove the duplicated Erf in the activation type vector. Signed-off-by: Luwei Zhou <[email protected]> * Add NGraph visitor API test case. Signed-off-by: Luwei Zhou <[email protected]> * Enalbe the Erf visitor API CmakeLists.txt. Signed-off-by: Luwei Zhou <[email protected]> * Revise the Erf OP backend test Signed-off-by: Luwei Zhou <[email protected]> * Migrate to use the template test. * Add erf type_prop test. * Update the license * Unary Visitor test template fix -Migrate OP Tanh to use RTTI; -Remove the using namespace in the header file -Migrate the Swish and Tanh visitor test to use template code Signed-off-by: Luwei Zhou <[email protected]> * Revert "Unary Visitor test template fix" This reverts commit b686c93. * Update the doc format. * Update the document format and description. Signed-off-by: Luwei Zhou <[email protected]> * Add Erf OP into the layer test summary list * Migrate the Erf backend test into template_plugin infrastructure * Update the Erf supported input type. * Remove the boolean type support in erf reference implement. validate_and_infer_elementwise_arithmetic() will fail with boolean type. * Update the erf test with all supported types. * Update with separate namespace of CommonReferenceTest
- Loading branch information
1 parent
a5a0f30
commit f874e3f
Showing
11 changed files
with
133 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
docs/template_plugin/tests/functional/op_reference/erf.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <ie_core.hpp> | ||
#include <ie_ngraph_utils.hpp> | ||
#include <limits> | ||
#include <ngraph/ngraph.hpp> | ||
#include <shared_test_classes/base/layer_test_utils.hpp> | ||
#include <tuple> | ||
|
||
#include "base_reference_test.hpp" | ||
|
||
using namespace reference_tests; | ||
using namespace ngraph; | ||
using namespace InferenceEngine; | ||
|
||
struct ErfParams { | ||
template <class IT> | ||
ErfParams(const ngraph::PartialShape& shape, const ngraph::element::Type& iType, const std::vector<IT>& iValues) | ||
: pshape(shape), inType(iType), outType(iType), inputData(CreateBlob(iType, iValues)) { | ||
std::vector<IT> oValues; | ||
std::vector<double> output; | ||
for (auto element : iValues) | ||
output.push_back(static_cast<double>(element)); | ||
|
||
std::transform(output.begin(), output.end(), output.begin(), [](double input) -> double { | ||
return std::erf(input); | ||
}); | ||
|
||
if (std::is_integral<IT>()) { | ||
std::transform(output.begin(), output.end(), output.begin(), [](double input) -> double { | ||
return std::round(input); | ||
}); | ||
} | ||
|
||
for (auto element : output) | ||
oValues.push_back(static_cast<IT>(element)); | ||
refData = CreateBlob(outType, oValues); | ||
} | ||
ngraph::PartialShape pshape; | ||
ngraph::element::Type inType; | ||
ngraph::element::Type outType; | ||
InferenceEngine::Blob::Ptr inputData; | ||
InferenceEngine::Blob::Ptr refData; | ||
}; | ||
|
||
class ReferenceErfLayerTest : public testing::TestWithParam<ErfParams>, public CommonReferenceTest { | ||
public: | ||
void SetUp() override { | ||
auto params = GetParam(); | ||
function = CreateFunction(params.pshape, params.inType, params.outType); | ||
inputData = {params.inputData}; | ||
refOutData = {params.refData}; | ||
} | ||
static std::string getTestCaseName(const testing::TestParamInfo<ErfParams>& obj) { | ||
auto param = obj.param; | ||
std::ostringstream result; | ||
result << "shape=" << param.pshape << "_"; | ||
result << "iType=" << param.inType << "_"; | ||
result << "oType=" << param.outType; | ||
return result.str(); | ||
} | ||
|
||
private: | ||
static std::shared_ptr<Function> CreateFunction(const PartialShape& input_shape, const element::Type& input_type, | ||
const element::Type& expected_output_type) { | ||
const auto in = std::make_shared<op::Parameter>(input_type, input_shape); | ||
const auto erf = std::make_shared<op::Erf>(in); | ||
return std::make_shared<Function>(NodeVector {erf}, ParameterVector {in}); | ||
} | ||
}; | ||
|
||
TEST_P(ReferenceErfLayerTest, CompareWithRefs) { | ||
Exec(); | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P( | ||
smoke_Erf_With_Hardcoded_Refs, ReferenceErfLayerTest, | ||
::testing::Values(ErfParams(ngraph::PartialShape {2, 5}, ngraph::element::f32, | ||
std::vector<float> {-INFINITY, -4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, INFINITY}), | ||
ErfParams(ngraph::PartialShape {2, 5}, ngraph::element::f16, | ||
std::vector<float16> {-INFINITY, -4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, INFINITY}), | ||
ErfParams(ngraph::PartialShape {2, 3}, ngraph::element::i32, | ||
std::vector<int32_t> {std::numeric_limits<int32_t>::min(), -2, -1, 1, 2, std::numeric_limits<int32_t>::max()}), | ||
ErfParams(ngraph::PartialShape {2, 3}, ngraph::element::u32, | ||
std::vector<uint32_t> {std::numeric_limits<uint32_t>::min(), 0, 1, 2, 3, std::numeric_limits<uint32_t>::max()}), | ||
ErfParams(ngraph::PartialShape {2, 3}, ngraph::element::i64, | ||
std::vector<int64_t> {std::numeric_limits<int64_t>::min(), -2, -1, 1, 2, std::numeric_limits<int64_t>::max()}), | ||
ErfParams(ngraph::PartialShape {2, 3}, ngraph::element::u64, | ||
std::vector<uint64_t> {std::numeric_limits<uint64_t>::min(), 0, 1, 2, 3, std::numeric_limits<uint64_t>::max()})), | ||
ReferenceErfLayerTest::getTestCaseName); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "unary_ops.hpp" | ||
|
||
using Type = ::testing::Types<ngraph::op::Erf>; | ||
|
||
INSTANTIATE_TYPED_TEST_SUITE_P(type_prop_erf, UnaryOperator, Type); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "unary_ops.hpp" | ||
|
||
using Type = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Erf, element::f32>>; | ||
|
||
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_without_atrribute, | ||
UnaryOperatorVisitor, | ||
Type, | ||
UnaryOperatorTypeName); |