Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Jun 12, 2024
1 parent 48526d8 commit 56f4c52
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
22 changes: 16 additions & 6 deletions src/common/low_precision_transformations/src/reduce_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,24 @@ void ReduceSumTransformation::changeDequantizationValues(
}

// (a1 - s) + (a2 - s) + ... + (an - s) = (a1 + a2 + ... + an) - n * s
const auto reductionSizeConstant = ov::opset1::Constant::create(
dequantization.subtractConstant->get_element_type(),
Shape{},
{ static_cast<float>(reductionSize) });
const auto result = fold<ov::opset1::Multiply>(dequantization.subtractConstant, reductionSizeConstant);
const auto reductionSizeConstant = ov::opset1::Constant::create(deqPrecision, Shape{}, { static_cast<float>(reductionSize) });
assert(deqPrecision == dequantization.subtract->get_input_element_type(0));
const auto result = fold<ov::opset1::Multiply>(
dequantization.subtractConstant->get_element_type() == deqPrecision ?
dequantization.subtractConstant :
std::dynamic_pointer_cast<ov::Node>(foldConvert(dequantization.subtractConstant, deqPrecision)),
reductionSizeConstant);

replace_node(
dequantization.subtractConvert != nullptr ?
std::dynamic_pointer_cast<ov::Node>(dequantization.subtractConvert) :
dequantization.subtractConstant,
result);

replace_node(dequantization.subtractConstant, result);
dequantization.subtractConstant = ov::as_type_ptr<ov::opset1::Constant>(result);
if (dequantization.subtractConvert != nullptr) {
dequantization.subtractConvert = nullptr;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,21 @@ class ReduceSumTransformation : public ReduceTransformation<ov::op::v1::ReduceSu
void SetUp() override {
ReduceTransformation::SetUp();

ov::pass::Serialize("test.original.xml", "test.original.bin").run_on_model(actualFunction);

const auto transformationParams = std::get<1>(GetParam()).params;

SimpleLowPrecisionTransformer transform;
transform.add<ov::pass::low_precision::ReduceSumTransformation, ov::op::v1::ReduceSum>(transformationParams);
transform.transform(actualFunction);

ov::pass::Serialize("test.actual.xml", "test.actual.bin").run_on_model(actualFunction);
}
};

TEST_P(ReduceSumTransformation, CompareFunctions) {
ov::pass::Serialize("test.reference.xml", "test.reference.bin").run_on_model(referenceFunction);

actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(actualFunction, referenceFunction, true, true, false);
ASSERT_TRUE(res.first) << res.second;
Expand Down Expand Up @@ -352,7 +358,7 @@ namespace testValues3 {
ov::element::u8,
{
{ov::element::f32},
{{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true},
{{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
},
Expand All @@ -362,7 +368,7 @@ namespace testValues3 {
ov::element::f32,
{
{},
{{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::u8, true},
{{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
}
Expand All @@ -375,7 +381,7 @@ namespace testValues3 {
ov::element::i8,
{
{ov::element::f32},
{{2.f, 4.f, 6.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true},
{{40.f, 80.f, 120.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
},
Expand All @@ -385,7 +391,7 @@ namespace testValues3 {
ov::element::f32,
{
{},
{{8.f, 16.f, 24.f}, ov::element::f32, {1, 3, 1, 1}, false, 1ul, ov::element::i8, true},
{{160.f, 320.f, 480.f}, ov::element::f32, {1, 3, 1, 1}},
{{0.1f, 1.f, 10.f}, ov::element::f32, {1, 3, 1, 1}}
}
}
Expand Down

0 comments on commit 56f4c52

Please sign in to comment.