diff --git a/src/common/low_precision_transformations/src/add.cpp b/src/common/low_precision_transformations/src/add.cpp index 6a66c077cf0de9..1ba6f6598be247 100644 --- a/src/common/low_precision_transformations/src/add.cpp +++ b/src/common/low_precision_transformations/src/add.cpp @@ -108,9 +108,6 @@ bool AddTransformation::transform(TransformationContext& context, ov::pass::patt return false; } - NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(op, defaultPrecisions, 0)); - NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(op, defaultPrecisions, 1)); - std::shared_ptr addNode = NetworkHelper::separateInStandaloneBranch(op, defaultPrecisions); std::shared_ptr add = ov::as_type_ptr(addNode); diff --git a/src/common/low_precision_transformations/src/concat.cpp b/src/common/low_precision_transformations/src/concat.cpp index a287a0ff27d1f8..ae5e4615daa5ca 100644 --- a/src/common/low_precision_transformations/src/concat.cpp +++ b/src/common/low_precision_transformations/src/concat.cpp @@ -40,11 +40,11 @@ ConcatTransformation::ConcatTransformation(const Params& params) : LayerTransfor } bool ConcatTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher &m) { - std::shared_ptr concat = ov::as_type_ptr(m.get_match_root()); - if (!canBeTransformed(context, concat)) { + if (!canBeTransformed(context, m.get_match_root())) { return false; } + const auto concat = ov::as_type_ptr(NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions)); std::vector layerDequantizations; layerDequantizations.reserve(concat->get_input_size()); for (size_t parentIndex = 0ul; parentIndex < concat->get_input_size(); parentIndex++) { diff --git a/src/common/low_precision_transformations/src/multiply.cpp b/src/common/low_precision_transformations/src/multiply.cpp index cc654d3deff706..4c1f3c073febcf 100644 --- a/src/common/low_precision_transformations/src/multiply.cpp +++ b/src/common/low_precision_transformations/src/multiply.cpp @@ -49,9 +49,6 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass: multiply = NetworkHelper::separateInStandaloneBranch(multiply, defaultPrecisions); decomposeFakeQuantizeForWeightsPath(multiply); - NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0)); - NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1)); - const auto dequantization1 = NetworkHelper::getDequantization(multiply, defaultPrecisions, 0); const auto dequantization2 = NetworkHelper::getDequantization(multiply, defaultPrecisions, 1); diff --git a/src/common/low_precision_transformations/src/multiply_partial.cpp b/src/common/low_precision_transformations/src/multiply_partial.cpp index d4c29890988c8a..c0760d4b1f1c01 100644 --- a/src/common/low_precision_transformations/src/multiply_partial.cpp +++ b/src/common/low_precision_transformations/src/multiply_partial.cpp @@ -45,9 +45,6 @@ bool MultiplyPartialTransformation::transform(TransformationContext& context, ov return false; } - NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0)); - NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1)); - multiply = NetworkHelper::separateInStandaloneBranch(multiply, defaultPrecisions); auto newMultiply = multiply; diff --git a/src/common/low_precision_transformations/src/mvn.cpp b/src/common/low_precision_transformations/src/mvn.cpp index 701114d2a4a490..10ea85152e4f1f 100644 --- a/src/common/low_precision_transformations/src/mvn.cpp +++ b/src/common/low_precision_transformations/src/mvn.cpp @@ -123,11 +123,7 @@ bool MVNTransformation::transform(TransformationContext &context, ov::pass::patt return false; } - std::shared_ptr mvn = ov::as_type_ptr(operation); - if (!mvn) { - mvn = ov::as_type_ptr(operation); - } - + const auto mvn = NetworkHelper::separateInStandaloneBranch(operation, defaultPrecisions); bool normalizeVariance; if (ov::is_type(mvn)) { normalizeVariance = ov::as_type_ptr(mvn)->get_normalize_variance(); diff --git a/src/common/low_precision_transformations/src/network_helper.cpp b/src/common/low_precision_transformations/src/network_helper.cpp index 2b26567f7a8307..95958bd67f2910 100644 --- a/src/common/low_precision_transformations/src/network_helper.cpp +++ b/src/common/low_precision_transformations/src/network_helper.cpp @@ -585,8 +585,12 @@ FakeQuantizeDequantization NetworkHelper::foldDequantization(const std::shared_p std::shared_ptr NetworkHelper::separateInStandaloneBranch(std::shared_ptr node, const std::vector& defaultPrecisions) { - FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(node, defaultPrecisions); - if (dequantization.isShared() && !dequantization.empty()) { + auto inputs = node->input_values(); + auto separate_branch = [&](size_t input_idx) { + const auto dequantization = NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(node, defaultPrecisions, input_idx)); + if (dequantization.empty() || !dequantization.isShared()) + return false; + Output parent = dequantization.data; if (dequantization.convert != nullptr) { auto convert = dequantization.convert->clone_with_new_inputs({ parent }); @@ -619,22 +623,27 @@ std::shared_ptr NetworkHelper::separateInStandaloneBranch(std::shared_ parent = multiply->output(0); } - std::vector> inputs = node->input_values(); const auto originalParent = dequantization.multiply ? dequantization.multiply->shared_from_this() : dequantization.subtract->shared_from_this(); const size_t inputIndex = NetworkHelper::getChildInputIndex(originalParent, node); inputs[inputIndex] = parent; - const std::shared_ptr newNode = node->clone_with_new_inputs(inputs); - copy_runtime_info(node, newNode); - replace_node(node, newNode); - newNode->set_friendly_name(node->get_friendly_name()); + return true; + }; - return newNode; - } + bool branch_separation_happened = false; + for (size_t i = 0; i < node->get_input_size(); ++i) + branch_separation_happened |= separate_branch(i); + + if (!branch_separation_happened) + return node; - return node; + const auto newNode = node->clone_with_new_inputs(inputs); + copy_runtime_info(node, newNode); + replace_node(node, newNode); + newNode->set_friendly_name(node->get_friendly_name()); + return newNode; } std::shared_ptr NetworkHelper::fuseConvert(const std::shared_ptr& fakeQuantize) { diff --git a/src/common/low_precision_transformations/src/reduce_base_transformation.cpp b/src/common/low_precision_transformations/src/reduce_base_transformation.cpp index e7fa01611f8743..ef34cae049b58c 100644 --- a/src/common/low_precision_transformations/src/reduce_base_transformation.cpp +++ b/src/common/low_precision_transformations/src/reduce_base_transformation.cpp @@ -22,7 +22,7 @@ bool ReduceBaseTransformation::transform(TransformationContext& context, ov::pas } const auto reduce = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions); - auto dequantization = NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(reduce, defaultPrecisions)); + auto dequantization = NetworkHelper::getDequantization(reduce, defaultPrecisions); // prepare dequantization to propagate changeDequantizationValues(reduce, dequantization); diff --git a/src/common/low_precision_transformations/tests/separate_in_standalone_branch_transformation.cpp b/src/common/low_precision_transformations/tests/separate_in_standalone_branch_transformation.cpp index b60880133c0e9e..291bca278b9bf5 100644 --- a/src/common/low_precision_transformations/tests/separate_in_standalone_branch_transformation.cpp +++ b/src/common/low_precision_transformations/tests/separate_in_standalone_branch_transformation.cpp @@ -85,7 +85,13 @@ class SeparateInStandaloneBranchTransformation : const auto createReferenceFunction = []( const ov::element::Type precision, const ov::Shape& inputShape, - const ov::builder::subgraph::DequantizationOperations& dequantization) -> std::shared_ptr { + ov::builder::subgraph::DequantizationOperations dequantization) -> std::shared_ptr { + // Note: separateInStandaloneBranch normalizes dequantization so constant indexes become equal to 1 + if (!dequantization.subtract.empty()) + dequantization.subtract.constantIndex = 1; + if (!dequantization.multiply.empty()) + dequantization.multiply.constantIndex = 1; + const std::shared_ptr input = std::make_shared(precision, inputShape); const auto relu = std::make_shared(input); @@ -118,7 +124,7 @@ class SeparateInStandaloneBranchTransformation : std::tie(shapes, testValues) = obj.param; std::stringstream ss; - ss << shapes << "_" << "_" << testValues; + ss << shapes << "_" << testValues; return ss.str(); } }; @@ -133,7 +139,6 @@ TEST_P(SeparateInStandaloneBranchTransformation, CompareFunctions) { const std::vector shapes = { { 1, 3, 9, 9 }, - { 4, 3, 9, 9 } }; std::vector testValues = { @@ -155,7 +160,16 @@ std::vector testValues = { { {127.f}, ov::element::f32, {}, true, 1ul, ov::element::u8, true}, { 0.02f } } - } + }, + { + LayerTransformation::createParamsU8U8(), + ov::element::u8, + { + ov::element::f32, + { {127.f}, ov::element::f32, {}, false, 0ul}, + { {0.02f}, ov::element::f32, {}, false, 0ul } + } + }, }; INSTANTIATE_TEST_SUITE_P(