diff --git a/src/common/low_precision_transformations/src/multiply.cpp b/src/common/low_precision_transformations/src/multiply.cpp index abf14224657655..525799993ccf17 100644 --- a/src/common/low_precision_transformations/src/multiply.cpp +++ b/src/common/low_precision_transformations/src/multiply.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2023 Intel Corporation +// Copyright (C) 2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -44,7 +44,6 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass: return false; } - // TODO: normalizeDequantization + fold_fake_quantizes + foldDequantization <= ??? NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0)); NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1)); @@ -65,12 +64,8 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass: fold_fake_quantizes(multiply, 1ul); const auto dequantization1 = NetworkHelper::foldDequantization(multiply, 0, defaultPrecisions); - if (dequantization1.multiplyConstant == nullptr) { - return false; - } - const auto dequantization2 = NetworkHelper::foldDequantization(multiply, 1, defaultPrecisions); - if (dequantization2.multiplyConstant == nullptr) { + if ((dequantization1.multiplyConstant == nullptr) && (dequantization2.multiplyConstant == nullptr)) { return false; } @@ -79,7 +74,50 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass: // X1` = X1 - SH1 // X2` = X2 - SH2 // SC1' = SC1 * SC2 + + if ((dequantization1.empty() && (ov::is_type(dequantization1.data.get_node()))) || + (dequantization2.empty() && (ov::is_type(dequantization2.data.get_node())))) { + // one input is constant + auto new_scales_values = fold( + dequantization1.empty() ? dequantization1.data : dequantization1.multiplyConstant, + dequantization2.empty() ? dequantization2.data : dequantization2.multiplyConstant); + + if (!ov::is_type(new_scales_values)) { + return false; + } + + const Output in1 = dequantization1.empty() ? + new_scales_values : + dequantization1.subtract == nullptr ? + dequantization1.data : + NetworkHelper::optimizeSubtract(dequantization1.subtract); + + const Output in2 = dequantization2.empty() ? + new_scales_values : + dequantization2.subtract == nullptr ? + dequantization2.data : + NetworkHelper::optimizeSubtract(dequantization2.subtract); + + auto const new_multiply = (in1.get_element_type() == multiply->get_output_element_type(0)) && + (in2.get_element_type() == multiply->get_output_element_type(0)) ? + std::make_shared(in1, in2) : + std::make_shared>( + std::vector{ deqPrecision, deqPrecision }, + std::vector{ multiply->get_output_element_type(0) }, + ov::op::TemporaryReplaceOutputType(in1, deqPrecision).get(), + ov::op::TemporaryReplaceOutputType(in2, deqPrecision).get()); + + replace_node(multiply, new_multiply); + updateOutput(context, new_multiply, multiply); + + return true; + } + auto new_scales_values = fold(dequantization1.multiplyConstant, dequantization2.multiplyConstant); + if (!ov::is_type(new_scales_values)) { + return false; + } + const Output in1 = dequantization1.subtract == nullptr ? dequantization1.data : NetworkHelper::optimizeSubtract(dequantization1.subtract); @@ -89,19 +127,23 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass: NetworkHelper::optimizeSubtract(dequantization2.subtract); // in1 & in2 can have different input types - auto const new_multiply = std::make_shared>( - std::vector{ element::f32, element::f32 }, - std::vector{ element::f32 }, - ov::op::TemporaryReplaceOutputType(in1, element::f32).get(), - ov::op::TemporaryReplaceOutputType(in2, element::f32).get()); + const auto new_multiply = (in1.get_element_type() == deqPrecision) && + (in2.get_element_type() == deqPrecision) ? + std::make_shared(in1, in2) : + std::make_shared>( + std::vector{ deqPrecision, deqPrecision }, + std::vector{ deqPrecision }, + ov::op::TemporaryReplaceOutputType(in1, deqPrecision).get(), + ov::op::TemporaryReplaceOutputType(in2, deqPrecision).get()); NetworkHelper::copyInfo(multiply, newMultiply); - auto new_scales = new_multiply->get_output_element_type(0) != multiply->get_output_element_type(0) ? + auto new_scales = (new_multiply->get_output_element_type(0) == multiply->get_output_element_type(0)) && + (new_scales_values->get_output_element_type(0) == multiply->get_output_element_type(0)) ? + std::make_shared(new_multiply, new_scales_values) : std::make_shared>( ov::opset1::Multiply(new_multiply, new_scales_values), - multiply->get_output_element_type(0)) : - std::make_shared(new_multiply, new_scales_values); + multiply->get_output_element_type(0)); replace_node(multiply, new_scales); updateOutput(context, new_scales, multiply); diff --git a/src/common/low_precision_transformations/tests/lpt_avoid_shapeof_propagation_test.cpp b/src/common/low_precision_transformations/tests/lpt_avoid_shapeof_propagation_test.cpp index 431c4459a4c57b..f2459620019351 100644 --- a/src/common/low_precision_transformations/tests/lpt_avoid_shapeof_propagation_test.cpp +++ b/src/common/low_precision_transformations/tests/lpt_avoid_shapeof_propagation_test.cpp @@ -22,7 +22,7 @@ #include "low_precision/interpolate.hpp" #include "low_precision/mat_mul.hpp" #include "low_precision/max_pool.hpp" -#include "low_precision/multiply.hpp" +#include "low_precision/multiply_partial.hpp" #include "low_precision/mvn.hpp" #include "low_precision/network_helper.hpp" #include "low_precision/normalize_l2.hpp" @@ -361,7 +361,7 @@ TEST(LPT, AvoidDequantizationToShapeOfPropagationMultiplyTransformation) { auto f = std::make_shared(ResultVector{result1, result2}, ParameterVector{input1, input2}); pass::Manager m; - m.register_pass(); + m.register_pass(); m.run_passes(f); auto dqBeforeShapeOf = ov::pass::low_precision::NetworkHelper::getDequantization(result2->get_input_node_shared_ptr(0)); diff --git a/src/common/low_precision_transformations/tests/multiply_transformation.cpp b/src/common/low_precision_transformations/tests/multiply_transformation.cpp index 162163477590ab..b51125fab0c296 100644 --- a/src/common/low_precision_transformations/tests/multiply_transformation.cpp +++ b/src/common/low_precision_transformations/tests/multiply_transformation.cpp @@ -20,6 +20,8 @@ #include "simple_low_precision_transformer.hpp" #include "lpt_ngraph_functions/multiply_function.hpp" +#include "ngraph/pass/serialize.hpp" + namespace { using namespace testing; using namespace ov; @@ -34,7 +36,12 @@ class MultiplyBranch { }; inline std::ostream& operator<<(std::ostream& out, const MultiplyBranch& branch) { - return out << "_" << branch.constant << "_" << branch.input_precision << "_" << branch.dequantization; + if (branch.constant.empty()) { + out << "_input=" << branch.input_precision; + } else { + out << "_constant=" << branch.constant; + } + return out << "_" << branch.dequantization; } class MultiplyValues { @@ -155,16 +162,22 @@ class MultiplyTransformation : public LayerTransformation, public testing::WithP // low precision has to be defined by tests parameters static void update_input_precisions(const std::pair& input_precisions, - MultiplyTransformationTestValues& test_values) { + MultiplyTransformationTestValues& test_values) { const auto update_values = [](const std::pair& input_precisions, MultiplyValues& values) { - if (values.branch1.input_precision == MultiplyTransformationTestValues::input_precision) { - values.branch1.input_precision = input_precisions.first; - } - - if (values.branch2.input_precision == MultiplyTransformationTestValues::input_precision) { - values.branch2.input_precision = input_precisions.second; - } + const auto update_branch = [&input_precisions](MultiplyBranch& branch) { + if (branch.input_precision == MultiplyTransformationTestValues::input_precision) { + branch.input_precision = input_precisions.first; + } + + if (!branch.constant.empty() && + (branch.constant.outPrecision == MultiplyTransformationTestValues::input_precision)) { + branch.constant.outPrecision = input_precisions.first; + } + }; + + update_branch(values.branch1); + update_branch(values.branch2); }; update_values(input_precisions, test_values.actual); @@ -190,9 +203,11 @@ const std::vector> input_precisi { ov::element::i8, ov::element::i8 }, { ov::element::u8, ov::element::i8 }, { ov::element::i8, ov::element::u8 }, + { ov::element::f32, ov::element::f32 }, + { ov::element::f16, ov::element::f16 }, }; -// PartialShape inputShape; +namespace broadcast_no { const std::vector> input_shapes = { {{ 1, 3, 8, 16 }, { 1, 3, 8, 16 }}, {{ 1, 3, 8, 16 }, { 1, 3, 1, 1 }}, @@ -237,6 +252,62 @@ const std::vector multiplyTransformationTestVa } }, + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {{ 7.f, 8.f, 9.f }, MultiplyTransformationTestValues::input_precision, ov::Shape{1, 3, 1, 1}}, // Constant as input, + {}, + {ov::element::f32, { 3.f }, { 7.f }} + }, + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} + }, + { + {{ 280.f, 350.f, 420.f }, ov::element::f32, ov::Shape{1, 3, 1, 1}}, // Constant as input, + {}, + {} + } + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {{ 7.f, 8.f, 9.f }, MultiplyTransformationTestValues::input_precision, ov::Shape{1, 3, 1, 1}}, // Constant as input, + {}, + {ov::element::f32, { 3.f }, { 7.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + } + }, + { + { + {{ 280.f, 350.f, 420.f }, ov::element::f32, ov::Shape{1, 3, 1, 1}}, // Constant as input, + {}, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} + } + } + }, + { LayerTransformation::createParamsU8I8(), { @@ -363,4 +434,332 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(input_precisions), ::testing::ValuesIn(multiplyTransformationTestValues)), MultiplyTransformation::getTestCaseName); -} // namespace +} // namespace broadcast_no + +namespace broadcast_right { +const std::vector> input_shapes = { + {{ 1, 3, 8, 16 }, { 1, 1, 1, 1 }} +}; + +const std::vector multiplyTransformationTestValues = { + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + }, + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 2.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f }, ov::element::f32}, {}} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {{ 1.f, 2.f, 3.f }}, {{ 10.f, 11.f, 12.f }}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 1.f, 2.f, 3.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f }, ov::element::f32}, {}} + }, + {{}, {}, {{70.f, 77.f, 84.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{3.f}, ov::element::f32}, {}} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + MultiplyTransformation, + ::testing::Combine( + ::testing::ValuesIn(model_precisions), + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(input_precisions), + ::testing::ValuesIn(multiplyTransformationTestValues)), + MultiplyTransformation::getTestCaseName); +} // namespace broadcast_right + +namespace broadcast_left { +const std::vector> input_shapes = { + {{ 1, 1, 1, 1 }, { 1, 3, 8, 16 }} +}; + +const std::vector multiplyTransformationTestValues = { + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + }, + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 2.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f }, ov::element::f32}, {}} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {{ 3.f, 4.f, 5.f }}, {{ 7.f, 8.f, 9.f }}} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 2.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f, 4.f, 5.f }, ov::element::f32}, {}} + }, + {{}, {}, {{70.f, 80.f, 90.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{3.f}, ov::element::f32}, {}} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + MultiplyTransformation, + ::testing::Combine( + ::testing::ValuesIn(model_precisions), + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(input_precisions), + ::testing::ValuesIn(multiplyTransformationTestValues)), + MultiplyTransformation::getTestCaseName); +} // namespace broadcast_left + +} // namespace \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp b/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp index 26846b5f97cb62..2088d4db87696a 100644 --- a/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp +++ b/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp @@ -11,7 +11,7 @@ #include #include -#include "lpt_ngraph_functions/multiply_function.hpp" +#include "lpt_ngraph_functions/multiply_partial_function.hpp" #include "ngraph_functions/subgraph_builders.hpp" @@ -56,7 +56,7 @@ void MultiplyTransformation::SetUp() { MultiplyTestValues param; std::tie(precision, inputShape, targetDevice, param) = this->GetParam(); - function = ngraph::builder::subgraph::MultiplyFunction::getOriginal( + function = ngraph::builder::subgraph::MultiplyPartialFunction::get( precision, inputShape, param.broadcast1, diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp index 0c74ed64adadc8..81fbac96b989dc 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp @@ -4,6 +4,8 @@ #include "lpt_ngraph_functions/multiply_function.hpp" +#include + #include #include #include "ngraph_functions/subgraph_builders.hpp" @@ -18,6 +20,7 @@ namespace ngraph { namespace builder { namespace subgraph { +namespace multiply_function { struct BranchNodes { std::shared_ptr input; std::shared_ptr dequantization; @@ -34,10 +37,11 @@ BranchNodes makeBranch(const MultiplyBranch& branch) { const auto dequantization = makeDequantization(parent, branch.dequantization); return {parent, dequantization}; } +} // namespace multiply_function std::shared_ptr MultiplyFunction::get(const element::Type model_precision, const MultiplyValues& actualValues) { - const BranchNodes branchNodes1 = makeBranch(actualValues.branch1); - const BranchNodes branchNodes2 = makeBranch(actualValues.branch2); + const auto branchNodes1 = multiply_function::makeBranch(actualValues.branch1); + const auto branchNodes2 = multiply_function::makeBranch(actualValues.branch2); // branchNodes1.dequantization & branchNodes2.dequantization can have different input types std::shared_ptr parent = std::make_shared>( diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp index 8e37a0dff00169..e41d340a634d61 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp @@ -4,6 +4,8 @@ #include "lpt_ngraph_functions/multiply_partial_function.hpp" +#include + #include #include #include "ngraph_functions/subgraph_builders.hpp" @@ -18,6 +20,7 @@ namespace ngraph { namespace builder { namespace subgraph { +namespace multiply_partial_function { struct BranchNodes { std::shared_ptr input; std::shared_ptr dequantization; @@ -34,6 +37,7 @@ BranchNodes getBranch(const MultiplyPartialBranch& branch) { const auto dequantization = makeDequantization(parent, branch.dequantization); return {parent, dequantization}; } +} // namespace multiply_partial_function std::shared_ptr MultiplyPartialFunction::get( const element::Type precision, @@ -45,8 +49,8 @@ std::shared_ptr MultiplyPartialFunction::get( branch2Structure.precisionBeforeDequantization = precision; branch2Structure.dequantization.multiply.outPrecision = precision; - const BranchNodes branchNodes1 = getBranch(actualValues.branch1); - const BranchNodes branchNodes2 = getBranch(actualValues.branch2); + const auto branchNodes1 = multiply_partial_function::getBranch(actualValues.branch1); + const auto branchNodes2 = multiply_partial_function::getBranch(actualValues.branch2); auto multiplyOriginal = opset1::Multiply( ov::op::TemporaryReplaceOutputType(branchNodes1.dequantization, element::f32).get(),