From 48526d837f420d562808f793a7c644e1d9b28f68 Mon Sep 17 00:00:00 2001 From: eshoguli Date: Wed, 12 Jun 2024 13:53:56 +0100 Subject: [PATCH 1/2] [LPT] ReduceSum: zero point support --- .../src/reduce_sum.cpp | 5 +- .../tests/reduce_sum_transformation.cpp | 65 +++++++++++++++++++ src/core/src/op/multiply.cpp | 2 +- 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/src/common/low_precision_transformations/src/reduce_sum.cpp b/src/common/low_precision_transformations/src/reduce_sum.cpp index 63534ddef290e9..cab168e55631a3 100644 --- a/src/common/low_precision_transformations/src/reduce_sum.cpp +++ b/src/common/low_precision_transformations/src/reduce_sum.cpp @@ -68,7 +68,10 @@ void ReduceSumTransformation::changeDequantizationValues( } // (a1 - s) + (a2 - s) + ... + (an - s) = (a1 + a2 + ... + an) - n * s - const auto reductionSizeConstant = ov::opset1::Constant::create(deqPrecision, Shape{}, { static_cast(reductionSize) }); + const auto reductionSizeConstant = ov::opset1::Constant::create( + dequantization.subtractConstant->get_element_type(), + Shape{}, + { static_cast(reductionSize) }); const auto result = fold(dequantization.subtractConstant, reductionSizeConstant); replace_node(dequantization.subtractConstant, result); diff --git a/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp b/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp index 8faa9a0fdd6074..5bb09b8984fce7 100644 --- a/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp +++ b/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp @@ -30,6 +30,7 @@ using namespace ov::builder::subgraph; class ReduceSumTransformation : public ReduceTransformation { void SetUp() override { ReduceTransformation::SetUp(); + const auto transformationParams = std::get<1>(GetParam()).params; SimpleLowPrecisionTransformer transform; @@ -336,4 +337,68 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(reduceSumTransformationTestValues)), ReduceSumTransformation::getTestCaseName); } // namespace testValues2 + +namespace testValues3 { + const std::vector inputShapes = { + {4, 3, 16, 16} + }; + + const std::vector reduceSumTransformationTestValues = { + { + LayerTransformation::createParamsU8I8(), + {0}, + true, + { + ov::element::u8, + { + {ov::element::f32}, + {{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, + {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} + } + }, + { + ov::element::u8, + {}, + ov::element::f32, + { + {}, + {{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, + {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} + } + } + }, + { + LayerTransformation::createParamsU8I8(), + {0}, + true, + { + ov::element::i8, + { + {ov::element::f32}, + {{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, + {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} + } + }, + { + ov::element::i8, + {}, + ov::element::f32, + { + {}, + {{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, + {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} + } + } + } + }; + + INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + ReduceSumTransformation, + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::ValuesIn(reduceSumTransformationTestValues)), + ReduceSumTransformation::getTestCaseName); +} // namespace testValues3 + } // namespace diff --git a/src/core/src/op/multiply.cpp b/src/core/src/op/multiply.cpp index fa3ef518c03202..88dbd347d46edf 100644 --- a/src/core/src/op/multiply.cpp +++ b/src/core/src/op/multiply.cpp @@ -51,7 +51,7 @@ bool Multiply::evaluate(TensorVector& outputs, const TensorVector& inputs) const this, outputs, inputs, - OV_PP_ET_LIST(f32, f64, i32, i64, u32, u64), + OV_PP_ET_LIST(f32, f64, i8, i32, i64, u8, u32, u64), multiply::Evaluate, inputs[0].get_element_type(), inputs[0], From 4f443071d08392f2fe6c3f453c339f0427c682b4 Mon Sep 17 00:00:00 2001 From: eshoguli Date: Wed, 12 Jun 2024 16:44:13 +0100 Subject: [PATCH 2/2] comments --- .../src/reduce_sum.cpp | 20 +++++++++++++------ .../tests/reduce_sum_transformation.cpp | 9 ++++----- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/common/low_precision_transformations/src/reduce_sum.cpp b/src/common/low_precision_transformations/src/reduce_sum.cpp index cab168e55631a3..155d64b508043e 100644 --- a/src/common/low_precision_transformations/src/reduce_sum.cpp +++ b/src/common/low_precision_transformations/src/reduce_sum.cpp @@ -68,14 +68,22 @@ void ReduceSumTransformation::changeDequantizationValues( } // (a1 - s) + (a2 - s) + ... + (an - s) = (a1 + a2 + ... + an) - n * s - const auto reductionSizeConstant = ov::opset1::Constant::create( - dequantization.subtractConstant->get_element_type(), - Shape{}, - { static_cast(reductionSize) }); - const auto result = fold(dequantization.subtractConstant, reductionSizeConstant); + const auto reductionSizeConstant = ov::opset1::Constant::create(deqPrecision, Shape{}, { static_cast(reductionSize) }); + OPENVINO_ASSERT(deqPrecision == dequantization.subtract->get_input_element_type(0), + "dequantization precision ", deqPrecision, + " differs from zero point 0 input ", dequantization.subtract->get_input_element_type(0)); + const auto result = fold( + foldConvert(dequantization.subtractConstant, deqPrecision), + reductionSizeConstant); + + replace_node( + dequantization.subtractConvert != nullptr ? + std::dynamic_pointer_cast(dequantization.subtractConvert) : + dequantization.subtractConstant, + result); - replace_node(dequantization.subtractConstant, result); dequantization.subtractConstant = ov::as_type_ptr(result); + dequantization.subtractConvert = nullptr; } } diff --git a/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp b/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp index 5bb09b8984fce7..6f306d24db746e 100644 --- a/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp +++ b/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp @@ -30,7 +30,6 @@ using namespace ov::builder::subgraph; class ReduceSumTransformation : public ReduceTransformation { void SetUp() override { ReduceTransformation::SetUp(); - const auto transformationParams = std::get<1>(GetParam()).params; SimpleLowPrecisionTransformer transform; @@ -352,7 +351,7 @@ namespace testValues3 { ov::element::u8, { {ov::element::f32}, - {{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, + {{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } }, @@ -362,7 +361,7 @@ namespace testValues3 { ov::element::f32, { {}, - {{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, + {{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } } @@ -375,7 +374,7 @@ namespace testValues3 { ov::element::i8, { {ov::element::f32}, - {{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, + {{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } }, @@ -385,7 +384,7 @@ namespace testValues3 { ov::element::f32, { {}, - {{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, + {{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } }