diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/concat_with_strided_slice_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/concat_with_strided_slice_transformation.cpp new file mode 100644 index 00000000000000..fbe05b43133357 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/concat_with_strided_slice_transformation.cpp @@ -0,0 +1,284 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "layer_transformation.hpp" + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "lpt_ngraph_functions/concat_function.hpp" +#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" +#include "simple_low_precision_transformer.hpp" + +using namespace testing; +using namespace ngraph; +using namespace ngraph::pass; + +namespace { + +class ConcatTransformationActualValues { +public: + ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1; + ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2; +}; + +inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationActualValues& values) { + return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2; +} + +class ConcatTransformationResultValues { +public: + ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1; + ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore; + ngraph::element::Type precisionBeforeConcat; + ngraph::element::Type precisionAfterConcat; + ngraph::builder::subgraph::DequantizationOperations dequantizationAfter1; + ngraph::builder::subgraph::DequantizationOperations dequantizationAfter2; +}; + +inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationResultValues& values) { + return out << "_" << + values.fakeQuantize1 << "_" << + values.fakeQuantize2 << "_" << + values.dequantizationAfter1 << "_" << + values.dequantizationAfter2; +} + +class ConcatTransformationTestValues { +public: + ngraph::pass::low_precision::LayerTransformation::Params params; + bool multiChannels; + bool ssBeforeConcat; + bool ssAfterConcat; + ConcatTransformationActualValues actual; + ConcatTransformationResultValues result; +}; + +inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) { + return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result; +} + +typedef std::tuple < + ngraph::element::Type, + ngraph::Shape, + ConcatTransformationTestValues +> ConcatTransformationParams; + +class ConcatWithStridedSliceTransformation : public LayerTransformation, public testing::WithParamInterface { +public: + void SetUp() override { + const ngraph::element::Type precision = std::get<0>(GetParam()); + const ngraph::Shape shape = std::get<1>(GetParam()); + ConcatTransformationTestValues testValues = std::get<2>(GetParam()); + + actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginalWithStridedSlice( + precision, + shape, + testValues.actual.fakeQuantize1, + testValues.actual.fakeQuantize2, + testValues.ssBeforeConcat, + testValues.ssAfterConcat); + + SimpleLowPrecisionTransformer transform; + if (testValues.multiChannels) { + transform.add(testValues.params); + } else { + transform.add(testValues.params); + } + transform.add(testValues.params); + transform.add(testValues.params); + transform.transform(actualFunction); + + referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithStridedSlice( + precision, + shape, + testValues.result.fakeQuantize1, + testValues.result.fakeQuantize2, + testValues.result.dequantizationBefore, + testValues.result.precisionBeforeConcat, + testValues.result.precisionAfterConcat, + testValues.ssBeforeConcat, + testValues.ssAfterConcat, + testValues.result.dequantizationAfter1, + testValues.result.dequantizationAfter2); + } + + static std::string getTestCaseName(testing::TestParamInfo obj) { + const ngraph::element::Type precision = std::get<0>(obj.param); + const ngraph::Shape shape = std::get<1>(obj.param); + const ConcatTransformationTestValues testValues = std::get<2>(obj.param); + + std::ostringstream result; + result << + LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" << + (testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") << + (testValues.ssBeforeConcat ? "SS_before_concat_" : "") << + (testValues.ssAfterConcat ? "SS_after_cancat_" : "") << + testValues.actual << "_" << + testValues.result << "_"; + return result.str(); + } +}; + +TEST_P(ConcatWithStridedSliceTransformation, CompareFunctions) { + actualFunction->validate_nodes_and_infer_types(); + auto res = compare_functions(referenceFunction, actualFunction, true); + ASSERT_TRUE(res.first) << res.second; +} + +const std::vector precisions = { + ngraph::element::f32, + // ngraph::element::f16 +}; + +const std::vector testValues = { + // FQ with the same values, ss before concat, ss after concat + { + LayerTransformation::createParamsU8I8(), + true, + true, + true, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} } + }, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, + {ngraph::element::f32, {}, { 0.01f }}, + ngraph::element::u8, + ngraph::element::u8, + {ngraph::element::f32, {}, { 0.01f }}, + {ngraph::element::f32, {}, { 0.01f }} + } + }, + // FQ with different values, ss before concat, ss after concat + { + LayerTransformation::createParamsU8I8(), + true, + true, + true, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} } + }, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, + { 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} }, + {ngraph::element::f32, {}, { 0.01f }}, + ngraph::element::u8, + ngraph::element::u8, + {ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f} }}, + {ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }} + } + }, + // FQ with different values, ss after concat + { + LayerTransformation::createParamsU8I8(), + true, + false, + true, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} } + }, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, + { 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} }, + {ngraph::element::f32, {}, { 0.01f }}, + ngraph::element::u8, + ngraph::element::u8, + {ngraph::element::f32, {}, { {0.01f, 0.01f, 0.01f, 0.01f, 0.1f, 0.1f} }}, + {ngraph::element::f32, {}, { {0.01f, 0.01f, 0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }} + } + }, + // FQ with different values, ss before concat + { + LayerTransformation::createParamsU8I8(), + true, + true, + false, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} } + }, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} }, + { 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} }, + {ngraph::element::f32, {}, { 0.01f }}, + ngraph::element::u8, + ngraph::element::u8, + {ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }}, + {ngraph::element::f32, {}, { {0.01f, 0.01f, 0.1f, 0.1f, 0.1f, 0.1f} }} + } + }, + // FQ with zero-point, ss before concat, ss after concat + { + LayerTransformation::createParamsU8I8(), + true, + true, + true, + { + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, {}, {1.275f}, {2.55f}, {1.275f}, {2.55f} } + }, + { + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} }, + { 256ul, {}, {1.275f}, {2.55f}, {0.f}, {255.f} }, + {ngraph::element::f32, {}, { 0.01f }}, + ngraph::element::u8, + ngraph::element::u8, + {ngraph::element::f32, { {0.f, 0.f, -255.f, -255.f} }, { {0.01f, 0.01f, 0.005f, 0.005f} }}, + {ngraph::element::f32, { {0.f, 0.f, -255.f, -255.f, -255.f, -255.f} }, { {0.01f, 0.01f, 0.005f, 0.005f, 0.005f, 0.005f} }} + } + }, + // not multi channels concat, ss before concat, ss after concat + { + LayerTransformation::createParamsU8I8(), + false, + true, + true, + { + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} } + }, + { + { 256ul, {}, {0.f}, {2.55f}, {85.f}, {255.f} }, + { 256ul, {}, {-1.28f}, {1.27f}, {0.f}, {170.f} }, + {ngraph::element::f32, { 85 }, { 0.015f } }, + ngraph::element::u8, + ngraph::element::u8, + {ngraph::element::f32, { 85 }, { 0.015f } }, + {ngraph::element::f32, { 85 }, { 0.015f } } + } + }, +}; + +const std::vector shapes = { + { 1, 4, 9, 9 }, + { 4, 4, 9, 9 } +}; + +INSTANTIATE_TEST_CASE_P( + smoke_LPT, + ConcatWithStridedSliceTransformation, + ::testing::Combine( + ::testing::ValuesIn(precisions), + ::testing::ValuesIn(shapes), + ::testing::ValuesIn(testValues)), + ConcatWithStridedSliceTransformation::getTestCaseName); +} // namespace diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp index c6a0e8cee98a9d..df3ff860bf10ce 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp @@ -62,6 +62,14 @@ class ConcatFunction { const FakeQuantizeOnData& fqOnData1, const FakeQuantizeOnData& fqOnData2); + static std::shared_ptr getOriginalWithStridedSlice( + const ngraph::element::Type precision, + const ngraph::Shape inputShape, + const FakeQuantizeOnData& fq1, + const FakeQuantizeOnData& fq2, + const bool ssBeforeConcat, + const bool ssAfterConcat); + static std::shared_ptr getOriginalWithDifferentPrecisionOnChilds( const ngraph::element::Type precision, const ngraph::Shape& inputShape, @@ -151,6 +159,19 @@ class ConcatFunction { const DequantizationOperations& dequantizationOperations1, const DequantizationOperations& dequantizationOperations2); + static std::shared_ptr getReferenceWithStridedSlice( + const ngraph::element::Type inputPrecision, + const ngraph::Shape inputShape, + const FakeQuantizeOnData& fq1, + const FakeQuantizeOnData& fq2, + const DequantizationOperations& deqBefore, + const ngraph::element::Type precisionBeforeConcat, + const ngraph::element::Type precisionAfterConcat, + const bool ssBeforeConcat, + const bool ssAfterConcat, + const DequantizationOperations& deqAfter1, + const DequantizationOperations& deqAfter2); + static std::shared_ptr getReferenceWithDifferentPrecisionOnChilds( const ngraph::element::Type precision, const ngraph::Shape& inputShape, diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp index d8c95466b56f9d..c8747b4d3da18b 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp @@ -375,6 +375,121 @@ std::shared_ptr ConcatFunction::getOriginalSelectionWithInterm return function; } +/* +(SS) - optional + + Input + / + FQ + / \ + (SS) Clamp + | | + | FQ + \ / + Concat + /\ + / \ + (SS) MaxPool +*/ + +std::shared_ptr ConcatFunction::getOriginalWithStridedSlice( + const ngraph::element::Type precision, + const ngraph::Shape inputShape, + const FakeQuantizeOnData& fq1, + const FakeQuantizeOnData& fq2, + const bool ssBeforeConcat, + const bool ssAfterConcat) { + const auto input = std::make_shared(precision, inputShape); + input->set_friendly_name("input"); + const auto fakeQuantize1 = makeFakeQuantize(input, precision, fq1); + fakeQuantize1->set_friendly_name("FakeQuantize_1"); + + std::shared_ptr parent1 = fakeQuantize1; + + if (ssBeforeConcat) { + const auto beginParam = ngraph::op::Constant::create( + ngraph::element::i64, + ngraph::Shape{ inputShape.size() }, + std::vector(inputShape.size(), 0)); + + const auto endParam = ngraph::op::Constant::create( + ngraph::element::i64, + ngraph::Shape{ inputShape.size() }, + std::vector{ inputShape[0], inputShape[1] - 2ul, inputShape[2], inputShape[3] }); + + const std::vector beginMask{ 1, 0, 1, 1 }; + const std::vector endMask{ 1, 0, 1, 1 }; + + parent1 = std::make_shared(parent1, beginParam, endParam, beginMask, endMask); + parent1->set_friendly_name("StridedSlice_1"); + } + + const auto clamp = std::make_shared(fakeQuantize1, 0.0, 6.0); + clamp->set_friendly_name("Clamp"); + const auto fakeQuantize2 = makeFakeQuantize(clamp, precision, fq2); + fakeQuantize2->set_friendly_name("FakeQuantize_2"); + + const auto concat = std::make_shared(NodeVector{ parent1, fakeQuantize2 }, 1); + concat->set_friendly_name("Concat"); + + + ngraph::ResultVector results; + if (ssAfterConcat) { + const auto concatShape = concat->get_output_shape(0); + const auto beginParam = ngraph::op::Constant::create( + ngraph::element::i64, + ngraph::Shape{ concatShape.size() }, + std::vector(concatShape.size(), 0)); + + const auto endParam = ngraph::op::Constant::create( + ngraph::element::i64, + ngraph::Shape{ concatShape.size() }, + std::vector{ concatShape[0], concatShape[1] - 2ul, concatShape[2], concatShape[3] }); + + const std::vector beginMask{ 1, 0, 1, 1 }; + const std::vector endMask{ 1, 0, 1, 1 }; + + const auto stridedSlice = std::make_shared(concat, beginParam, endParam, beginMask, endMask); + stridedSlice->set_friendly_name("StridedSlice_2"); + + const auto result1 = std::make_shared(stridedSlice); + result1->set_friendly_name("Result_1"); + results.push_back(result1); + } else { + const auto result1 = std::make_shared(concat); + result1->set_friendly_name("Result_1"); + results.push_back(result1); + } + + const std::vector kernel = { 3, 3 }; + const std::vector stride = { 1, 1 }; + const std::vector padBegin = { 0, 0 }; + const std::vector padEnd = { 0, 0 }; + const ngraph::op::PadType padType = ngraph::op::PadType::NOTSET; + const ngraph::op::RoundingType roundingType = ngraph::op::RoundingType::FLOOR; + + const auto maxPool = std::make_shared( + concat, + stride, + padBegin, + padEnd, + kernel, + roundingType, + padType); + maxPool->set_friendly_name("MaxPool"); + + const auto result2 = std::make_shared(maxPool); + result2->set_friendly_name("Result_2"); + results.push_back(result2); + + std::shared_ptr function = std::make_shared( + results, + ngraph::ParameterVector{ input }, + "ConcatWithDifferentChildsTransformation"); + + return function; +} + std::shared_ptr ConcatFunction::getOriginalWithDifferentPrecisionOnChilds( const ngraph::element::Type precision, const ngraph::Shape& inputShape, @@ -985,6 +1100,117 @@ std::shared_ptr ConcatFunction::getReferenceSelectionWithInter return function; } +std::shared_ptr ConcatFunction::getReferenceWithStridedSlice( + const ngraph::element::Type inputPrecision, + const ngraph::Shape inputShape, + const FakeQuantizeOnData& fq1, + const FakeQuantizeOnData& fq2, + const DequantizationOperations& deqBefore, + const ngraph::element::Type precisionBeforeConcat, + const ngraph::element::Type precisionAfterConcat, + const bool ssBeforeConcat, + const bool ssAfterConcat, + const DequantizationOperations& deqAfter1, + const DequantizationOperations& deqAfter2) { + const auto input = std::make_shared(inputPrecision, inputShape); + input->set_friendly_name("input1"); + + const auto fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input, inputPrecision, fq1); + low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize1, precisionBeforeConcat); + fakeQuantize1->set_friendly_name("FakeQuantize_1"); + + std::shared_ptr parent1 = fakeQuantize1; + + if (ssBeforeConcat) { + const auto beginParam = ngraph::op::Constant::create( + ngraph::element::i64, + ngraph::Shape{ inputShape.size() }, + std::vector(inputShape.size(), 0)); + + const auto endParam = ngraph::op::Constant::create( + ngraph::element::i64, + ngraph::Shape{ inputShape.size() }, + std::vector{ inputShape[0], inputShape[1] - 2ul, inputShape[2], inputShape[3] }); + + const std::vector beginMask{ 1, 0, 1, 1 }; + const std::vector endMask{ 1, 0, 1, 1 }; + + parent1 = std::make_shared(parent1, beginParam, endParam, beginMask, endMask); + parent1->set_friendly_name("StridedSlice_1"); + } + + const auto dequantizationBefore = makeDequantization(fakeQuantize1, deqBefore); + const auto clamp = std::make_shared(dequantizationBefore, 0.0, 6.0); + clamp->set_friendly_name("Clamp"); + + const auto fakeQuantize2 = makeFakeQuantizeTypeRelaxed(clamp, inputPrecision, fq2); + low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize2, precisionBeforeConcat); + fakeQuantize2->set_friendly_name("FakeQuantize_2"); + + const auto concat = std::make_shared(NodeVector{ parent1, fakeQuantize2 }, 1); + concat->set_friendly_name("Concat"); + + ngraph::ResultVector results; + if (ssAfterConcat) { + const auto concatShape = concat->get_output_shape(0); + const auto beginParam = ngraph::op::Constant::create( + ngraph::element::i64, + ngraph::Shape{ concatShape.size() }, + std::vector(concatShape.size(), 0)); + + const auto endParam = ngraph::op::Constant::create( + ngraph::element::i64, + ngraph::Shape{ concatShape.size() }, + std::vector{ concatShape[0], concatShape[1] - 2ul, concatShape[2], concatShape[3] }); + + const std::vector beginMask{ 1, 0, 1, 1 }; + const std::vector endMask{ 1, 0, 1, 1 }; + + const auto stridedSlice = std::make_shared(concat, beginParam, endParam, beginMask, endMask); + stridedSlice->set_friendly_name("StridedSlice_2"); + + const auto dequantizationAfter1 = makeDequantization(stridedSlice, deqAfter1); + const auto result1 = std::make_shared(dequantizationAfter1); + result1->set_friendly_name("Result_1"); + results.push_back(result1); + } else { + const auto dequantizationAfter1 = makeDequantization(concat, deqAfter1); + const auto result1 = std::make_shared(dequantizationAfter1); + result1->set_friendly_name("Result_1"); + results.push_back(result1); + } + + const std::vector kernel = { 3, 3 }; + const std::vector stride = { 1, 1 }; + const std::vector padBegin = { 0, 0 }; + const std::vector padEnd = { 0, 0 }; + const ngraph::op::PadType padType = ngraph::op::PadType::NOTSET; + const ngraph::op::RoundingType roundingType = ngraph::op::RoundingType::FLOOR; + + const auto maxPool = std::make_shared( + concat, + stride, + padBegin, + padEnd, + kernel, + roundingType, + padType); + maxPool->set_friendly_name("MaxPool"); + + const auto dequantizationAfter2 = makeDequantization(maxPool, deqAfter2); + + const auto result2 = std::make_shared(dequantizationAfter2); + result2->set_friendly_name("Result_2"); + results.push_back(result2); + + std::shared_ptr function = std::make_shared( + results, + ngraph::ParameterVector{ input }, + "ConcatWithDifferentChildsTransformation"); + + return function; +} + std::shared_ptr ConcatFunction::getReferenceWithDifferentPrecisionOnChilds( const ngraph::element::Type precision, const ngraph::Shape& inputShape,