Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Jun 12, 2024
1 parent 48526d8 commit cb3a247
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
22 changes: 16 additions & 6 deletions src/common/low_precision_transformations/src/reduce_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,24 @@ void ReduceSumTransformation::changeDequantizationValues(
}

// (a1 - s) + (a2 - s) + ... + (an - s) = (a1 + a2 + ... + an) - n * s
const auto reductionSizeConstant = ov::opset1::Constant::create(
dequantization.subtractConstant->get_element_type(),
Shape{},
{ static_cast<float>(reductionSize) });
const auto result = fold<ov::opset1::Multiply>(dequantization.subtractConstant, reductionSizeConstant);
const auto reductionSizeConstant = ov::opset1::Constant::create(deqPrecision, Shape{}, { static_cast<float>(reductionSize) });
assert(deqPrecision == dequantization.subtract->get_input_element_type(0));
const auto result = fold<ov::opset1::Multiply>(
dequantization.subtractConstant->get_element_type() == deqPrecision ?
dequantization.subtractConstant :
std::dynamic_pointer_cast<ov::Node>(foldConvert(dequantization.subtractConstant, deqPrecision)),
reductionSizeConstant);

replace_node(
dequantization.subtractConvert != nullptr ?
std::dynamic_pointer_cast<ov::Node>(dequantization.subtractConvert) :
dequantization.subtractConstant,
result);

replace_node(dequantization.subtractConstant, result);
dequantization.subtractConstant = ov::as_type_ptr<ov::opset1::Constant>(result);
if (dequantization.subtractConvert != nullptr) {
dequantization.subtractConvert = nullptr;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ using namespace ov::builder::subgraph;
class ReduceSumTransformation : public ReduceTransformation<ov::op::v1::ReduceSum> {
void SetUp() override {
ReduceTransformation::SetUp();

const auto transformationParams = std::get<1>(GetParam()).params;

SimpleLowPrecisionTransformer transform;
Expand Down Expand Up @@ -352,7 +351,7 @@ namespace testValues3 {
ov::element::u8,
{
{ov::element::f32},
{{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true},
{{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
},
Expand All @@ -362,7 +361,7 @@ namespace testValues3 {
ov::element::f32,
{
{},
{{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true},
{{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
}
Expand All @@ -375,7 +374,7 @@ namespace testValues3 {
ov::element::i8,
{
{ov::element::f32},
{{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true},
{{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
},
Expand All @@ -385,7 +384,7 @@ namespace testValues3 {
ov::element::f32,
{
{},
{{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true},
{{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
}
Expand Down

0 comments on commit cb3a247

Please sign in to comment.