From 9f5267baea3a89d97e94f13facde475fe2f27182 Mon Sep 17 00:00:00 2001 From: Vladimir Zinoviev Date: Fri, 28 May 2021 14:27:45 +0300 Subject: [PATCH] Tests improvement (#5704) * [LPT] Test: concat with convolution neighbor and convolution after * [LPT] elementwise fuse to FakeQuantize * [LPT] plugin test build fix Co-authored-by: Edward Shogulin --- .../concat_with_neighbors_transformation.cpp | 84 +++++++-- ...ltiply_to_fake_quantize_transformation.cpp | 30 ++- ...btract_to_fake_quantize_transformation.cpp | 42 ++++- ...at_with_neighbors_graph_transformation.cpp | 4 +- .../common/dequantization_operations.hpp | 4 +- .../lpt_ngraph_functions/concat_function.hpp | 8 +- .../src/concat_function.cpp | 174 ++++++++++++++---- 7 files changed, 280 insertions(+), 66 deletions(-) diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/concat_with_neighbors_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/concat_with_neighbors_transformation.cpp index 3008272bbfaea0..c099b523b39140 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/concat_with_neighbors_transformation.cpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/concat_with_neighbors_transformation.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include "common_test_utils/ngraph_test_utils.hpp" #include "lpt_ngraph_functions/concat_function.hpp" @@ -65,6 +66,8 @@ class ConcatTransformationTestValues { bool multiChannels; ConcatTransformationActualValues actual; ConcatTransformationResultValues result; + std::string neighborType; + std::string additionalLayer; }; inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) { @@ -89,15 +92,25 @@ class ConcatWithNeighborsTransformation : public LayerTransformation, public tes shape, testValues.actual.fakeQuantize1, testValues.actual.fakeQuantize2, - testValues.actual.fakeQuantize3); + testValues.actual.fakeQuantize3, + testValues.neighborType, + testValues.additionalLayer); - SimpleLowPrecisionTransformer transform; + SimpleLowPrecisionTransformer transformBranchSpecific; if (testValues.multiChannels) { - transform.add(testValues.params); + transformBranchSpecific.add(testValues.params); } else { - transform.add(testValues.params); + transformBranchSpecific.add(testValues.params); + } + if (testValues.additionalLayer == "convolution" || testValues.neighborType == "convolution") { + transformBranchSpecific.add(testValues.params); + } + transformBranchSpecific.transform(actualFunction); + if (testValues.additionalLayer == "convolution" || testValues.neighborType == "convolution") { + SimpleLowPrecisionTransformer transformConvolution; + transformConvolution.add(testValues.params); + transformConvolution.transform(actualFunction); } - transform.transform(actualFunction); referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithNeighbors( precision, @@ -109,7 +122,9 @@ class ConcatWithNeighborsTransformation : public LayerTransformation, public tes testValues.result.dequantizationBefore, testValues.result.precisionAfterOp, testValues.result.dequantizationAfter1, - testValues.result.dequantizationAfter2); + testValues.result.dequantizationAfter2, + testValues.neighborType, + testValues.additionalLayer); } static std::string getTestCaseName(testing::TestParamInfo obj) { @@ -157,7 +172,9 @@ const std::vector testValues = { ngraph::element::u8, { ngraph::element::f32, {}, { 0.01f } }, { ngraph::element::f32, {}, { 0.01f } } - } + }, + "concat", + "" }, // U8: concat multi channels { @@ -177,7 +194,9 @@ const std::vector testValues = { ngraph::element::u8, { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }, { ngraph::element::f32, {}, {{ 0.005f, 0.005f, 0.005f, 0.00333f, 0.00333f, 0.00333f }} } - } + }, + "concat", + "" }, // U8: concat multi channels with subtract { @@ -197,7 +216,9 @@ const std::vector testValues = { ngraph::element::u8, { ngraph::element::f32, {{ 0.f, 0.f, 0.f, -255.f, -255.f, -255.f }}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }, { ngraph::element::f32, { -255.f }, { 0.005f } } - } + }, + "concat", + "" }, // I8: concat { @@ -217,7 +238,9 @@ const std::vector testValues = { ngraph::element::i8, { ngraph::element::f32, {}, { 0.01f } }, { ngraph::element::f32, {}, { 0.01f } } - } + }, + "concat", + "" }, // I8: concat multi channels { @@ -237,7 +260,9 @@ const std::vector testValues = { ngraph::element::i8, { ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} }, { ngraph::element::f32, {}, {{ 0.005f, 0.005f, 0.005f, 0.00333f, 0.00333f, 0.00333f }} } - } + }, + "concat", + "" }, // mixed: U8 + I8: concat multi channels { @@ -257,7 +282,9 @@ const std::vector testValues = { ngraph::element::u8, { ngraph::element::f32, {{ 0.f, 0.f, 0.f, 128.f, 128.f, 128.f }}, { 0.01f } }, { ngraph::element::f32, { 128.f }, { 0.01f } } - } + }, + "concat", + "" }, // not update precisions { @@ -277,7 +304,38 @@ const std::vector testValues = { ngraph::element::f32, { {}, {{ 0.f, 0.f, 0.f, 128.f, 128.f, 128.f }}, { 0.01f } }, { {}, { 128.f }, { 0.01f } } - } + }, + "concat", + "" + }, + // convolution neighbor and additional layer + // different precisions on FQ, u8 have to be chosen + { + LayerTransformation::createParamsU8I8(), + true, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, + { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-12.8f}, {12.7f} }, + {} + }, + { + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {128.f}, {154.f} }, + { 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f} }, + {}, + ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + { + {}, + {{ 128.f, 128.f, 128.f, 128.f, 128.f, 128.f }, ngraph::element::f32, { 1, 6, 1, 1 }, false}, + {{0.1f}, ngraph::element::f32, { 1, 1, 1, 1 } } }, + { + {}, + {{128.f, 128.f, 128.f}, ngraph::element::f32, { 1, 3, 1, 1 }, false}, + {{0.1f}, ngraph::element::f32, { 1, 1, 1, 1 } } } + }, + "convolution", + "convolution" }, }; diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/fuse_multiply_to_fake_quantize_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/fuse_multiply_to_fake_quantize_transformation.cpp index 6d2873007ff465..bf0ce1b4484b84 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/fuse_multiply_to_fake_quantize_transformation.cpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/fuse_multiply_to_fake_quantize_transformation.cpp @@ -44,11 +44,20 @@ class FuseMultiplyToFakeQuantizeTransformationTestValues { Expected expected; }; +typedef std::tuple FuseMultiplyToFakeQuantizeTransformationTestParams; + class FuseMultiplyToFakeQuantizeTransformation : public LayerTransformation, - public testing::WithParamInterface { + public testing::WithParamInterface { public: void SetUp() override { - const FuseMultiplyToFakeQuantizeTransformationTestValues testValues = GetParam(); + const size_t quantizationLevel = std::get<0>(GetParam()); + FuseMultiplyToFakeQuantizeTransformationTestValues testValues = std::get<1>(GetParam()); + if (!testValues.actual.fakeQuantizeOnData.empty()) { + testValues.actual.fakeQuantizeOnData.quantizationLevel = quantizationLevel; + } + if (!testValues.expected.fakeQuantizeOnData.empty()) { + testValues.expected.fakeQuantizeOnData.quantizationLevel = quantizationLevel; + } actualFunction = ngraph::builder::subgraph::FuseMultiplyToFakeQuantizeFunction::get( testValues.inputShape, @@ -65,8 +74,15 @@ class FuseMultiplyToFakeQuantizeTransformation : public LayerTransformation, testValues.expected.dequantization); } - static std::string getTestCaseName(testing::TestParamInfo obj) { - const FuseMultiplyToFakeQuantizeTransformationTestValues testValues = obj.param; + static std::string getTestCaseName(testing::TestParamInfo obj) { + const size_t quantizationLevel = std::get<0>(obj.param); + FuseMultiplyToFakeQuantizeTransformationTestValues testValues = std::get<1>(obj.param); + if (!testValues.actual.fakeQuantizeOnData.empty()) { + testValues.actual.fakeQuantizeOnData.quantizationLevel = quantizationLevel; + } + if (!testValues.expected.fakeQuantizeOnData.empty()) { + testValues.expected.fakeQuantizeOnData.quantizationLevel = quantizationLevel; + } std::ostringstream result; result << testValues.params.updatePrecisions << "_" << @@ -83,6 +99,8 @@ TEST_P(FuseMultiplyToFakeQuantizeTransformation, CompareFunctions) { ASSERT_TRUE(res.first) << res.second; } +std::vector quantizationLevels = { 256ul, 128ul }; + const std::vector testValues = { { Shape{1, 3, 16, 16}, @@ -125,7 +143,9 @@ const std::vector testValues INSTANTIATE_TEST_CASE_P( smoke_LPT, FuseMultiplyToFakeQuantizeTransformation, - ::testing::ValuesIn(testValues), + ::testing::Combine( + ::testing::ValuesIn(quantizationLevels), + ::testing::ValuesIn(testValues)), FuseMultiplyToFakeQuantizeTransformation::getTestCaseName); } // namespace diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/fuse_subtract_to_fake_quantize_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/fuse_subtract_to_fake_quantize_transformation.cpp index df725ae5c2379e..94016eb98f8705 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/fuse_subtract_to_fake_quantize_transformation.cpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/fuse_subtract_to_fake_quantize_transformation.cpp @@ -49,11 +49,26 @@ class FuseSubtractToFakeQuantizeTransformationTestValues { Expected expected; }; +typedef std::tuple FuseSubtractToFakeQuantizeTransformationTestParams; + class FuseSubtractToFakeQuantizeTransformation : public LayerTransformation, - public testing::WithParamInterface { + public testing::WithParamInterface { public: void SetUp() override { - const FuseSubtractToFakeQuantizeTransformationTestValues testValues = GetParam(); + const size_t quantizationLevel = get<0>(GetParam()); + FuseSubtractToFakeQuantizeTransformationTestValues testValues = get<1>(GetParam()); + if (!testValues.actual.fakeQuantizeOnData.empty()) { + testValues.actual.fakeQuantizeOnData.quantizationLevel = quantizationLevel; + } + if (!testValues.actual.fakeQuantizeOnData2.empty()) { + testValues.actual.fakeQuantizeOnData2.quantizationLevel = quantizationLevel; + } + if (!testValues.expected.fakeQuantizeOnData.empty()) { + testValues.expected.fakeQuantizeOnData.quantizationLevel = quantizationLevel; + } + if (!testValues.expected.fakeQuantizeOnData2.empty()) { + testValues.expected.fakeQuantizeOnData2.quantizationLevel = quantizationLevel; + } actualFunction = testValues.actual.fakeQuantizeOnData2.empty() ? ngraph::builder::subgraph::FuseSubtractToFakeQuantizeFunction::get( @@ -84,8 +99,21 @@ class FuseSubtractToFakeQuantizeTransformation : public LayerTransformation, testValues.expected.dequantization2); } - static std::string getTestCaseName(testing::TestParamInfo obj) { - const FuseSubtractToFakeQuantizeTransformationTestValues testValues = obj.param; + static std::string getTestCaseName(testing::TestParamInfo obj) { + const size_t quantizationLevel = get<0>(obj.param); + FuseSubtractToFakeQuantizeTransformationTestValues testValues = get<1>(obj.param); + if (!testValues.actual.fakeQuantizeOnData.empty()) { + testValues.actual.fakeQuantizeOnData.quantizationLevel = quantizationLevel; + } + if (!testValues.actual.fakeQuantizeOnData2.empty()) { + testValues.actual.fakeQuantizeOnData2.quantizationLevel = quantizationLevel; + } + if (!testValues.expected.fakeQuantizeOnData.empty()) { + testValues.expected.fakeQuantizeOnData.quantizationLevel = quantizationLevel; + } + if (!testValues.expected.fakeQuantizeOnData2.empty()) { + testValues.expected.fakeQuantizeOnData2.quantizationLevel = quantizationLevel; + } std::ostringstream result; result << testValues.params.updatePrecisions << "_" << @@ -104,6 +132,8 @@ TEST_P(FuseSubtractToFakeQuantizeTransformation, CompareFunctions) { ASSERT_TRUE(res.first) << res.second; } +std::vector quantizationLevels = { 256ul, 128ul }; + const std::vector testValues = { { Shape{1, 3, 16, 16}, @@ -190,7 +220,9 @@ const std::vector testValues INSTANTIATE_TEST_CASE_P( smoke_LPT, FuseSubtractToFakeQuantizeTransformation, - ::testing::ValuesIn(testValues), + ::testing::Combine( + ::testing::ValuesIn(quantizationLevels), + ::testing::ValuesIn(testValues)), FuseSubtractToFakeQuantizeTransformation::getTestCaseName); } // namespace diff --git a/inference-engine/tests/functional/plugin/shared/src/low_precision_transformations/concat_with_neighbors_graph_transformation.cpp b/inference-engine/tests/functional/plugin/shared/src/low_precision_transformations/concat_with_neighbors_graph_transformation.cpp index d5d0d21a6db910..912982c2ea6847 100644 --- a/inference-engine/tests/functional/plugin/shared/src/low_precision_transformations/concat_with_neighbors_graph_transformation.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/low_precision_transformations/concat_with_neighbors_graph_transformation.cpp @@ -52,7 +52,9 @@ void ConcatWithNeighborsGraphTransformation::SetUp() { inputShape, { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} }, { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f / 2.f} }, - { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f / 3.f} }); + { 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f / 3.f} }, + "concat", + ""); validate(); } diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/dequantization_operations.hpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/dequantization_operations.hpp index db9702ab631022..79ef40d7e3da5f 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/dequantization_operations.hpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/common/dequantization_operations.hpp @@ -50,7 +50,9 @@ class DequantizationOperations { bool operator==(const Subtract& value) const noexcept { return equal(value); } - + void erase() { + isEmpty = true; + } Subtract& setConstantPrecision(const ngraph::element::Type& precision); std::vector values; diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp index 4d0c7c249e7e01..c0c1686ca5521c 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/concat_function.hpp @@ -40,7 +40,9 @@ class ConcatFunction { const ngraph::Shape& inputShape, const FakeQuantizeOnData& fqOnData1, const FakeQuantizeOnData& fqOnData2, - const FakeQuantizeOnData& fqOnData3); + const FakeQuantizeOnData& fqOnData3, + const std::string& neighborType, + const std::string& additionalLayer); static std::shared_ptr getOriginalWithIntermediate( const ngraph::element::Type precision, @@ -128,7 +130,9 @@ class ConcatFunction { const DequantizationOperations& dequantizationBefore, const ngraph::element::Type precisionAfterOperation, const DequantizationOperations& dequantizationOperations1, - const DequantizationOperations& dequantizationOperations2); + const DequantizationOperations& dequantizationOperations2, + const std::string& neighborType, + const std::string& additionalLayer); static std::shared_ptr getReferenceWithIntermediate( const ngraph::element::Type precision, diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp index 64357d96aeb03e..15108abb73e3c8 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/concat_function.cpp @@ -115,7 +115,9 @@ std::shared_ptr ConcatFunction::getOriginalWithNeighbors( const ngraph::Shape& inputShape, const FakeQuantizeOnData& fqOnData1, const FakeQuantizeOnData& fqOnData2, - const FakeQuantizeOnData& fqOnData3) { + const FakeQuantizeOnData& fqOnData3, + const std::string& neighborType, + const std::string& additionalLayer) { const auto input1 = std::make_shared(precision, ngraph::Shape(inputShape)); input1->set_friendly_name("input1"); const auto fakeQuantize1 = makeFakeQuantize(input1, precision, fqOnData1); @@ -126,11 +128,6 @@ std::shared_ptr ConcatFunction::getOriginalWithNeighbors( const auto fakeQuantize2 = makeFakeQuantize(input2, precision, fqOnData2); fakeQuantize2->set_friendly_name("fakeQuantize2"); - const auto input3 = std::make_shared(precision, ngraph::Shape(inputShape)); - input3->set_friendly_name("input3"); - const auto fakeQuantize3 = makeFakeQuantize(input3, precision, fqOnData3); - fakeQuantize3->set_friendly_name("fakeQuantize3"); - const auto concat1 = std::make_shared( ngraph::OutputVector { fakeQuantize1->output(0), fakeQuantize2->output(0) }, 1ull); @@ -139,22 +136,62 @@ std::shared_ptr ConcatFunction::getOriginalWithNeighbors( auto& rtInfo1 = concat1->get_rt_info(); rtInfo1["Variant::std::string"] = std::make_shared>("concat1"); - const auto concat2 = std::make_shared( - ngraph::OutputVector { fakeQuantize2->output(0), fakeQuantize3->output(0) }, - 1ull); - concat2->set_friendly_name("concat2"); - - auto& rtInfo2 = concat2->get_rt_info(); - rtInfo2["Variant::std::string"] = std::make_shared>("concat2"); + ngraph::ParameterVector inputs{input1, input2}; + + ngraph::ResultVector results { }; + if (additionalLayer == "convolution") { + auto convShape = inputShape; + convShape[1] += convShape[1]; + convShape[0] = convShape[1] * 2; + convShape[2] = convShape[3] = 1; + auto convolutionAddition = std::make_shared( + concat1, + std::make_shared( + std::make_shared(opset1::Constant::create(element::i8, convShape, {1}), element::f32), + opset1::Constant::create(element::f32, Shape{}, {1})), + ngraph::Strides{ 1, 1 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::Strides{ 1, 1 }); + convolutionAddition->set_friendly_name("convolution_addition"); + results.push_back(std::make_shared(convolutionAddition)); + } + if (neighborType == "concat") { + const auto input3 = std::make_shared(precision, ngraph::Shape(inputShape)); + input3->set_friendly_name("input3"); + const auto fakeQuantize3 = makeFakeQuantize(input3, precision, fqOnData3); + fakeQuantize3->set_friendly_name("fakeQuantize3"); + inputs.push_back(input3); + + const auto concat2 = std::make_shared( + ngraph::OutputVector { fakeQuantize2->output(0), fakeQuantize3->output(0) }, + 1ull); + concat2->set_friendly_name("concat2"); + auto& rtInfo2 = concat2->get_rt_info(); + rtInfo2["Variant::std::string"] = std::make_shared>("concat2"); + results.push_back(std::make_shared(concat1)); + results.push_back(std::make_shared(concat2)); + } else if (neighborType == "convolution") { + auto convShape = inputShape; + convShape[0] = convShape[1] * 2; + convShape[2] = convShape[3] = 1; + auto convolutionNeighbor = std::make_shared( + fakeQuantize2, + std::make_shared( + std::make_shared(opset1::Constant::create(element::i8, convShape, {1}), element::f32), + opset1::Constant::create(element::f32, Shape{}, {1})), + ngraph::Strides{ 1, 1 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::Strides{ 1, 1 }); + convolutionNeighbor->set_friendly_name("convolution_neighbor"); + results.push_back(std::make_shared(convolutionNeighbor)); + } - const ngraph::ResultVector results { - std::make_shared(concat1), - std::make_shared(concat2) - }; std::shared_ptr function = std::make_shared( results, - ngraph::ParameterVector { input1, input2, input3 }, + inputs, "ConcatWithNeighborsTransformation"); return function; @@ -808,7 +845,9 @@ std::shared_ptr ConcatFunction::getReferenceWithNeighbors( const DequantizationOperations& dequantizationBefore, const ngraph::element::Type precisionAfterOperation, const DequantizationOperations& dequantizationOperations1, - const DequantizationOperations& dequantizationOperations2) { + const DequantizationOperations& dequantizationOperations2, + const std::string& neighborType, + const std::string& additionalLayer) { const auto input1 = std::make_shared(precision, inputShape); input1->set_friendly_name("input1"); @@ -825,14 +864,6 @@ std::shared_ptr ConcatFunction::getReferenceWithNeighbors( fakeQuantize2->set_friendly_name("fakeQuantize2"); const auto deqBefore2 = makeDequantization(fakeQuantize2, dequantizationBefore); - const auto input3 = std::make_shared(precision, inputShape); - input3->set_friendly_name("input3"); - - const auto fakeQuantize3 = makeFakeQuantizeTypeRelaxed(input3, precision, fqOnData3); - low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize3, precisionBeforeOp); - fakeQuantize3->set_friendly_name("fakeQuantize3"); - const auto deqBefore3 = makeDequantization(fakeQuantize3, dequantizationBefore); - const auto concat1 = std::make_shared( ngraph::OutputVector { deqBefore1, deqBefore2 }, 1ull); @@ -841,19 +872,84 @@ std::shared_ptr ConcatFunction::getReferenceWithNeighbors( auto& rtInfo1 = concat1->get_rt_info(); rtInfo1["Variant::std::string"] = std::make_shared>("concat1"); - const auto concat2 = std::make_shared( - ngraph::OutputVector { deqBefore2, deqBefore3 }, - 1ull); - concat2->set_friendly_name("concat2"); - - auto& rtInfo2 = concat2->get_rt_info(); - rtInfo2["Variant::std::string"] = std::make_shared>("concat2"); + ngraph::ParameterVector inputs{input1, input2}; + std::shared_ptr mainBranch = concat1; + std::string output_name1 = "concat1"; + auto deqCopy1 = dequantizationOperations1; + if (additionalLayer == "convolution") { + if (!deqCopy1.subtract.empty()) { + DequantizationOperations deqSubtract; + deqSubtract.subtract = deqCopy1.subtract; + mainBranch = makeDequantization(mainBranch, deqSubtract); + deqCopy1.subtract.erase(); + } + auto convShape = inputShape; + convShape[1] += convShape[1]; + convShape[0] = convShape[1] * 2; + convShape[2] = convShape[3] = 1; + auto convolutionAddition = std::make_shared>( + element::TypeVector{ element::f32, element::f32 }, + element::TypeVector{ element::f32 }, + op::TemporaryReplaceOutputType(mainBranch, element::f32).get(), + op::TemporaryReplaceOutputType(opset1::Constant::create(element::i8, convShape, {1}), element::f32).get(), + ngraph::Strides{ 1, 1 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::Strides{ 1, 1 }); + convolutionAddition->set_friendly_name("convolution_addition"); + output_name1 = "convolution_addition"; + mainBranch = convolutionAddition; + } + std::shared_ptr neighbor = fakeQuantize2; + auto deqCopy2 = dequantizationOperations2; + std::string output_name2 = "concat2"; + if (neighborType == "concat") { + const auto input3 = std::make_shared(precision, inputShape); + input3->set_friendly_name("input3"); + inputs.push_back(input3); + + const auto fakeQuantize3 = makeFakeQuantizeTypeRelaxed(input3, precision, fqOnData3); + low_precision::NetworkHelper::setOutDataPrecisionForTypeRelaxed(fakeQuantize3, precisionBeforeOp); + fakeQuantize3->set_friendly_name("fakeQuantize3"); + const auto deqBefore3 = makeDequantization(fakeQuantize3, dequantizationBefore); + + const auto concat2 = std::make_shared( + ngraph::OutputVector { deqBefore2, deqBefore3 }, + 1ull); + concat2->set_friendly_name("concat2"); + auto& rtInfo2 = concat2->get_rt_info(); + rtInfo2["Variant::std::string"] = std::make_shared>("concat2"); + + neighbor = concat2; + } else if (neighborType == "convolution") { + if (!deqCopy2.subtract.empty()) { + DequantizationOperations deqSubtract; + deqSubtract.subtract = deqCopy2.subtract; + neighbor = makeDequantization(neighbor, deqSubtract); + deqCopy2.subtract.erase(); + } + auto convShape = inputShape; + convShape[0] = convShape[1] * 2; + convShape[2] = convShape[3] = 1; + auto convolutionNeighbor = std::make_shared>( + element::TypeVector{ element::f32, element::f32 }, + element::TypeVector{ element::f32 }, + op::TemporaryReplaceOutputType(neighbor, element::f32).get(), + op::TemporaryReplaceOutputType(opset1::Constant::create(element::i8, convShape, {1}), element::f32).get(), + ngraph::Strides{ 1, 1 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::CoordinateDiff{ 0, 0 }, + ngraph::Strides{ 1, 1 }); + convolutionNeighbor->set_friendly_name("convolution_neighbor"); + output_name2 = "convolution_neighbor"; + neighbor = convolutionNeighbor; + } - const std::shared_ptr lastDequantization1 = makeDequantization(concat1, dequantizationOperations1); - lastDequantization1->set_friendly_name("concat1"); + const std::shared_ptr lastDequantization1 = makeDequantization(mainBranch, deqCopy1); + lastDequantization1->set_friendly_name(output_name1); - const std::shared_ptr lastDequantization2 = makeDequantization(concat2, dequantizationOperations2); - lastDequantization2->set_friendly_name("concat2"); + const std::shared_ptr lastDequantization2 = makeDequantization(neighbor, deqCopy2); + lastDequantization2->set_friendly_name(output_name2); const ngraph::ResultVector results { std::make_shared(lastDequantization1), @@ -862,7 +958,7 @@ std::shared_ptr ConcatFunction::getReferenceWithNeighbors( std::shared_ptr function = std::make_shared( results, - ngraph::ParameterVector { input1, input2, input3 }, + inputs, "ConcatWithNeighborsTransformation"); return function;