Skip to content

Commit

Permalink
Add a check for negative axis attr out of bounds (#4543)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bartosz Lesniewski authored Mar 15, 2021
1 parent 24fb09e commit 635ffc7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
8 changes: 6 additions & 2 deletions ngraph/core/src/op/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,14 @@ void op::Concat::validate_and_infer_types()
}
auto concat_axis = get_concatenation_axis();
NODE_VALIDATION_CHECK(this,
concat_axis < this_input_rank.get_length(),
concat_axis < this_input_rank.get_length() && concat_axis >= 0,
"Concatenation axis (",
concat_axis,
") is out of bounds for ",
") is out of bounds [",
-this_input_rank.get_length(),
", ",
this_input_rank.get_length() - 1,
"] for ",
"argument ",
i,
", which has shape ",
Expand Down
34 changes: 34 additions & 0 deletions ngraph/test/type_prop/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,37 @@ TEST(type_prop, concat_partial_all_static_with_concat_axis_static_dims_incompati
FAIL() << "Deduced type check failed for unexpected reason";
}
}

TEST(type_prop, concat_partial_negative_axis_correct)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{3, 2, 4});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{7, 2, 4});
auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 2, 4});

auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, -3);

ASSERT_EQ(c->get_element_type(), element::f32);
ASSERT_EQ(c->get_shape(), (Shape{12, 2, 4}));
}

TEST(type_prop, concat_partial_negative_axis_incorrect)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 7, 4});
auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 2, 4});

try
{
auto c = make_shared<op::Concat>(NodeVector{param0, param1, param2}, -4);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect negative axis value not detected (out of bounds)";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Concatenation axis (-1) is out of bounds"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}

0 comments on commit 635ffc7

Please sign in to comment.