Skip to content

Commit

Permalink
[LPT] ReduceSum: zero point support
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Jun 12, 2024
1 parent 107869e commit 48526d8
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/common/low_precision_transformations/src/reduce_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ void ReduceSumTransformation::changeDequantizationValues(
}

// (a1 - s) + (a2 - s) + ... + (an - s) = (a1 + a2 + ... + an) - n * s
const auto reductionSizeConstant = ov::opset1::Constant::create(deqPrecision, Shape{}, { static_cast<float>(reductionSize) });
const auto reductionSizeConstant = ov::opset1::Constant::create(
dequantization.subtractConstant->get_element_type(),
Shape{},
{ static_cast<float>(reductionSize) });
const auto result = fold<ov::opset1::Multiply>(dequantization.subtractConstant, reductionSizeConstant);

replace_node(dequantization.subtractConstant, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using namespace ov::builder::subgraph;
class ReduceSumTransformation : public ReduceTransformation<ov::op::v1::ReduceSum> {
void SetUp() override {
ReduceTransformation::SetUp();

const auto transformationParams = std::get<1>(GetParam()).params;

SimpleLowPrecisionTransformer transform;
Expand Down Expand Up @@ -336,4 +337,68 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(reduceSumTransformationTestValues)),
ReduceSumTransformation::getTestCaseName);
} // namespace testValues2

namespace testValues3 {
const std::vector<ov::PartialShape> inputShapes = {
{4, 3, 16, 16}
};

const std::vector<ReduceTransformationTestValues> reduceSumTransformationTestValues = {
{
LayerTransformation::createParamsU8I8(),
{0},
true,
{
ov::element::u8,
{
{ov::element::f32},
{{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
},
{
ov::element::u8,
{},
ov::element::f32,
{
{},
{{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
}
},
{
LayerTransformation::createParamsU8I8(),
{0},
true,
{
ov::element::i8,
{
{ov::element::f32},
{{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
},
{
ov::element::i8,
{},
ov::element::f32,
{
{},
{{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
}
}
};

INSTANTIATE_TEST_SUITE_P(
smoke_LPT,
ReduceSumTransformation,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(reduceSumTransformationTestValues)),
ReduceSumTransformation::getTestCaseName);
} // namespace testValues3

} // namespace
2 changes: 1 addition & 1 deletion src/core/src/op/multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ bool Multiply::evaluate(TensorVector& outputs, const TensorVector& inputs) const
this,
outputs,
inputs,
OV_PP_ET_LIST(f32, f64, i32, i64, u32, u64),
OV_PP_ET_LIST(f32, f64, i8, i32, i64, u8, u32, u64),
multiply::Evaluate,
inputs[0].get_element_type(),
inputs[0],
Expand Down

0 comments on commit 48526d8

Please sign in to comment.