diff --git a/inference-engine/tests/functional/inference_engine/lp_transformations/split_transformation.cpp b/inference-engine/tests/functional/inference_engine/lp_transformations/split_transformation.cpp index ae2ada651ef18d..b08d154c17cbab 100644 --- a/inference-engine/tests/functional/inference_engine/lp_transformations/split_transformation.cpp +++ b/inference-engine/tests/functional/inference_engine/lp_transformations/split_transformation.cpp @@ -42,6 +42,7 @@ class SplitTransformationTestValues { ngraph::pass::low_precision::LayerTransformation::Params params; Actual actual; Expected expected; + bool addUnsupportedConcat; }; inline std::ostream& operator<<(std::ostream& os, @@ -74,7 +75,8 @@ class SplitTransformation : public LayerTransformation, public testing::WithPara testValues.actual.precisionBeforeDequantization, testValues.actual.dequantization, testValues.splitedAxis, - testValues.numSplits); + testValues.numSplits, + testValues.addUnsupportedConcat); SimpleLowPrecisionTransformer transformer; transformer.add(testValues.params.setSupportAsymmetricQuantization(true)); @@ -88,7 +90,8 @@ class SplitTransformation : public LayerTransformation, public testing::WithPara testValues.expected.precisionAfterOperation, testValues.expected.dequantizationAfter, testValues.splitedAxis, - testValues.numSplits); + testValues.numSplits, + testValues.addUnsupportedConcat); } static std::string getTestCaseName(testing::TestParamInfo obj) { @@ -429,6 +432,22 @@ const std::vector testValues = { } } }, + // issue #56781: unsupported Concat after Split + { + ngraph::Shape({ 1, 4, 3, 3 }), std::int64_t{2}, size_t{3}, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + {{ngraph::element::f32}, {128.f}, {3.f}} + }, + { + ngraph::element::u8, + {{ngraph::element::f32}, {128.f}, {3.f}}, + ngraph::element::f32, + {} + }, + true + }, // no dequantization { ngraph::Shape({ 1, 3, 4, 4 }), std::int64_t{2}, size_t{2}, diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/split_function.hpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/split_function.hpp index 46fa7cb9d61e4a..de4c8c089bbfff 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/split_function.hpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/split_function.hpp @@ -25,7 +25,8 @@ class SplitFunction { const ngraph::element::Type precisionBeforeDequantization, const ngraph::builder::subgraph::DequantizationOperations& dequantization, const int64_t splitedAxis, - const size_t numSplits); + const size_t numSplits, + const bool addUnsupportedConcat = false); static std::shared_ptr getOriginal( const ngraph::element::Type originalFunctionPrecision, @@ -42,7 +43,8 @@ class SplitFunction { const ngraph::element::Type precisionAfterOperation, const std::vector& dequantizationAfter, const int64_t splitedAxis, - const size_t numSplit); + const size_t numSplit, + const bool addUnsupportedConcat = false); }; } // namespace subgraph } // namespace builder diff --git a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/split_function.cpp b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/split_function.cpp index fe2e797cd32a32..43279a257d6595 100644 --- a/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/split_function.cpp +++ b/inference-engine/tests/ngraph_helpers/lpt_ngraph_functions/src/split_function.cpp @@ -23,7 +23,8 @@ std::shared_ptr SplitFunction::getOriginal( const ngraph::element::Type precisionBeforeDequantization, const ngraph::builder::subgraph::DequantizationOperations& dequantization, const int64_t splitedAxis, - const size_t numSplits) { + const size_t numSplits, + const bool addUnsupportedConcat) { const std::shared_ptr input = std::make_shared( precisionBeforeDequantization, ngraph::Shape(inputShape)); @@ -35,8 +36,14 @@ std::shared_ptr SplitFunction::getOriginal( const std::shared_ptr split = std::make_shared(dequantizationOp, constant, numSplits); ngraph::ResultVector results; - for (size_t i = 0; i < numSplits; ++i) { - results.push_back(std::make_shared(split->output(i))); + + if (addUnsupportedConcat) { + const auto concat = std::make_shared(split->outputs(), 2ul); + results.push_back(std::make_shared(concat)); + } else { + for (size_t i = 0; i < numSplits; ++i) { + results.push_back(std::make_shared(split->output(i))); + } } return std::make_shared(results, ngraph::ParameterVector{ input }, "SplitFunction"); } @@ -77,7 +84,8 @@ std::shared_ptr SplitFunction::getReference( const ngraph::element::Type precisionAfterOperation, const std::vector& dequantizationAfter, const int64_t splitedAxis, - const size_t numSplit) { + const size_t numSplit, + const bool addUnsupportedConcat) { const std::shared_ptr input = std::make_shared( inputPrecision, ngraph::Shape(inputShape)); @@ -89,17 +97,23 @@ std::shared_ptr SplitFunction::getReference( split = std::make_shared(deqBefore, constant, numSplit); ngraph::ResultVector results; - for (size_t i = 0; i < numSplit; ++i) { - if (!dequantizationAfter.empty()) { - auto dequantizationStructure = dequantizationAfter[i]; - if (!dequantizationStructure.multiply.empty()) { - dequantizationStructure.multiply.outPrecision = precision; + if (addUnsupportedConcat) { + const auto concat = std::make_shared(split->outputs(), 2ul); + results.push_back(std::make_shared(concat)); + } else { + for (size_t i = 0; i < numSplit; ++i) { + if (!dequantizationAfter.empty()) { + auto dequantizationStructure = dequantizationAfter[i]; + if (!dequantizationStructure.multiply.empty()) { + dequantizationStructure.multiply.outPrecision = precision; + } + results.push_back(std::make_shared(makeDequantization(split->output(i), dequantizationAfter[i]))); + } else { + results.push_back(std::make_shared(split->output(i))); } - results.push_back(std::make_shared(makeDequantization(split->output(i), dequantizationAfter[i]))); - } else { - results.push_back(std::make_shared(split->output(i))); } } + return std::make_shared(results, ngraph::ParameterVector{ input }, "SplitTransformation"); }