From 20ee134b8a829da410b537c2b0dbb6b7655dd3e2 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Thu, 5 Sep 2024 12:11:48 +0200 Subject: [PATCH] [LPT] Handle scale with convert in MatMulTransformation (#26398) ### Details: This PR handles Matmul dequantization subgraphs with convert on scales. In general, LPT don't expect Convert on dequantization scale: the convert is usually folded at the previous transformation pipeline steps. However, such subgraphs may arise on Matmul weights (in CPU plugin) since decompression handling logic keeps these subgraphs unchanged. Since decompression handling logic concerns only MatMuls, the changes in MatMul LPT is enough. ### Tickets: - *CVS-151589* --- .../src/mat_mul.cpp | 13 ++++ .../mat_mul_with_constant_transformation.cpp | 16 +++++ .../common/dequantization_operations.hpp | 6 +- .../ov_lpt_models/src/common/builders.cpp | 71 +++++++------------ .../src/common/dequantization_operations.cpp | 20 +++++- 5 files changed, 78 insertions(+), 48 deletions(-) diff --git a/src/common/low_precision_transformations/src/mat_mul.cpp b/src/common/low_precision_transformations/src/mat_mul.cpp index 9155e9bf877783..15afe2408cc459 100644 --- a/src/common/low_precision_transformations/src/mat_mul.cpp +++ b/src/common/low_precision_transformations/src/mat_mul.cpp @@ -222,6 +222,19 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context return false; } + // WA: LPT don't expect Convert on dequantization scale: the convert is usually folded at the previous transformation pipeline steps. + // However, such subgraphs may arise on Matmul weights since decompression handling logic keeps these subgraphs unchanged + // TODO: remove this logic when + // 1. CompressedMatmul is implemented + // 2. Or convert on scales is supported across the whole LPT pipeline + if (auto mul = ov::as_type_ptr(layer->get_input_node_shared_ptr(1))) { + if (auto convert = ov::as_type_ptr(mul->get_input_node_shared_ptr(1))) { + if (auto constant = ov::as_type_ptr(convert->get_input_node_shared_ptr(0))) { + auto new_constant = foldConvert(constant, convert->get_destination_type()); + ov::replace_node_update_name(convert, new_constant); + } + } + } const auto dequantization2 = NetworkHelper::getDequantization(layer, defaultPrecisions, 1); if (!dequantization2.empty()) { if ((updatePrecisions && !dequantization2.isLowPrecision())) { diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_with_constant_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_with_constant_transformation.cpp index b93dd32cb70f8d..a51a5a37362e5c 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_with_constant_transformation.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_with_constant_transformation.cpp @@ -65,6 +65,22 @@ std::vector testValues = { "FullyConnected", "u8" }, + // 4D with Dq on weights, with convert on scales + { + { 1, 1, 3, 4 }, + { 256ul, {{1, 1, 1}, {1, 1, 1}, {1, 3, 1}, {1, 3, 1}}, {0.f}, {255.f}, {0.f, 0.f, 0.f}, {255.f, 25.5f, 255.f} }, + { std::vector(4 * 2, 2.f), ov::element::i8, ov::Shape{ 2, 4 } }, + {}, + { + ov::element::f32, + {}, + ov::builder::subgraph::DequantizationOperations::Multiply({0.1f, 0.01f}, ov::element::f32, ov::Shape{ 2, 1 }) + .setConstantPrecision(ov::element::f16) + .setAddConvert(true) + }, + "FullyConnected", + "u8" + }, // 3D with the same values { { 1, 3, 4 }, diff --git a/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/common/dequantization_operations.hpp b/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/common/dequantization_operations.hpp index 716f00f36c364d..0fc95cb3c0f2f5 100644 --- a/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/common/dequantization_operations.hpp +++ b/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/common/dequantization_operations.hpp @@ -54,6 +54,7 @@ class DequantizationOperations { isEmpty = true; } Subtract& setConstantPrecision(const ov::element::Type& precision); + Subtract& setAddConvert(bool value); std::vector values; ov::element::Type outPrecision = ov::element::undefined; @@ -81,13 +82,15 @@ class DequantizationOperations { const ov::Shape& constantShape, const bool toRemove = false, const size_t constantIndex = 1ul, - const ov::element::Type constantPrecision = ov::element::undefined); + const ov::element::Type constantPrecision = ov::element::undefined, + const bool addConvert = false); bool empty() const noexcept; bool equal(const DequantizationOperations::Multiply& value) const noexcept; bool operator==(const Multiply& value) const noexcept { return equal(value); } Multiply& setConstantPrecision(const ov::element::Type& precision); + Multiply& setAddConvert(bool value); std::vector values; ov::element::Type outPrecision = ov::element::undefined; @@ -95,6 +98,7 @@ class DequantizationOperations { bool constantShapeIsDefined = false; size_t constantIndex = 1ul; ov::element::Type constantPrecision = ov::element::undefined; + bool addConvert = false; private: bool isEmpty; diff --git a/src/tests/ov_helpers/ov_lpt_models/src/common/builders.cpp b/src/tests/ov_helpers/ov_lpt_models/src/common/builders.cpp index 60f009753f2721..19f6a59161a4b8 100644 --- a/src/tests/ov_helpers/ov_lpt_models/src/common/builders.cpp +++ b/src/tests/ov_helpers/ov_lpt_models/src/common/builders.cpp @@ -81,24 +81,13 @@ std::shared_ptr makeDequantization( (((dequantizationOperations.subtract.constantPrecision == ov::element::undefined) || (dequantizationOperations.subtract.constantPrecision == parent.get_element_type())) || dequantizationOperations.subtract.addConvert)) { - subtract = dequantizationOperations.subtract.constantIndex == 1ul ? - std::make_shared(parent, subtractConst) : - subtract = std::make_shared(subtractConst, parent); + subtract = std::make_shared(leftBranchParent, rightBranchParent); } else { - if (dequantizationOperations.subtract.constantIndex == 1ul) { - subtract = std::make_shared>( - std::vector{ov::element::f32, ov::element::f32}, - std::vector{ov::element::f32}, - ov::op::TemporaryReplaceOutputType(parent, ov::element::f32).get(), - ov::op::TemporaryReplaceOutputType(subtractConst, ov::element::f32).get()); - } else { - subtract = std::make_shared>( - std::vector{ov::element::f32, ov::element::f32}, - std::vector{ov::element::f32}, - ov::op::TemporaryReplaceOutputType(subtractConst, ov::element::f32).get(), - ov::op::TemporaryReplaceOutputType(parent, ov::element::f32).get()); - } - + subtract = std::make_shared>( + std::vector{ov::element::f32, ov::element::f32}, + std::vector{ov::element::f32}, + ov::op::TemporaryReplaceOutputType(leftBranchParent, ov::element::f32).get(), + ov::op::TemporaryReplaceOutputType(rightBranchParent, ov::element::f32).get()); ov::pass::low_precision::NetworkHelper::setOutDataPrecision(subtract, dequantizationOperations.subtract.outPrecision); } @@ -143,38 +132,32 @@ std::shared_ptr makeMultiply(const ov::Output& parent, const Dequant } } + std::shared_ptr constant = std::make_shared( + multiply.constantPrecision != ov::element::undefined ? multiply.constantPrecision : parent.get_element_type(), + shape, + values); + if (multiply.addConvert) { + constant = std::make_shared( + constant, + multiply.outPrecision == ov::element::undefined ? parent.get_element_type() : multiply.outPrecision); + } + + ov::Output leftBranchParent = multiply.constantIndex == 1 ? parent : constant; + ov::Output rightBranchParent = multiply.constantIndex == 1 ? constant : parent; + std::shared_ptr newMultiply; if (((multiply.outPrecision == ov::element::undefined) || (multiply.outPrecision == parent.get_element_type())) && ((multiply.constantPrecision == ov::element::undefined) || - (multiply.constantPrecision == parent.get_element_type()))) { - const std::shared_ptr constant = std::make_shared( - multiply.constantPrecision != ov::element::undefined ? multiply.constantPrecision - : parent.get_element_type(), - shape, - values); - - newMultiply = multiply.constantIndex == 1ul ? - std::make_shared(parent, constant) : - std::make_shared(constant, parent); + (multiply.constantPrecision == parent.get_element_type())) || + multiply.addConvert) { + newMultiply = std::make_shared(leftBranchParent, rightBranchParent); } else { - const std::shared_ptr constant = std::make_shared( - multiply.constantPrecision != ov::element::undefined ? multiply.constantPrecision - : parent.get_element_type(), - shape, - values); - // TODO: use templates - newMultiply = multiply.constantIndex == 1ul - ? std::make_shared>( - std::vector{ov::element::f32, ov::element::f32}, - std::vector{multiply.outPrecision}, - ov::op::TemporaryReplaceOutputType(parent, ov::element::f32).get(), - ov::op::TemporaryReplaceOutputType(constant, ov::element::f32).get()) - : std::make_shared>( - std::vector{ov::element::f32, ov::element::f32}, - std::vector{multiply.outPrecision}, - ov::op::TemporaryReplaceOutputType(constant, ov::element::f32).get(), - ov::op::TemporaryReplaceOutputType(parent, ov::element::f32).get()); + newMultiply = std::make_shared>( + std::vector{ov::element::f32, ov::element::f32}, + std::vector{multiply.outPrecision}, + ov::op::TemporaryReplaceOutputType(leftBranchParent, ov::element::f32).get(), + ov::op::TemporaryReplaceOutputType(rightBranchParent, ov::element::f32).get()); } return newMultiply; diff --git a/src/tests/ov_helpers/ov_lpt_models/src/common/dequantization_operations.cpp b/src/tests/ov_helpers/ov_lpt_models/src/common/dequantization_operations.cpp index 8677028aefe378..95e8a1dab2cd6b 100644 --- a/src/tests/ov_helpers/ov_lpt_models/src/common/dequantization_operations.cpp +++ b/src/tests/ov_helpers/ov_lpt_models/src/common/dequantization_operations.cpp @@ -96,6 +96,11 @@ DequantizationOperations::Subtract& DequantizationOperations::Subtract::setConst return *this; } +DequantizationOperations::Subtract& DequantizationOperations::Subtract::setAddConvert(bool value) { + addConvert = value; + return *this; +} + DequantizationOperations::Multiply::Multiply() : isEmpty(true), outPrecision(ov::element::undefined), @@ -129,13 +134,15 @@ DequantizationOperations::Multiply::Multiply( const ov::Shape& constantShape, const bool toRemove, const size_t constantIndex, - ov::element::Type constantPrecision) : + ov::element::Type constantPrecision, + const bool addConvert) : isEmpty(false), values(values), outPrecision(outPrecision), constantShape(constantShape), constantIndex(constantIndex), constantPrecision(constantPrecision), + addConvert(addConvert), constantShapeIsDefined(true) { } @@ -166,6 +173,11 @@ DequantizationOperations::Multiply& DequantizationOperations::Multiply::setConst return *this; } +DequantizationOperations::Multiply& DequantizationOperations::Multiply::setAddConvert(bool value) { + addConvert = value; + return *this; +} + DequantizationOperations::DequantizationOperations() {} DequantizationOperations::DequantizationOperations( @@ -179,9 +191,11 @@ DequantizationOperations::DequantizationOperations( void DequantizationOperations::setPrecision(const ov::element::Type& type) noexcept { convert.outPrecision = type; - subtract.constantPrecision = type; + if (!subtract.addConvert) + subtract.constantPrecision = type; subtract.outPrecision = type; - multiply.constantPrecision = type; + if (!multiply.addConvert) + multiply.constantPrecision = type; multiply.outPrecision = type; }