diff --git a/inference-engine/src/low_precision_transformations/src/move_fake_quantize.cpp b/inference-engine/src/low_precision_transformations/src/move_fake_quantize.cpp index 627650e6232ee5..7f933185340089 100644 --- a/inference-engine/src/low_precision_transformations/src/move_fake_quantize.cpp +++ b/inference-engine/src/low_precision_transformations/src/move_fake_quantize.cpp @@ -35,20 +35,31 @@ MoveFakeQuantize::MoveFakeQuantize(const Params& params) : LayerTransformation(p bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern::Matcher& m) { auto fq = m.get_match_root(); - auto relu = fq->get_input_node_shared_ptr(0); - auto concat = relu->get_input_node_shared_ptr(0); auto result = *fq->output(0).get_target_inputs().begin(); - auto input1 = concat->get_input_node_shared_ptr(0); - auto input2 = concat->get_input_node_shared_ptr(1); - auto relu1 = std::make_shared(input1->output(0)); - auto relu2 = std::make_shared(input2->output(0)); - auto fq1 = std::make_shared(relu1, + auto operation = fq->get_input_node_shared_ptr(0); + auto type = operation->get_type_name(); + std::shared_ptr concat, fq1input, fq2input; + if (strcmp(type, "Concat") == 0) { + concat = operation; + fq1input = operation->get_input_node_shared_ptr(0); + fq2input = operation->get_input_node_shared_ptr(1); + } + else { + concat = operation->get_input_node_shared_ptr(0); + auto input1 = concat->get_input_node_shared_ptr(0); + auto input2 = concat->get_input_node_shared_ptr(1); + if (strcmp(type, "Relu") == 0) { + fq1input = std::make_shared(input1->output(0)); + fq2input = std::make_shared(input2->output(0)); + } + } + auto fq1 = std::make_shared(fq1input, fq->get_input_node_shared_ptr(1), fq->get_input_node_shared_ptr(2), fq->get_input_node_shared_ptr(3), fq->get_input_node_shared_ptr(4), as_type_ptr(fq)->get_levels()); - auto fq2 = std::make_shared(relu2, + auto fq2 = std::make_shared(fq2input, fq->get_input_node_shared_ptr(1), fq->get_input_node_shared_ptr(2), fq->get_input_node_shared_ptr(3), diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/move_fake_quantize_for_concat_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/move_fake_quantize_for_concat_transformation.cpp index a1f2b15a84fe24..8badaef6416f08 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/move_fake_quantize_for_concat_transformation.cpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/move_fake_quantize_for_concat_transformation.cpp @@ -43,51 +43,61 @@ namespace { class MoveFakeQuantizeActualValues { public: - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize1; - ngraph::builder::subgraph::DequantizationOperations::Convert convert1; - ngraph::builder::subgraph::DequantizationOperations dequantization1; - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize2; - ngraph::builder::subgraph::DequantizationOperations::Convert convert2; - ngraph::builder::subgraph::DequantizationOperations dequantization2; - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize3; - ngraph::builder::subgraph::DequantizationOperations::Convert convert3; - ngraph::builder::subgraph::DequantizationOperations dequantization3; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore1; //before1 + ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore1; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore1; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore2; //before1 + ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore2; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore2; + std::string operation; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeAfter; // after + ngraph::builder::subgraph::DequantizationOperations::Convert convertAfter; + ngraph::builder::subgraph::DequantizationOperations dequantizationAfter; }; inline std::ostream& operator<<(std::ostream& out, const MoveFakeQuantizeActualValues& values) { return out << "_" << - values.fakeQuantize1 << "_" << - values.convert1.outPrecision << "_" << - values.dequantization1 << "_" << - values.fakeQuantize2 << "_" << - values.convert2.outPrecision << "_" << - values.dequantization2; + values.fakeQuantizeBefore1 << "_" << + values.convertBefore1.outPrecision << "_" << + values.dequantizationBefore1 << "_" << + values.fakeQuantizeBefore2 << "_" << + values.convertBefore2.outPrecision << "_" << + values.dequantizationBefore2 << "_" << + values.operation << "_" << + values.fakeQuantizeAfter << "_" << + values.convertAfter.outPrecision << "_" << + values.dequantizationAfter; } class MoveFakeQuantizeResultValues { public: - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize1; - ngraph::builder::subgraph::DequantizationOperations::Convert convert1; - ngraph::builder::subgraph::DequantizationOperations dequantization1; - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize2; - ngraph::builder::subgraph::DequantizationOperations::Convert convert2; - ngraph::builder::subgraph::DequantizationOperations dequantization2; - ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize3; - ngraph::builder::subgraph::DequantizationOperations::Convert convert3; - ngraph::builder::subgraph::DequantizationOperations dequantization3; - ngraph::element::Type precisionAfterOperation; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore1; + ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore1; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore1; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeBefore2; + ngraph::builder::subgraph::DequantizationOperations::Convert convertBefore2; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore2; + std::string operation; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantizeAfter; + ngraph::builder::subgraph::DequantizationOperations::Convert convertAfter; ngraph::builder::subgraph::DequantizationOperations dequantizationAfter; + ngraph::element::Type precisionAfterOperation; + ngraph::builder::subgraph::DequantizationOperations dequantizationAfterNotFQ; }; inline std::ostream& operator<<(std::ostream& out, const MoveFakeQuantizeResultValues& values) { return out << "_" << - values.fakeQuantize1 << "_" << - values.convert1.outPrecision << "_" << - values.dequantization1 << "_" << - values.fakeQuantize2 << "_" << - values.convert2.outPrecision << "_" << - values.dequantization2 << "_" << - values.dequantizationAfter; + values.fakeQuantizeBefore1 << "_" << + values.convertBefore1.outPrecision << "_" << + values.dequantizationBefore1 << "_" << + values.fakeQuantizeBefore2 << "_" << + values.convertBefore2.outPrecision << "_" << + values.dequantizationBefore2 << "_" << + values.operation << "_" << + values.fakeQuantizeAfter << "_" << + values.convertAfter << "_" << + values.dequantizationAfter << "_" << + values.dequantizationAfterNotFQ; } class MoveFakeQuantizeTestValues { @@ -139,24 +149,25 @@ class MoveFakeQuantize : public LayerTransformation, public testing::WithParamIn // dequantization output precision depends on input precision // to avoid huge amount of tests cases let's define dequantization output precision as input precision - if (!testValues.actual.dequantization1.multiply.empty()) { - testValues.actual.dequantization1.multiply.outPrecision = precision; + if (!testValues.actual.dequantizationBefore1.multiply.empty()) { + testValues.actual.dequantizationBefore1.multiply.outPrecision = precision; } - if (!testValues.actual.dequantization2.multiply.empty()) { - testValues.actual.dequantization2.multiply.outPrecision = precision; + if (!testValues.actual.dequantizationBefore2.multiply.empty()) { + testValues.actual.dequantizationBefore2.multiply.outPrecision = precision; } actualFunction = ngraph::builder::subgraph::MoveFakeQuantize::get( precision, shape, - testValues.actual.fakeQuantize1, - testValues.actual.convert1, - testValues.actual.dequantization1, - testValues.actual.fakeQuantize2, - testValues.actual.convert2, - testValues.actual.dequantization2, - testValues.actual.fakeQuantize3, - testValues.actual.convert3, - testValues.actual.dequantization3, + testValues.actual.fakeQuantizeBefore1, + testValues.actual.convertBefore1, + testValues.actual.dequantizationBefore1, + testValues.actual.fakeQuantizeBefore2, + testValues.actual.convertBefore2, + testValues.actual.dequantizationBefore2, + testValues.actual.operation, + testValues.actual.fakeQuantizeAfter, + testValues.actual.convertAfter, + testValues.actual.dequantizationAfter, ngraph::element::undefined, {}, testValues.axis); @@ -196,15 +207,16 @@ class MoveFakeQuantize : public LayerTransformation, public testing::WithParamIn referenceFunction = ngraph::builder::subgraph::MoveFakeQuantize::get( precision, shape, - testValues.result.fakeQuantize1, - testValues.result.convert1, - testValues.result.dequantization1, - testValues.result.fakeQuantize2, - testValues.result.convert2, - testValues.result.dequantization2, - testValues.result.fakeQuantize3, - testValues.result.convert3, - testValues.result.dequantization3, + testValues.result.fakeQuantizeBefore1, + testValues.result.convertBefore1, + testValues.result.dequantizationBefore1, + testValues.result.fakeQuantizeBefore2, + testValues.result.convertBefore2, + testValues.result.dequantizationBefore2, + testValues.result.operation, + testValues.result.fakeQuantizeAfter, + testValues.result.convertAfter, + testValues.result.dequantizationAfter, testValues.result.precisionAfterOperation, {}, testValues.axis); @@ -270,6 +282,38 @@ const std::vector testValues = { {}, {}, {}, + "", + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {}, + {} + }, + { + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {}, + {}, + { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {}, + {}, + "", + {}, + {}, + {}, + }, + false, + false + }, + { + LayerTransformation::createParamsU8I8(), + false, + 1, + { + {}, + {}, + {}, + {}, + {}, + {}, + "relu", { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, {}, {} @@ -281,13 +325,14 @@ const std::vector testValues = { { 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, {}, {}, + "relu", {}, {}, {}, }, false, false - }, + } }; INSTANTIATE_TEST_SUITE_P( smoke_LPT, diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/move_fake_quantize_function.hpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/move_fake_quantize_function.hpp index efe2e5b429f5de..c02045d8803f0f 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/move_fake_quantize_function.hpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/move_fake_quantize_function.hpp @@ -26,6 +26,7 @@ class MoveFakeQuantize { const FakeQuantizeOnDataWithConstant& fakeQuantize2, const DequantizationOperations::Convert& convert2, const DequantizationOperations& dequantization2, + const std::string operation, const FakeQuantizeOnDataWithConstant& fakeQuantize3, const DequantizationOperations::Convert& convert3, const DequantizationOperations& dequantization3, diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/move_fake_quantize_function.cpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/move_fake_quantize_function.cpp index 6ee2fc75b7d009..baf66cc57ec4ce 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/move_fake_quantize_function.cpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/move_fake_quantize_function.cpp @@ -28,6 +28,7 @@ std::shared_ptr MoveFakeQuantize::get( const FakeQuantizeOnDataWithConstant& fqOnData2, const DequantizationOperations::Convert& convert2, const DequantizationOperations& dequantization2, + const std::string operation, const FakeQuantizeOnDataWithConstant& fqOnData3, const DequantizationOperations::Convert& convert3, const DequantizationOperations& dequantization3, @@ -43,9 +44,18 @@ std::shared_ptr MoveFakeQuantize::get( if (fqOnData3.empty()) { - auto relu1 = std::make_shared(input1->output(0)); - auto relu2 = std::make_shared(input2->output(0)); - std::shared_ptr parent1 = makeFakeQuantizeTypeRelaxed(relu1, inputPrecision, fqOnData1); + std::shared_ptr parent1, parent2; + if (operation == "relu") { + auto relu1 = std::make_shared(input1->output(0)); + auto relu2 = std::make_shared(input2->output(0)); + parent1 = makeFakeQuantizeTypeRelaxed(relu1, inputPrecision, fqOnData1); + parent2 = makeFakeQuantizeTypeRelaxed(relu2, inputPrecision, fqOnData2); + } + else { + parent1 = makeFakeQuantizeTypeRelaxed(input1, inputPrecision, fqOnData1); + parent2 = makeFakeQuantizeTypeRelaxed(input1, inputPrecision, fqOnData2); + } + if (!convert1.empty()) { parent1 = std::make_shared(parent1, convert1.outPrecision); } @@ -53,7 +63,6 @@ std::shared_ptr MoveFakeQuantize::get( parent1 = makeDequantization(parent1, dequantization1); } - std::shared_ptr parent2 = makeFakeQuantizeTypeRelaxed(relu2, inputPrecision, fqOnData2); if (!convert2.empty()) { parent2 = std::make_shared(parent2, convert2.outPrecision); } @@ -81,8 +90,14 @@ std::shared_ptr MoveFakeQuantize::get( rtInfo["Variant::std::string"] = std::make_shared>("concat"); concat->set_friendly_name("output"); - auto relu = std::make_shared(concat->output(0)); - std::shared_ptr fq = makeFakeQuantize(relu, inputPrecision, fqOnData3); + std::shared_ptr fq; + if (operation == "relu") { + auto relu = std::make_shared(concat->output(0)); + fq = makeFakeQuantize(relu, inputPrecision, fqOnData3); + } + else { + fq = makeFakeQuantize(concat, inputPrecision, fqOnData3); + } ngraph::ResultVector results{ std::make_shared(fq) }; std::shared_ptr function = std::make_shared(