From 79cba9aa7bf2f25a6b28855859a72ac4ecb74d0a Mon Sep 17 00:00:00 2001 From: eshoguli Date: Wed, 27 Sep 2023 22:23:46 +0100 Subject: [PATCH] [LPT] DisableCleanup usage --- .../include/low_precision/low_precision.hpp | 9 +++ .../src/low_precision.cpp | 4 ++ .../src/multiply.cpp | 62 +++++++++---------- 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/src/common/low_precision_transformations/include/low_precision/low_precision.hpp b/src/common/low_precision_transformations/include/low_precision/low_precision.hpp index 9236113c731052..d3fb4b7ec55834 100644 --- a/src/common/low_precision_transformations/include/low_precision/low_precision.hpp +++ b/src/common/low_precision_transformations/include/low_precision/low_precision.hpp @@ -71,9 +71,18 @@ class ov::pass::low_precision::LowPrecision : public ov::pass::ModelPass { static bool isFunctionQuantized(const std::shared_ptr& model); static bool isFQLevelsPresent(const std::shared_ptr& model, const std::set& levels); + template + std::shared_ptr add_main(Args&&... args) { + const auto tr = std::make_shared(std::forward(args)...); + added_main.push_back(tr); + return tr; + } + protected: std::vector precisionRestrictions; std::vector quantizationRestrictions; // remove LayerTransformation::Params params; + + std::vector> added_main; }; diff --git a/src/common/low_precision_transformations/src/low_precision.cpp b/src/common/low_precision_transformations/src/low_precision.cpp index df02353dd67c2b..553e78adbadd37 100644 --- a/src/common/low_precision_transformations/src/low_precision.cpp +++ b/src/common/low_precision_transformations/src/low_precision.cpp @@ -273,6 +273,10 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptradd_matcher(tr); + } + std::shared_ptr cleanup = manager.register_pass(); ADD_MATCHER(cleanup, EliminateFakeQuantizeTransformation, params) ADD_MATCHER(cleanup, FoldConvertTransformation, params) diff --git a/src/common/low_precision_transformations/src/multiply.cpp b/src/common/low_precision_transformations/src/multiply.cpp index 3a150e76a7e789..cc654d3deff706 100644 --- a/src/common/low_precision_transformations/src/multiply.cpp +++ b/src/common/low_precision_transformations/src/multiply.cpp @@ -15,6 +15,7 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "low_precision/common/ie_lpt_exception.hpp" +#include "low_precision/rt_info/disable_cleanup_attribute.hpp" #include "low_precision/network_helper.hpp" #include "itt.hpp" @@ -61,31 +62,36 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass: // before: y = (deq_scales1 * (x1 - zero_point1)) * (deq_scales2 * (x2 - zero_point2)) // after : y = deq_scales1 * deq_scales2 * (x1 - zero_point1) * (x2 - zero_point2) - if ((dequantization1.empty() && (ov::is_type(dequantization1.data.get_node()))) || - (dequantization2.empty() && (ov::is_type(dequantization2.data.get_node())))) { - // one input is constant - auto new_scales_values = fold( - dequantization1.empty() ? dequantization1.data : dequantization1.multiplyConstant, - dequantization2.empty() ? dequantization2.data : dequantization2.multiplyConstant); + auto new_scales_values = fold( + dequantization1.empty() ? dequantization1.data : dequantization1.multiplyConstant, + dequantization2.empty() ? dequantization2.data : dequantization2.multiplyConstant); - if (!ov::is_type(new_scales_values)) { - return false; + if (!ov::is_type(new_scales_values)) { + return false; + } + + const auto init_input = [&new_scales_values](const FakeQuantizeDequantization& dequantization) -> Output { + if (dequantization.empty()) { + return new_scales_values; } - const auto create_input = [&new_scales_values](const FakeQuantizeDequantization& dequantization) -> Output { - if (dequantization.empty()) { - return new_scales_values; - } + if (dequantization.subtract == nullptr) { + return dequantization.data; + } - if (dequantization.subtract == nullptr) { - return dequantization.data; - } + const auto subtract = NetworkHelper::optimizeSubtract(dequantization.subtract); + if (subtract != nullptr) { + DisableCleanupAttribute::create(subtract); + } - const auto subtract = NetworkHelper::optimizeSubtract(dequantization.subtract); - return subtract == nullptr ? dequantization.data : subtract; - }; - const Output in1 = create_input(dequantization1); - const Output in2 = create_input(dequantization2); + return subtract == nullptr ? dequantization.data : subtract; + }; + + if ((dequantization1.empty() && (ov::is_type(dequantization1.data.get_node()))) || + (dequantization2.empty() && (ov::is_type(dequantization2.data.get_node())))) { + // one input is constant + const Output in1 = init_input(dequantization1); + const Output in2 = init_input(dequantization2); const auto new_multiply = (in1.get_element_type() == multiply->get_output_element_type(0)) && (in2.get_element_type() == multiply->get_output_element_type(0)) ? @@ -102,18 +108,8 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass: return true; } - auto new_scales_values = fold(dequantization1.multiplyConstant, dequantization2.multiplyConstant); - if (!ov::is_type(new_scales_values)) { - return false; - } - - const Output in1 = dequantization1.subtract == nullptr ? - dequantization1.data : - NetworkHelper::optimizeSubtract(dequantization1.subtract); - - const Output in2 = dequantization2.subtract == nullptr ? - dequantization2.data : - NetworkHelper::optimizeSubtract(dequantization2.subtract); + Output in1 = init_input(dequantization1); + Output in2 = init_input(dequantization2); // in1 & in2 can have different input types const auto new_multiply = (in1.get_element_type() == deqPrecision) && @@ -125,6 +121,8 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass: ov::op::TemporaryReplaceOutputType(in1, deqPrecision).get(), ov::op::TemporaryReplaceOutputType(in2, deqPrecision).get()); + DisableCleanupAttribute::create(new_multiply); + auto new_scales = (new_multiply->get_output_element_type(0) == multiply->get_output_element_type(0)) && (new_scales_values->get_output_element_type(0) == multiply->get_output_element_type(0)) ? std::make_shared(new_multiply, new_scales_values) :