Skip to content

Commit

Permalink
AlignEltwiseInputRanks - fix when output rank is less than constant r…
Browse files Browse the repository at this point in the history
…ank (openvinotoolkit#17895)

Fixes an issue when AlignEltwiseInputRanks is applied on FakeQuantize with
scalar as a first input and input/output low/high being Shape{1} constants.
In such case FakeQuantize output is still a scalar, so the difference
between output rank and input/output low/high rank is negative.

Ticket: CVS-112454
  • Loading branch information
mateusztabaka authored and alvoron committed Jun 21, 2023
1 parent 34debea commit e533f0e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ ov::pass::AlignEltwiseInputRanks::AlignEltwiseInputRanks() {
return false;
}

const auto rank = node->get_output_partial_shape(0).size();
const auto rank = static_cast<int64_t>(node->get_output_partial_shape(0).size());

for (size_t i = 0; i < node->get_input_size(); i++) {
auto const_node = as_type<opset8::Constant>(node->get_input_node_ptr(i));
if (const_node == nullptr)
continue;
const auto& const_shape = const_node->get_shape();
auto diff = rank - const_shape.size();
auto diff = rank - static_cast<int64_t>(const_shape.size());
if (diff > 0) {
Shape new_shape = const_shape;
new_shape.insert(new_shape.begin(), diff, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ using namespace ngraph;

using AlignEltwiseInputRanksParams = std::tuple<PartialShape, Shape, Shape, bool>;

class AlignEltwiseInputRanksTest : public testing::WithParamInterface<AlignEltwiseInputRanksParams>,
public TransformationTestsF {};
class AlignEltwiseInputRanksTestP : public testing::WithParamInterface<AlignEltwiseInputRanksParams>,
public TransformationTestsF {};

TEST_P(AlignEltwiseInputRanksTest, FusionTest) {
TEST_P(AlignEltwiseInputRanksTestP, FusionTest) {
auto params = GetParam();
const auto& input_shape = std::get<0>(params);
auto const_shape = std::get<1>(params);
Expand Down Expand Up @@ -77,4 +77,20 @@ static std::vector<AlignEltwiseInputRanksParams> params = {
AlignEltwiseInputRanksParams(Shape{}, {2, 3, 4}, {}, false),
};

INSTANTIATE_TEST_SUITE_P(TransformationTests, AlignEltwiseInputRanksTest, ::testing::ValuesIn(params));
INSTANTIATE_TEST_SUITE_P(TransformationTests, AlignEltwiseInputRanksTestP, ::testing::ValuesIn(params));

class AlignEltwiseInputRanksTestF : public TransformationTestsF {};

TEST_F(AlignEltwiseInputRanksTestF, NegativeFakeQuantizeWithScalarFirstInput) {
{
auto data = op::Constant::create(element::f32, Shape{}, {10});
auto low = op::Constant::create(element::f32, Shape{1}, {0});
auto high = op::Constant::create(element::f32, Shape{1}, {20});
auto fq = std::make_shared<opset8::FakeQuantize>(data, low, high, low, high, 256);
function = std::make_shared<Function>(fq->outputs());

manager.register_pass<ov::pass::AlignEltwiseInputRanks>();
}

comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}

0 comments on commit e533f0e

Please sign in to comment.