Skip to content

Commit

Permalink
[LPT] ReduceSum: zero point support (openvinotoolkit#24977)
Browse files Browse the repository at this point in the history
### Details:
 - *[LPT] ReduceSum: zero point support*

### Tickets:
 - *CVS-142256*
  • Loading branch information
eshoguli authored and allnes committed Jun 26, 2024
1 parent 5428336 commit e31bcf5
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 3 deletions.
15 changes: 13 additions & 2 deletions src/common/low_precision_transformations/src/reduce_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,21 @@ void ReduceSumTransformation::changeDequantizationValues(

// (a1 - s) + (a2 - s) + ... + (an - s) = (a1 + a2 + ... + an) - n * s
const auto reductionSizeConstant = ov::opset1::Constant::create(deqPrecision, Shape{}, { static_cast<float>(reductionSize) });
const auto result = fold<ov::opset1::Multiply>(dequantization.subtractConstant, reductionSizeConstant);
OPENVINO_ASSERT(deqPrecision == dequantization.subtract->get_input_element_type(0),
"dequantization precision ", deqPrecision,
" differs from zero point 0 input ", dequantization.subtract->get_input_element_type(0));
const auto result = fold<ov::opset1::Multiply>(
foldConvert(dequantization.subtractConstant, deqPrecision),
reductionSizeConstant);

replace_node(
dequantization.subtractConvert != nullptr ?
std::dynamic_pointer_cast<ov::Node>(dequantization.subtractConvert) :
dequantization.subtractConstant,
result);

replace_node(dequantization.subtractConstant, result);
dequantization.subtractConstant = ov::as_type_ptr<ov::opset1::Constant>(result);
dequantization.subtractConvert = nullptr;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,68 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(reduceSumTransformationTestValues)),
ReduceSumTransformation::getTestCaseName);
} // namespace testValues2

namespace testValues3 {
const std::vector<ov::PartialShape> inputShapes = {
{4, 3, 16, 16}
};

const std::vector<ReduceTransformationTestValues> reduceSumTransformationTestValues = {
{
LayerTransformation::createParamsU8I8(),
{0},
true,
{
ov::element::u8,
{
{ov::element::f32},
{{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
},
{
ov::element::u8,
{},
ov::element::f32,
{
{},
{{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
}
},
{
LayerTransformation::createParamsU8I8(),
{0},
true,
{
ov::element::i8,
{
{ov::element::f32},
{{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
},
{
ov::element::i8,
{},
ov::element::f32,
{
{},
{{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
}
}
};

INSTANTIATE_TEST_SUITE_P(
smoke_LPT,
ReduceSumTransformation,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(reduceSumTransformationTestValues)),
ReduceSumTransformation::getTestCaseName);
} // namespace testValues3

} // namespace
2 changes: 1 addition & 1 deletion src/core/src/op/multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ bool Multiply::evaluate(TensorVector& outputs, const TensorVector& inputs) const
this,
outputs,
inputs,
OV_PP_ET_LIST(f32, f64, i32, i64, u32, u64),
OV_PP_ET_LIST(f32, f64, i8, i32, i64, u8, u32, u64),
multiply::Evaluate,
inputs[0].get_element_type(),
inputs[0],
Expand Down

0 comments on commit e31bcf5

Please sign in to comment.