Skip to content

Commit

Permalink
[LPT] Concat with StridedSlice funcional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Feb 4, 2021
1 parent fff29a6 commit 15cf049
Show file tree
Hide file tree
Showing 3 changed files with 531 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "layer_transformation.hpp"

#include <string>
#include <sstream>
#include <memory>

#include <gtest/gtest.h>

#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <low_precision/transformer.hpp>
#include <low_precision/concat.hpp>
#include <low_precision/concat_multi_channels.hpp>
#include <low_precision/max_pool.hpp>
#include <low_precision/strided_slice.hpp>

#include "common_test_utils/ngraph_test_utils.hpp"
#include "lpt_ngraph_functions/concat_function.hpp"
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
#include "simple_low_precision_transformer.hpp"

using namespace testing;
using namespace ngraph;
using namespace ngraph::pass;

namespace {

class ConcatTransformationActualValues {
public:
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
};

inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) {
return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2;
}

class ConcatTransformationResultValues {
public:
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
ngraph::element::Type precisionBeforeConcat;
ngraph::element::Type precisionAfterConcat;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter1;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter2;
};

inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) {
return out << "_" <<
values.fakeQuantize1 << "_" <<
values.fakeQuantize2 << "_" <<
values.dequantizationAfter1 << "_" <<
values.dequantizationAfter2;
}

class ConcatTransformationTestValues {
public:
ngraph::pass::low_precision::LayerTransformation::Params params;
bool multiChannels;
bool ssBeforeConcat;
bool ssAfterConcat;
ConcatTransformationActualValues actual;
ConcatTransformationResultValues result;
};

inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) {
return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result;
}

typedef std::tuple <
ngraph::element::Type,
ngraph::Shape,
ConcatTransformationTestValues
> ConcatTransformationParams;

class ConcatWithStridedSliceTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
public:
void SetUp() override {
const ngraph::element::Type precision = std::get<0>(GetParam());
const ngraph::Shape shape = std::get<1>(GetParam());
ConcatTransformationTestValues testValues = std::get<2>(GetParam());

actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginalWithStridedSlice(
precision,
shape,
testValues.actual.fakeQuantize1,
testValues.actual.fakeQuantize2,
testValues.ssBeforeConcat,
testValues.ssAfterConcat);

SimpleLowPrecisionTransformer transform;
if (testValues.multiChannels) {
transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
} else {
transform.add<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params);
}
transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
transform.add<ngraph::pass::low_precision::StridedSliceTransformation, ngraph::opset1::StridedSlice>(testValues.params);
transform.transform(actualFunction);

referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithStridedSlice(
precision,
shape,
testValues.result.fakeQuantize1,
testValues.result.fakeQuantize2,
testValues.result.dequantizationBefore,
testValues.result.precisionBeforeConcat,
testValues.result.precisionAfterConcat,
testValues.ssBeforeConcat,
testValues.ssAfterConcat,
testValues.result.dequantizationAfter1,
testValues.result.dequantizationAfter2);
}

static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
const ngraph::element::Type precision = std::get<0>(obj.param);
const ngraph::Shape shape = std::get<1>(obj.param);
const ConcatTransformationTestValues testValues = std::get<2>(obj.param);

std::ostringstream result;
result <<
LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
(testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") <<
(testValues.ssBeforeConcat ? "SS_before_concat_" : "") <<
(testValues.ssAfterConcat ? "SS_after_cancat_" : "") <<
testValues.actual << "_" <<
testValues.result << "_";
return result.str();
}
};

TEST_P(ConcatWithStridedSliceTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(referenceFunction, actualFunction, true);
ASSERT_TRUE(res.first) << res.second;
}

const std::vector<ngraph::element::Type> precisions = {
ngraph::element::f32,
// ngraph::element::f16
};

const std::vector<ConcatTransformationTestValues> testValues = {
// FQ with the same values, ss before concat, ss after concat
{
LayerTransformation::createParamsU8I8(),
true,
true,
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ngraph::element::f32, {}, { 0.01f }},
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, {}, { 0.01f }},
{ngraph::element::f32, {}, { 0.01f }}
}
},
// FQ with different values, ss before concat, ss after concat
{
LayerTransformation::createParamsU8I8(),
true,
true,
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} },
{ngraph::element::f32, {}, { 0.01f }},
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f} }},
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }}
}
},
// FQ with different values, ss after concat
{
LayerTransformation::createParamsU8I8(),
true,
false,
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} },
{ngraph::element::f32, {}, { 0.01f }},
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.01f, 0.01f, 0.1f, 0.1f} }},
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }}
}
},
// FQ with different values, ss before concat
{
LayerTransformation::createParamsU8I8(),
true,
true,
false,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} },
{ngraph::element::f32, {}, { 0.01f }},
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }},
{ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }}
}
},
// FQ with zero-point, ss before concat, ss after concat
{
LayerTransformation::createParamsU8I8(),
true,
true,
true,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} }
},
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, {}, {1.275f}, {2.55f}, {0.f}, {255.f} },
{ngraph::element::f32, {}, { 0.01f }},
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, { {0.f, 0.f, -255.f, -255.f} }, { {0.01f, 0.01f, 0.005f, 0.005f} }},
{ngraph::element::f32, { {0.f, 0.f, -255.f, -255.f, -255.f, -255.f} }, { {0.01f, 0.01f, 0.005f, 0.005f, 0.005f, 0.005f} }}
}
},
// not multi channels concat, ss before concat, ss after concat
{
LayerTransformation::createParamsU8I8(),
false,
true,
true,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} }
},
{
{ 256ul, {}, {0.f}, {2.55f}, {85.f}, {255.f} },
{ 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {170.f} },
{ngraph::element::f32, { 85 }, { 0.015f } },
ngraph::element::u8,
ngraph::element::u8,
{ngraph::element::f32, { 85 }, { 0.015f } },
{ngraph::element::f32, { 85 }, { 0.015f } }
}
},
};

const std::vector<ngraph::Shape> shapes = {
{ 1, 4, 9, 9 },
{ 4, 4, 9, 9 }
};

INSTANTIATE_TEST_CASE_P(
smoke_LPT,
ConcatWithStridedSliceTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::ValuesIn(shapes),
::testing::ValuesIn(testValues)),
ConcatWithStridedSliceTransformation::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ class ConcatFunction {
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2);

static std::shared_ptr<ngraph::Function> getOriginalWithStridedSlice(
const ngraph::element::Type precision,
const ngraph::Shape inputShape,
const FakeQuantizeOnData& fq1,
const FakeQuantizeOnData& fq2,
const bool ssBeforeConcat,
const bool ssAfterConcat);

static std::shared_ptr<ngraph::Function> getOriginalWithDifferentPrecisionOnChilds(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
Expand Down Expand Up @@ -151,6 +159,19 @@ class ConcatFunction {
const DequantizationOperations& dequantizationOperations1,
const DequantizationOperations& dequantizationOperations2);

static std::shared_ptr<ngraph::Function> getReferenceWithStridedSlice(
const ngraph::element::Type inputPrecision,
const ngraph::Shape inputShape,
const FakeQuantizeOnData& fq1,
const FakeQuantizeOnData& fq2,
const DequantizationOperations& deqBefore,
const ngraph::element::Type precisionBeforeConcat,
const ngraph::element::Type precisionAfterConcat,
const bool ssBeforeConcat,
const bool ssAfterConcat,
const DequantizationOperations& deqAfter1,
const DequantizationOperations& deqAfter2);

static std::shared_ptr<ngraph::Function> getReferenceWithDifferentPrecisionOnChilds(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
Expand Down
Loading

0 comments on commit 15cf049

Please sign in to comment.