diff --git a/src/common/low_precision_transformations/src/reduce_sum.cpp b/src/common/low_precision_transformations/src/reduce_sum.cpp index cab168e55631a3..5e6e484a895ee0 100644 --- a/src/common/low_precision_transformations/src/reduce_sum.cpp +++ b/src/common/low_precision_transformations/src/reduce_sum.cpp @@ -68,14 +68,24 @@ void ReduceSumTransformation::changeDequantizationValues( } // (a1 - s) + (a2 - s) + ... + (an - s) = (a1 + a2 + ... + an) - n * s - const auto reductionSizeConstant = ov::opset1::Constant::create( - dequantization.subtractConstant->get_element_type(), - Shape{}, - { static_cast(reductionSize) }); - const auto result = fold(dequantization.subtractConstant, reductionSizeConstant); + const auto reductionSizeConstant = ov::opset1::Constant::create(deqPrecision, Shape{}, { static_cast(reductionSize) }); + assert(deqPrecision == dequantization.subtract->get_input_element_type(0)); + const auto result = fold( + dequantization.subtractConstant->get_element_type() == deqPrecision ? + dequantization.subtractConstant : + std::dynamic_pointer_cast(foldConvert(dequantization.subtractConstant, deqPrecision)), + reductionSizeConstant); + + replace_node( + dequantization.subtractConvert != nullptr ? + std::dynamic_pointer_cast(dequantization.subtractConvert) : + dequantization.subtractConstant, + result); - replace_node(dequantization.subtractConstant, result); dequantization.subtractConstant = ov::as_type_ptr(result); + if (dequantization.subtractConvert != nullptr) { + dequantization.subtractConvert = nullptr; + } } } diff --git a/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp b/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp index 5bb09b8984fce7..a9759330a9151e 100644 --- a/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp +++ b/src/common/low_precision_transformations/tests/reduce_sum_transformation.cpp @@ -31,15 +31,21 @@ class ReduceSumTransformation : public ReduceTransformation(GetParam()).params; SimpleLowPrecisionTransformer transform; transform.add(transformationParams); transform.transform(actualFunction); + + ov::pass::Serialize("test.actual.xml", "test.actual.bin").run_on_model(actualFunction); } }; TEST_P(ReduceSumTransformation, CompareFunctions) { + ov::pass::Serialize("test.reference.xml", "test.reference.bin").run_on_model(referenceFunction); + actualFunction->validate_nodes_and_infer_types(); auto res = compare_functions(actualFunction, referenceFunction, true, true, false); ASSERT_TRUE(res.first) << res.second; @@ -352,7 +358,7 @@ namespace testValues3 { ov::element::u8, { {ov::element::f32}, - {{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, + {{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } }, @@ -362,7 +368,7 @@ namespace testValues3 { ov::element::f32, { {}, - {{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true}, + {{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } } @@ -375,7 +381,7 @@ namespace testValues3 { ov::element::i8, { {ov::element::f32}, - {{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, + {{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } }, @@ -385,7 +391,7 @@ namespace testValues3 { ov::element::f32, { {}, - {{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true}, + {{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}}, {{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}} } }