From cb3a247967f061e99297f028aca5f280761b12e5 Mon Sep 17 00:00:00 2001 From: eshoguli Date: Wed, 12 Jun 2024 16:44:13 +0100 Subject: [PATCH] comments --- .../src/reduce_sum.cpp | 22 ++++++++++++++----- .../tests/reduce_sum_transformation.cpp | 9 ++++---- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/common/low_precision_transformations/src/reduce_sum.cpp b/src/common/low_precision_transformations/src/reduce_sum.cpp index cab168e55631a3..5e6e484a895ee0 100644 --- a/src/common/low_precision_transformations/src/reduce_sum.cpp +++ b/src/common/low_precision_transformations/src/reduce_sum.cpp @@ -68,14 +68,24 @@ void ReduceSumTransformation::changeDequantizationValues( } // (a1 - s) + (a2 - s) + ... + (an - s) = (a1 + a2 + ... + an) - n * s - const auto reductionSizeConstant = ov::opset1::Constant::create( - dequantization.subtractConstant->get_element_type(), - Shape{}, - { static_cast(reductionSize) }); - const auto result = fold(dequantization.subtractConstant, reductionSizeConstant); + const auto reductionSizeConstant = ov::opset1::Constant::create(deqPrecision, Shape{}, { static_cast(reductionSize) }); + assert(deqPrecision == dequantization.subtract->get_input_element_type(0)); + const auto result = fold( + dequantization.subtractConstant->get_element_type() == deqPrecision ? + dequantization.subtractConstant : + std::dynamic_pointer_cast(foldConvert(dequantization.subtractConstant, deqPrecision)), + reductionSizeConstant); + + replace_node( + dequantization.subtractConvert != nullptr ? + std::dynamic_pointer_cast(dequantization.subtractConvert) : + dequantization.subtractConstant, + result); - replace_node(dequantization.subtractConstant, result); dequantization.subtractConstant = ov::as_type_ptr(result); + if (dequantization.subtractConvert != nullptr) { + dequantization.subtractConvert = nullptr; + } } } diff --git a/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp b/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp index 5bb09b8984fce7..6f306d24db746e 100644 --- a/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp +++ b/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp @@ -30,7 +30,6 @@ using namespace ov::builder::subgraph; class ReduceSumTransformation : public ReduceTransformation { void SetUp() override { ReduceTransformation::SetUp(); - const auto transformationParams = std::get<1>(GetParam()).params; SimpleLowPrecisionTransformer transform; @@ -352,7 +351,7 @@ namespace testValues3 { ov::element::u8, { {ov::element::f32}, - {{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, + {{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } }, @@ -362,7 +361,7 @@ namespace testValues3 { ov::element::f32, { {}, - {{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, + {{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } } @@ -375,7 +374,7 @@ namespace testValues3 { ov::element::i8, { {ov::element::f32}, - {{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, + {{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } }, @@ -385,7 +384,7 @@ namespace testValues3 { ov::element::f32, { {}, - {{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, + {{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } }