diff --git a/src/common/low_precision_transformations/src/multiply.cpp b/src/common/low_precision_transformations/src/multiply.cpp index abf14224657655..bebd8970a79333 100644 --- a/src/common/low_precision_transformations/src/multiply.cpp +++ b/src/common/low_precision_transformations/src/multiply.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2023 Intel Corporation +// Copyright (C) 2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -44,7 +44,6 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass: return false; } - // TODO: normalizeDequantization + fold_fake_quantizes + foldDequantization <= ??? NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0)); NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1)); diff --git a/src/common/low_precision_transformations/tests/multiply_transformation.cpp b/src/common/low_precision_transformations/tests/multiply_transformation.cpp index 162163477590ab..aa8bc894a0e55a 100644 --- a/src/common/low_precision_transformations/tests/multiply_transformation.cpp +++ b/src/common/low_precision_transformations/tests/multiply_transformation.cpp @@ -20,6 +20,8 @@ #include "simple_low_precision_transformer.hpp" #include "lpt_ngraph_functions/multiply_function.hpp" +#include "ngraph/pass/serialize.hpp" + namespace { using namespace testing; using namespace ov; @@ -109,13 +111,22 @@ class MultiplyTransformation : public LayerTransformation, public testing::WithP model_precision, to_multiply_values(testParams.actual)); + ngraph::pass::Serialize("svg/test.original.xml", "svg/test.original.bin").run_on_model(actualFunction); + ngraph::pass::VisualizeTree("svg/test.original.svg").run_on_model(actualFunction); + SimpleLowPrecisionTransformer transform; transform.add(testParams.transformationParams); transform.transform(actualFunction); + ngraph::pass::Serialize("svg/test.actual.xml", "svg/test.actual.bin").run_on_model(actualFunction); + ngraph::pass::VisualizeTree("svg/test.actual.svg").run_on_model(actualFunction); + referenceFunction = MultiplyFunction::get( model_precision, to_multiply_values(testParams.expected)); + + ngraph::pass::Serialize("svg/test.reference.xml", "svg/test.reference.bin").run_on_model(referenceFunction); + ngraph::pass::VisualizeTree("svg/test.reference.svg").run_on_model(referenceFunction); } static std::string getTestCaseName(testing::TestParamInfo obj) { @@ -192,7 +203,7 @@ const std::vector> input_precisi { ov::element::i8, ov::element::u8 }, }; -// PartialShape inputShape; +namespace broadcast_no { const std::vector> input_shapes = { {{ 1, 3, 8, 16 }, { 1, 3, 8, 16 }}, {{ 1, 3, 8, 16 }, { 1, 3, 1, 1 }}, @@ -363,4 +374,332 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(input_precisions), ::testing::ValuesIn(multiplyTransformationTestValues)), MultiplyTransformation::getTestCaseName); -} // namespace +} // namespace broadcast_no + +namespace broadcast_right { +const std::vector> input_shapes = { + {{ 1, 3, 8, 16 }, { 1, 1, 1, 1 }} +}; + +const std::vector multiplyTransformationTestValues = { + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + }, + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 2.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f }, ov::element::f32}, {}} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {{ 1.f, 2.f, 3.f }}, {{ 10.f, 11.f, 12.f }}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 1.f, 2.f, 3.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f }, ov::element::f32}, {}} + }, + {{}, {}, {{70.f, 77.f, 84.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{3.f}, ov::element::f32}, {}} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + MultiplyTransformation, + ::testing::Combine( + ::testing::ValuesIn(model_precisions), + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(input_precisions), + ::testing::ValuesIn(multiplyTransformationTestValues)), + MultiplyTransformation::getTestCaseName); +} // namespace broadcast_right + +namespace broadcast_left { +const std::vector> input_shapes = { + {{ 1, 1, 1, 1 }, { 1, 3, 8, 16 }} +}; + +const std::vector multiplyTransformationTestValues = { + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + }, + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 2.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f }, ov::element::f32}, {}} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {{ 3.f, 4.f, 5.f }}, {{ 7.f, 8.f, 9.f }}} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 2.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f, 4.f, 5.f }, ov::element::f32}, {}} + }, + {{}, {}, {{70.f, 80.f, 90.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{3.f}, ov::element::f32}, {}} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + MultiplyTransformation, + ::testing::Combine( + ::testing::ValuesIn(model_precisions), + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(input_precisions), + ::testing::ValuesIn(multiplyTransformationTestValues)), + MultiplyTransformation::getTestCaseName); +} // namespace broadcast_left + +} // namespace \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp b/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp index 26846b5f97cb62..2088d4db87696a 100644 --- a/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp +++ b/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp @@ -11,7 +11,7 @@ #include #include -#include "lpt_ngraph_functions/multiply_function.hpp" +#include "lpt_ngraph_functions/multiply_partial_function.hpp" #include "ngraph_functions/subgraph_builders.hpp" @@ -56,7 +56,7 @@ void MultiplyTransformation::SetUp() { MultiplyTestValues param; std::tie(precision, inputShape, targetDevice, param) = this->GetParam(); - function = ngraph::builder::subgraph::MultiplyFunction::getOriginal( + function = ngraph::builder::subgraph::MultiplyPartialFunction::get( precision, inputShape, param.broadcast1, diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp index 0c74ed64adadc8..f92f4fcd3fd31a 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp @@ -4,6 +4,8 @@ #include "lpt_ngraph_functions/multiply_function.hpp" +#include + #include #include #include "ngraph_functions/subgraph_builders.hpp" @@ -18,6 +20,7 @@ namespace ngraph { namespace builder { namespace subgraph { +namespace multiply_function { struct BranchNodes { std::shared_ptr input; std::shared_ptr dequantization; @@ -34,6 +37,9 @@ BranchNodes makeBranch(const MultiplyBranch& branch) { const auto dequantization = makeDequantization(parent, branch.dequantization); return {parent, dequantization}; } +} // namespace multiply_function + +using namespace multiply_function; std::shared_ptr MultiplyFunction::get(const element::Type model_precision, const MultiplyValues& actualValues) { const BranchNodes branchNodes1 = makeBranch(actualValues.branch1); diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp index 8e37a0dff00169..d11faaa44a1709 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp @@ -4,6 +4,8 @@ #include "lpt_ngraph_functions/multiply_partial_function.hpp" +#include + #include #include #include "ngraph_functions/subgraph_builders.hpp" @@ -18,6 +20,7 @@ namespace ngraph { namespace builder { namespace subgraph { +namespace multiply_partial_function { struct BranchNodes { std::shared_ptr input; std::shared_ptr dequantization; @@ -34,6 +37,9 @@ BranchNodes getBranch(const MultiplyPartialBranch& branch) { const auto dequantization = makeDequantization(parent, branch.dequantization); return {parent, dequantization}; } +} // namespace multiply_partial_function + +using namespace multiply_partial_function; std::shared_ptr MultiplyPartialFunction::get( const element::Type precision,