Skip to content

Commit

Permalink
added correct handling of dynamic shapes in nGraph, added unit-tests …
Browse files Browse the repository at this point in the history
…for dynamic cases, fixed dump typos in MO, replaced axis type from int -> int64_t
  • Loading branch information
pavel-esir committed Dec 4, 2020
1 parent 5f6bc73 commit e692101
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class GatherElementsFrontExtractor(FrontExtractorOp):
op = 'GatherElements'
enabled = False
enabled = True

@classmethod
def extract(cls, node):
Expand Down
5 changes: 3 additions & 2 deletions model-optimizer/extensions/ops/gatherelements.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

class GatherElements(Op):
op = 'GatherElements'
enabled = True

def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
Expand All @@ -30,13 +29,15 @@ def __init__(self, graph: Graph, attrs: dict):
'infer': self.infer,
'in_ports_count': 2,
'out_ports_count': 1,
'axis': 0,
}, attrs)

def backend_attrs(self):
return ['axis']

@staticmethod
def infer(node: Node):
node.in_port(1).data.get_shape()
indices_shape = node.in_port(1).data.get_shape()
node.out_port(0).data.set_shape(indices_shape)

# todo: add value_inference
4 changes: 2 additions & 2 deletions ngraph/core/include/ngraph/op/gather_elements.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace ngraph
/// \param axis specifies axis along which indices are specified
GatherElements(const Output<Node>& data,
const Output<Node>& indices,
const int axis);
const int64_t axis);

void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
Expand All @@ -48,7 +48,7 @@ namespace ngraph

size_t get_axis() const { return m_axis; }
private:
int m_axis;
int64_t m_axis;
};
}
}
Expand Down
53 changes: 33 additions & 20 deletions ngraph/core/src/op/gather_elements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ NGRAPH_RTTI_DEFINITION(op::v6::GatherElements, "GatherElements", 6);

op::v6::GatherElements::GatherElements(const Output<Node>& data,
const Output<Node>& indices,
const int axis)
const int64_t axis)
: Op({data, indices})
, m_axis(axis)
{
Expand All @@ -53,7 +53,7 @@ void op::v6::GatherElements::validate_and_infer_types()

if (data_rank.is_static())
{
int data_rank_size = data_rank.get_length();
int64_t data_rank_size = data_rank.get_length();

NODE_VALIDATION_CHECK(this, data_rank_size > 1, "data rank must be greater than 1.");

Expand All @@ -74,7 +74,7 @@ void op::v6::GatherElements::validate_and_infer_types()
this, indices_rank.get_length() > 1, "indices rank must be greater than 1.");
}

if (data_rank.is_static() && indices_rank.is_static())
if (indices_rank.is_static() && data_rank.is_static())
{
NODE_VALIDATION_CHECK(this,
data_rank.get_length() == indices_rank.get_length(),
Expand All @@ -83,28 +83,41 @@ void op::v6::GatherElements::validate_and_infer_types()
" and ",
indices_rank.get_length());

if (data_pshape.is_static() && indices_pshape.is_static())
PartialShape out_shape_info(indices_pshape);

for (int i = 0; i < indices_rank.get_length(); i++)
{
// check if PartialShapes of data and indices are consistent
for (int i = 0; i < data_rank.get_length(); i++)
if (i != m_axis && data_pshape[i].is_static() && indices_pshape[i].is_static())
{
NODE_VALIDATION_CHECK(
this,
data_pshape[i] == indices_pshape[i],
"Sizes ",
data_pshape[i],
" and ",
indices_pshape[i],
" on axis ",
i,
" do not match. data and indices must have equal shapes except for axis ",
m_axis);
}
// if indices.size of the current axis is unknown try to retrieve it from data
// if data_shape = {4, 4, ?} indices_shape = {1, ?, 5} and axis = 0
// make optimistic assumptions that data size along 2nd dim is also 5
else if (indices_pshape[i].is_dynamic() && data_pshape[i].is_static())
{
if (i != m_axis)
NODE_VALIDATION_CHECK(
this,
data_pshape[i] == indices_pshape[i],
"Sizes ",
data_pshape[i],
" and ",
indices_pshape[i],
" on axis ",
i,
" do not match. data and indices must have equal shapes except for axis ",
m_axis);
out_shape_info[i] = data_pshape[i];
}
}
set_output_type(0, data_type, out_shape_info);
}
else // if (indices_rank.is_dynamic() || data_rank.is_dynamic())
{
// if at least one input has dynamic rank pass PartialShape of the other input
// in the optimistic scenario this input will have at least static rank
// in the worse scenario passing both PartialShapes is equivalent
set_output_type(0, data_type, data_rank.is_dynamic() ? indices_pshape : data_pshape);
}

set_output_type(0, data_type, indices_pshape);
}

bool op::v6::GatherElements::visit_attributes(AttributeVisitor& visitor)
Expand Down
81 changes: 71 additions & 10 deletions ngraph/test/type_prop/gather_elements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ TEST(type_prop, gather_elements_3D_axis_0)
{
Shape data_shape{3, 3, 10000};
Shape indices_shape{300, 3, 10000};
int axis = 0;
int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
Expand All @@ -63,7 +63,7 @@ TEST(type_prop, gather_elements_3D_axis_2)
{
Shape data_shape{300, 3, 10};
Shape indices_shape{300, 3, 10000};
int axis = 2;
int64_t axis = 2;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
Expand All @@ -75,21 +75,57 @@ TEST(type_prop, gather_elements_4D_axis_minus_1)
{
Shape data_shape{300, 3, 10, 1};
Shape indices_shape{300, 3, 10, 33333};
int axis = -1;
int64_t axis = -1;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
ASSERT_EQ(GE->get_shape(), indices_shape);
}

TEST(type_prop, gather_elements_nonfloat_data_type_int64_indices)
{
Shape data_shape{300, 3, 10, 1};
Shape indices_shape{300, 3, 10, 33333};
int64_t axis = -1;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
ASSERT_EQ(GE->get_shape(), indices_shape);
}

TEST(type_prop, gather_elements_dynamic_consistent_shapes)
{
PartialShape data_shape{4, 4, Dimension::dynamic()};
PartialShape indices_shape{1, Dimension::dynamic(), 5};
int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
ASSERT_EQ(GE->get_shape(), Shape({1, 4, 5}));
}

TEST(type_prop, gather_elements_dynamic_out_shape)
{
PartialShape data_shape{4, 4, Dimension::dynamic()};
PartialShape indices_shape{1, Dimension::dynamic(), Dimension::dynamic()};
int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::i8);
ASSERT_EQ(GE->get_output_partial_shape(0), PartialShape({1, 4, Dimension::dynamic()}));
}

// --------------------- Negative tests ------------------------------

TEST(type_prop, gather_elements_type_inconsistency)
{
Shape data_shape{3, 3};
Shape indices_shape{2, 1};
int axis = 1;
int64_t axis = 1;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::u32, indices_shape);

Expand All @@ -114,7 +150,7 @@ TEST(type_prop, gather_elements_data_rank_check)
{
Shape data_shape{3};
Shape indices_shape{2, 333};
int axis = 0;
int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);

Expand All @@ -138,7 +174,7 @@ TEST(type_prop, gather_elements_out_of_bounds_axis)
{
Shape data_shape{3, 3};
Shape indices_shape{2, 1};
int axis = -33;
int64_t axis = -33;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);

Expand All @@ -162,7 +198,7 @@ TEST(type_prop, gather_elements_indices_rank_check)
{
Shape data_shape{3, 3};
Shape indices_shape{333};
int axis = 0;
int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);

Expand All @@ -186,7 +222,7 @@ TEST(type_prop, gather_elements_rank_consistency_check)
{
Shape data_shape{3, 3};
Shape indices_shape{2, 3, 3333};
int axis = 0;
int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);

Expand All @@ -210,7 +246,7 @@ TEST(type_prop, gather_elements_shapes_inconsistency)
{
Shape data_shape{3, 3};
Shape indices_shape{2, 1};
int axis = 1;
int64_t axis = 1;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);

Expand All @@ -227,6 +263,31 @@ TEST(type_prop, gather_elements_shapes_inconsistency)
}
catch (...)
{
FAIL() << "Deduced shape check failed for unexpected reason";
FAIL() << "Shape inconsistency check failed for unexpected reason";
}
}

TEST(type_prop, gather_elements_dynamic_inconsistent_shapes)
{
PartialShape data_shape{4, 2, 4, Dimension::dynamic()};
PartialShape indices_shape{1, 3, Dimension::dynamic(), 5};
int64_t axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::i8, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i64, indices_shape);

try
{
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Shape inconsistency check for dynamic PartialShape failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("data and indices must have equal shapes except for axis "));
}
catch (...)
{
FAIL() << "Shape inconsistency check for dynamic PartialShape failed for unexpected reason";
}
}

0 comments on commit e692101

Please sign in to comment.