-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LPT] Concat with StridedSlice funcional tests
- Loading branch information
Showing
3 changed files
with
532 additions
and
1 deletion.
There are no files selected for viewing
284 changes: 284 additions & 0 deletions
284
...nctional/inference_engine/lp_transformations/concat_with_strided_slice_transformation.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "layer_transformation.hpp" | ||
|
||
#include <string> | ||
#include <sstream> | ||
#include <memory> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <transformations/utils/utils.hpp> | ||
#include <transformations/init_node_info.hpp> | ||
#include <low_precision/transformer.hpp> | ||
#include <low_precision/concat.hpp> | ||
#include <low_precision/concat_multi_channels.hpp> | ||
#include <low_precision/max_pool.hpp> | ||
#include <low_precision/strided_slice.hpp> | ||
|
||
#include "common_test_utils/ngraph_test_utils.hpp" | ||
#include "lpt_ngraph_functions/concat_function.hpp" | ||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" | ||
#include "simple_low_precision_transformer.hpp" | ||
|
||
using namespace testing; | ||
using namespace ngraph; | ||
using namespace ngraph::pass; | ||
|
||
namespace { | ||
|
||
class ConcatTransformationActualValues { | ||
public: | ||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1; | ||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2; | ||
}; | ||
|
||
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) { | ||
return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2; | ||
} | ||
|
||
class ConcatTransformationResultValues { | ||
public: | ||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1; | ||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2; | ||
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore; | ||
ngraph::element::Type precisionBeforeConcat; | ||
ngraph::element::Type precisionAfterConcat; | ||
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter1; | ||
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter2; | ||
}; | ||
|
||
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) { | ||
return out << "_" << | ||
values.fakeQuantize1 << "_" << | ||
values.fakeQuantize2 << "_" << | ||
values.dequantizationAfter1 << "_" << | ||
values.dequantizationAfter2; | ||
} | ||
|
||
class ConcatTransformationTestValues { | ||
public: | ||
ngraph::pass::low_precision::LayerTransformation::Params params; | ||
bool multiChannels; | ||
bool ssBeforeConcat; | ||
bool ssAfterConcat; | ||
ConcatTransformationActualValues actual; | ||
ConcatTransformationResultValues result; | ||
}; | ||
|
||
inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) { | ||
return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result; | ||
} | ||
|
||
typedef std::tuple < | ||
ngraph::element::Type, | ||
ngraph::Shape, | ||
ConcatTransformationTestValues | ||
> ConcatTransformationParams; | ||
|
||
class ConcatWithStridedSliceTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> { | ||
public: | ||
void SetUp() override { | ||
const ngraph::element::Type precision = std::get<0>(GetParam()); | ||
const ngraph::Shape shape = std::get<1>(GetParam()); | ||
ConcatTransformationTestValues testValues = std::get<2>(GetParam()); | ||
|
||
actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginalWithStridedSlice( | ||
precision, | ||
shape, | ||
testValues.actual.fakeQuantize1, | ||
testValues.actual.fakeQuantize2, | ||
testValues.ssBeforeConcat, | ||
testValues.ssAfterConcat); | ||
|
||
SimpleLowPrecisionTransformer transform; | ||
if (testValues.multiChannels) { | ||
transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params); | ||
} else { | ||
transform.add<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params); | ||
} | ||
transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params); | ||
transform.add<ngraph::pass::low_precision::StridedSliceTransformation, ngraph::opset1::StridedSlice>(testValues.params); | ||
transform.transform(actualFunction); | ||
|
||
referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithStridedSlice( | ||
precision, | ||
shape, | ||
testValues.result.fakeQuantize1, | ||
testValues.result.fakeQuantize2, | ||
testValues.result.dequantizationBefore, | ||
testValues.result.precisionBeforeConcat, | ||
testValues.result.precisionAfterConcat, | ||
testValues.ssBeforeConcat, | ||
testValues.ssAfterConcat, | ||
testValues.result.dequantizationAfter1, | ||
testValues.result.dequantizationAfter2); | ||
} | ||
|
||
static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) { | ||
const ngraph::element::Type precision = std::get<0>(obj.param); | ||
const ngraph::Shape shape = std::get<1>(obj.param); | ||
const ConcatTransformationTestValues testValues = std::get<2>(obj.param); | ||
|
||
std::ostringstream result; | ||
result << | ||
LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" << | ||
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") << | ||
(testValues.ssBeforeConcat ? "SS_before_concat_" : "") << | ||
(testValues.ssAfterConcat ? "SS_after_cancat_" : "") << | ||
testValues.actual << "_" << | ||
testValues.result << "_"; | ||
return result.str(); | ||
} | ||
}; | ||
|
||
TEST_P(ConcatWithStridedSliceTransformation, CompareFunctions) { | ||
actualFunction->validate_nodes_and_infer_types(); | ||
auto res = compare_functions(referenceFunction, actualFunction, true); | ||
ASSERT_TRUE(res.first) << res.second; | ||
} | ||
|
||
const std::vector<ngraph::element::Type> precisions = { | ||
ngraph::element::f32, | ||
// ngraph::element::f16 | ||
}; | ||
|
||
const std::vector<ConcatTransformationTestValues> testValues = { | ||
// FQ with the same values, ss before concat, ss after concat | ||
{ | ||
LayerTransformation::createParamsU8I8(), | ||
true, | ||
true, | ||
true, | ||
{ | ||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, | ||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} } | ||
}, | ||
{ | ||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, | ||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, | ||
{ngraph::element::f32, {}, { 0.01f }}, | ||
ngraph::element::u8, | ||
ngraph::element::u8, | ||
{ngraph::element::f32, {}, { 0.01f }}, | ||
{ngraph::element::f32, {}, { 0.01f }} | ||
} | ||
}, | ||
// FQ with different values, ss before concat, ss after concat | ||
{ | ||
LayerTransformation::createParamsU8I8(), | ||
true, | ||
true, | ||
true, | ||
{ | ||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, | ||
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} } | ||
}, | ||
{ | ||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, | ||
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} }, | ||
{ngraph::element::f32, {}, { 0.01f }}, | ||
ngraph::element::u8, | ||
ngraph::element::u8, | ||
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f} }}, | ||
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }} | ||
} | ||
}, | ||
// FQ with different values, ss after concat | ||
{ | ||
LayerTransformation::createParamsU8I8(), | ||
true, | ||
false, | ||
true, | ||
{ | ||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, | ||
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} } | ||
}, | ||
{ | ||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, | ||
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} }, | ||
{ngraph::element::f32, {}, { 0.01f }}, | ||
ngraph::element::u8, | ||
ngraph::element::u8, | ||
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.01f, 0.01f, 0.1f, 0.1f} }}, | ||
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }} | ||
} | ||
}, | ||
// FQ with different values, ss before concat | ||
{ | ||
LayerTransformation::createParamsU8I8(), | ||
true, | ||
true, | ||
false, | ||
{ | ||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, | ||
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} } | ||
}, | ||
{ | ||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, | ||
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} }, | ||
{ngraph::element::f32, {}, { 0.01f }}, | ||
ngraph::element::u8, | ||
ngraph::element::u8, | ||
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }}, | ||
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }} | ||
} | ||
}, | ||
// FQ with zero-point, ss before concat, ss after concat | ||
{ | ||
LayerTransformation::createParamsU8I8(), | ||
true, | ||
true, | ||
true, | ||
{ | ||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, | ||
{ 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} } | ||
}, | ||
{ | ||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} }, | ||
{ 256ul, {}, {1.275f}, {2.55f}, {0.f}, {255.f} }, | ||
{ngraph::element::f32, {}, { 0.01f }}, | ||
ngraph::element::u8, | ||
ngraph::element::u8, | ||
{ngraph::element::f32, { {0.f, 0.f, -255.f, -255.f} }, { {0.01f, 0.01f, 0.005f, 0.005f} }}, | ||
{ngraph::element::f32, { {0.f, 0.f, -255.f, -255.f, -255.f, -255.f} }, { {0.01f, 0.01f, 0.005f, 0.005f, 0.005f, 0.005f} }} | ||
} | ||
}, | ||
// not multi channels concat, ss before concat, ss after concat | ||
{ | ||
LayerTransformation::createParamsU8I8(), | ||
false, | ||
true, | ||
true, | ||
{ | ||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, | ||
{ 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} } | ||
}, | ||
{ | ||
{ 256ul, {}, {0.f}, {2.55f}, {85.f}, {255.f} }, | ||
{ 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {170.f} }, | ||
{ngraph::element::f32, { 85 }, { 0.015f } }, | ||
ngraph::element::u8, | ||
ngraph::element::u8, | ||
{ngraph::element::f32, { 85 }, { 0.015f } }, | ||
{ngraph::element::f32, { 85 }, { 0.015f } } | ||
} | ||
}, | ||
}; | ||
|
||
const std::vector<ngraph::Shape> shapes = { | ||
{ 1, 4, 9, 9 }, | ||
{ 4, 4, 9, 9 } | ||
}; | ||
|
||
INSTANTIATE_TEST_CASE_P( | ||
smoke_LPT, | ||
ConcatWithStridedSliceTransformation, | ||
::testing::Combine( | ||
::testing::ValuesIn(precisions), | ||
::testing::ValuesIn(shapes), | ||
::testing::ValuesIn(testValues)), | ||
ConcatWithStridedSliceTransformation::getTestCaseName); | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.