diff --git a/inference-engine/src/low_precision_transformations/src/split.cpp b/inference-engine/src/low_precision_transformations/src/split.cpp index c6a1f4b1df1b0f..486111dd73778a 100644 --- a/inference-engine/src/low_precision_transformations/src/split.cpp +++ b/inference-engine/src/low_precision_transformations/src/split.cpp @@ -130,7 +130,20 @@ bool SplitTransformation::isPrecisionPreserved(std::shared_ptr layer) cons } bool SplitTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { - return (!NetworkHelper::getDequantization(layer).empty()) && LayerTransformation::canBeTransformed(context, layer); + if (!LayerTransformation::canBeTransformed(context, layer) || NetworkHelper::getDequantization(layer).empty()) { + return false; + } + + const auto consumers = NetworkHelper::consumers(layer); + const auto concat = as_type_ptr(consumers[0]); + + // WA to avoid propagation of dequantization if after Split all consumers are the same unsupported Concat + if (concat && concat->get_axis() != 1ul) { + const size_t id = consumers[0]->get_instance_id(); + return std::any_of(consumers.begin(), consumers.end(), [&](const std::shared_ptr& node) { return node->get_instance_id() != id; }); + } + + return true; } } // namespace low_precision diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/split_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/split_transformation.cpp index ae2ada651ef18d..b08d154c17cbab 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/split_transformation.cpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/split_transformation.cpp @@ -42,6 +42,7 @@ class SplitTransformationTestValues { ngraph::pass::low_precision::LayerTransformation::Params params; Actual actual; Expected expected; + bool addUnsupportedConcat; }; inline std::ostream& operator<<(std::ostream& os, @@ -74,7 +75,8 @@ class SplitTransformation : public LayerTransformation, public testing::WithPara testValues.actual.precisionBeforeDequantization, testValues.actual.dequantization, testValues.splitedAxis, - testValues.numSplits); + testValues.numSplits, + testValues.addUnsupportedConcat); SimpleLowPrecisionTransformer transformer; transformer.add(testValues.params.setSupportAsymmetricQuantization(true)); @@ -88,7 +90,8 @@ class SplitTransformation : public LayerTransformation, public testing::WithPara testValues.expected.precisionAfterOperation, testValues.expected.dequantizationAfter, testValues.splitedAxis, - testValues.numSplits); + testValues.numSplits, + testValues.addUnsupportedConcat); } static std::string getTestCaseName(testing::TestParamInfo obj) { @@ -429,6 +432,22 @@ const std::vector testValues = { } } }, + // issue #56781: unsupported Concat after Split + { + ngraph::Shape({ 1, 4, 3, 3 }), std::int64_t{2}, size_t{3}, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + {{ngraph::element::f32}, {128.f}, {3.f}} + }, + { + ngraph::element::u8, + {{ngraph::element::f32}, {128.f}, {3.f}}, + ngraph::element::f32, + {} + }, + true + }, // no dequantization { ngraph::Shape({ 1, 3, 4, 4 }), std::int64_t{2}, size_t{2}, diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/split_function.hpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/split_function.hpp index 46fa7cb9d61e4a..de4c8c089bbfff 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/split_function.hpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/split_function.hpp @@ -25,7 +25,8 @@ class SplitFunction { const ngraph::element::Type precisionBeforeDequantization, const ngraph::builder::subgraph::DequantizationOperations& dequantization, const int64_t splitedAxis, - const size_t numSplits); + const size_t numSplits, + const bool addUnsupportedConcat = false); static std::shared_ptr getOriginal( const ngraph::element::Type originalFunctionPrecision, @@ -42,7 +43,8 @@ class SplitFunction { const ngraph::element::Type precisionAfterOperation, const std::vector& dequantizationAfter, const int64_t splitedAxis, - const size_t numSplit); + const size_t numSplit, + const bool addUnsupportedConcat = false); }; } // namespace subgraph } // namespace builder diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/split_function.cpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/split_function.cpp index fe2e797cd32a32..43279a257d6595 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/split_function.cpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/split_function.cpp @@ -23,7 +23,8 @@ std::shared_ptr SplitFunction::getOriginal( const ngraph::element::Type precisionBeforeDequantization, const ngraph::builder::subgraph::DequantizationOperations& dequantization, const int64_t splitedAxis, - const size_t numSplits) { + const size_t numSplits, + const bool addUnsupportedConcat) { const std::shared_ptr input = std::make_shared( precisionBeforeDequantization, ngraph::Shape(inputShape)); @@ -35,8 +36,14 @@ std::shared_ptr SplitFunction::getOriginal( const std::shared_ptr split = std::make_shared(dequantizationOp, constant, numSplits); ngraph::ResultVector results; - for (size_t i = 0; i < numSplits; ++i) { - results.push_back(std::make_shared(split->output(i))); + + if (addUnsupportedConcat) { + const auto concat = std::make_shared(split->outputs(), 2ul); + results.push_back(std::make_shared(concat)); + } else { + for (size_t i = 0; i < numSplits; ++i) { + results.push_back(std::make_shared(split->output(i))); + } } return std::make_shared(results, ngraph::ParameterVector{ input }, "SplitFunction"); } @@ -77,7 +84,8 @@ std::shared_ptr SplitFunction::getReference( const ngraph::element::Type precisionAfterOperation, const std::vector& dequantizationAfter, const int64_t splitedAxis, - const size_t numSplit) { + const size_t numSplit, + const bool addUnsupportedConcat) { const std::shared_ptr input = std::make_shared( inputPrecision, ngraph::Shape(inputShape)); @@ -89,17 +97,23 @@ std::shared_ptr SplitFunction::getReference( split = std::make_shared(deqBefore, constant, numSplit); ngraph::ResultVector results; - for (size_t i = 0; i < numSplit; ++i) { - if (!dequantizationAfter.empty()) { - auto dequantizationStructure = dequantizationAfter[i]; - if (!dequantizationStructure.multiply.empty()) { - dequantizationStructure.multiply.outPrecision = precision; + if (addUnsupportedConcat) { + const auto concat = std::make_shared(split->outputs(), 2ul); + results.push_back(std::make_shared(concat)); + } else { + for (size_t i = 0; i < numSplit; ++i) { + if (!dequantizationAfter.empty()) { + auto dequantizationStructure = dequantizationAfter[i]; + if (!dequantizationStructure.multiply.empty()) { + dequantizationStructure.multiply.outPrecision = precision; + } + results.push_back(std::make_shared(makeDequantization(split->output(i), dequantizationAfter[i]))); + } else { + results.push_back(std::make_shared(split->output(i))); } - results.push_back(std::make_shared(makeDequantization(split->output(i), dequantizationAfter[i]))); - } else { - results.push_back(std::make_shared(split->output(i))); } } + return std::make_shared(results, ngraph::ParameterVector{ input }, "SplitTransformation"); }