From e31bcf54a5de2beeaa415df3734c9acd912745b3 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 13 Jun 2024 11:56:06 +0100 Subject: [PATCH] [LPT] ReduceSum: zero point support (#24977) ### Details: - *[LPT] ReduceSum: zero point support* ### Tickets: - *CVS-142256* --- .../src/reduce_sum.cpp | 15 ++++- .../tests/reduce_sum_transformation.cpp | 64 +++++++++++++++++++ src/core/src/op/multiply.cpp | 2 +- 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/src/common/low_precision_transformations/src/reduce_sum.cpp b/src/common/low_precision_transformations/src/reduce_sum.cpp index 63534ddef290e9..155d64b508043e 100644 --- a/src/common/low_precision_transformations/src/reduce_sum.cpp +++ b/src/common/low_precision_transformations/src/reduce_sum.cpp @@ -69,10 +69,21 @@ void ReduceSumTransformation::changeDequantizationValues( // (a1 - s) + (a2 - s) + ... + (an - s) = (a1 + a2 + ... + an) - n * s const auto reductionSizeConstant = ov::opset1::Constant::create(deqPrecision, Shape{}, { static_cast(reductionSize) }); - const auto result = fold(dequantization.subtractConstant, reductionSizeConstant); + OPENVINO_ASSERT(deqPrecision == dequantization.subtract->get_input_element_type(0), + "dequantization precision ", deqPrecision, + " differs from zero point 0 input ", dequantization.subtract->get_input_element_type(0)); + const auto result = fold( + foldConvert(dequantization.subtractConstant, deqPrecision), + reductionSizeConstant); + + replace_node( + dequantization.subtractConvert != nullptr ? + std::dynamic_pointer_cast(dequantization.subtractConvert) : + dequantization.subtractConstant, + result); - replace_node(dequantization.subtractConstant, result); dequantization.subtractConstant = ov::as_type_ptr(result); + dequantization.subtractConvert = nullptr; } } diff --git a/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp b/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp index 8faa9a0fdd6074..6f306d24db746e 100644 --- a/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp +++ b/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp @@ -336,4 +336,68 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(reduceSumTransformationTestValues)), ReduceSumTransformation::getTestCaseName); } // namespace testValues2 + +namespace testValues3 { + const std::vector inputShapes = { + {4, 3, 16, 16} + }; + + const std::vector reduceSumTransformationTestValues = { + { + LayerTransformation::createParamsU8I8(), + {0}, + true, + { + ov::element::u8, + { + {ov::element::f32}, + {{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, + {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} + } + }, + { + ov::element::u8, + {}, + ov::element::f32, + { + {}, + {{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}}, + {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} + } + } + }, + { + LayerTransformation::createParamsU8I8(), + {0}, + true, + { + ov::element::i8, + { + {ov::element::f32}, + {{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, + {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} + } + }, + { + ov::element::i8, + {}, + ov::element::f32, + { + {}, + {{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}}, + {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} + } + } + } + }; + + INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + ReduceSumTransformation, + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::ValuesIn(reduceSumTransformationTestValues)), + ReduceSumTransformation::getTestCaseName); +} // namespace testValues3 + } // namespace diff --git a/src/core/src/op/multiply.cpp b/src/core/src/op/multiply.cpp index fa3ef518c03202..88dbd347d46edf 100644 --- a/src/core/src/op/multiply.cpp +++ b/src/core/src/op/multiply.cpp @@ -51,7 +51,7 @@ bool Multiply::evaluate(TensorVector& outputs, const TensorVector& inputs) const this, outputs, inputs, - OV_PP_ET_LIST(f32, f64, i32, i64, u32, u64), + OV_PP_ET_LIST(f32, f64, i8, i32, i64, u8, u32, u64), multiply::Evaluate, inputs[0].get_element_type(), inputs[0],