Skip to content

Commit

Permalink
finally implemented all ngraph shape_inference unit-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Dec 3, 2020
1 parent 47f5795 commit 0d61833
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 31 deletions.
1 change: 1 addition & 0 deletions docs/MO_DG/prepare_model/Supported_Frameworks_Layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ Standard ONNX\* operators:
| Floor | No |
| GRU | No |
| Gather | No |
| GatherElements | only with positive indices |
| GatherND | No |
| GatherTree | No |
| Gemm | No |
Expand Down
10 changes: 4 additions & 6 deletions model-optimizer/extensions/ops/gatherelements.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@
limitations under the License.
"""

import numpy as np

from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.ops.op import Op, PermuteAttrs
from mo.ops.op import Op


class GatherElements(Op):
Expand All @@ -40,5 +37,6 @@ def backend_attrs(self):

@staticmethod
def infer(node: Node):
indices_shape = node.in_port(1).data.get_shape()
node.out_port(0).data.set_shape(indices_shape)
node.in_port(1).data.get_shape()

# todo: add value_inference
4 changes: 2 additions & 2 deletions ngraph/core/include/ngraph/op/gather_elements.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace ngraph
/// \param axis specifies axis along which indices are specified
GatherElements(const Output<Node>& data,
const Output<Node>& indices,
const size_t axis);
const int axis);

void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
Expand All @@ -48,7 +48,7 @@ namespace ngraph

size_t get_axis() const { return m_axis; }
private:
size_t m_axis;
int m_axis;
};
}
}
Expand Down
12 changes: 6 additions & 6 deletions ngraph/core/src/op/gather_elements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ NGRAPH_RTTI_DEFINITION(op::v6::GatherElements, "GatherElements", 6);

op::v6::GatherElements::GatherElements(const Output<Node>& data,
const Output<Node>& indices,
const size_t axis)
const int axis)
: Op({data, indices})
, m_axis(axis)
{
Expand All @@ -53,16 +53,16 @@ void op::v6::GatherElements::validate_and_infer_types()

if (data_rank.is_static())
{
auto data_rank_size = data_rank.get_length();
int data_rank_size = data_rank.get_length();

NODE_VALIDATION_CHECK(this, data_rank_size > 1, "Data rank must be greater than 1.");
NODE_VALIDATION_CHECK(this, data_rank_size > 1, "data rank must be greater than 1.");

if (m_axis < 0)
{
NODE_VALIDATION_CHECK(
this,
-data_rank_size < m_axis < data_rank_size - 1,
"axis must be within interval (-data.rank, data.rank - 1. But instead Got: ",
(-data_rank_size < m_axis) && (m_axis < data_rank_size - 1),
"axis must be within interval (-data.rank, data.rank - 1). But instead Got: ",
m_axis);
m_axis = data_rank_size + m_axis;
}
Expand All @@ -71,7 +71,7 @@ void op::v6::GatherElements::validate_and_infer_types()
if (indices_rank.is_static())
{
NODE_VALIDATION_CHECK(
this, indices_rank.get_length() > 1, "Indices rank must be greater that 1.");
this, indices_rank.get_length() > 1, "indices rank must be greater than 1.");
}

if (data_rank.is_static() && indices_rank.is_static())
Expand Down
151 changes: 134 additions & 17 deletions ngraph/test/type_prop/gather_elements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using namespace ngraph;

// ------------------------------ V6 ------------------------------

TEST(type_prop, gather_elements_2d_axis_0)
TEST(type_prop, gather_elements_2D_axis_0)
{
Shape data_shape{3, 3};
Shape indices_shape{2, 3};
Expand All @@ -35,7 +35,7 @@ TEST(type_prop, gather_elements_2d_axis_0)
ASSERT_EQ(GE->get_shape(), indices_shape);
}

TEST(type_prop, gather_elements_2d_axis_1)
TEST(type_prop, gather_elements_2D_axis_1)
{
Shape data_shape{3, 3};
Shape indices_shape{3, 1};
Expand All @@ -47,7 +47,7 @@ TEST(type_prop, gather_elements_2d_axis_1)
ASSERT_EQ(GE->get_shape(), indices_shape);
}

TEST(type_prop, gather_elements_3d_axis_0)
TEST(type_prop, gather_elements_3D_axis_0)
{
Shape data_shape{3, 3, 10000};
Shape indices_shape{300, 3, 10000};
Expand All @@ -59,57 +59,174 @@ TEST(type_prop, gather_elements_3d_axis_0)
ASSERT_EQ(GE->get_shape(), indices_shape);
}

TEST(type_prop, gather_elements_3D_axis_2)
{
Shape data_shape{300, 3, 10};
Shape indices_shape{300, 3, 10000};
int axis = 2;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
ASSERT_EQ(GE->get_shape(), indices_shape);
}

TEST(type_prop, gather_elements_4D_axis_minus_1)
{
Shape data_shape{300, 3, 10, 1};
Shape indices_shape{300, 3, 10, 33333};
int axis = -1;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
ASSERT_EQ(GE->get_element_type(), element::Type_t::f32);
ASSERT_EQ(GE->get_shape(), indices_shape);
}

// --------------------- Negative tests ------------------------------
TEST(type_prop, gather_elements_shapes_inconsistency)

TEST(type_prop, gather_elements_type_inconsistency)
{
Shape data_shape{3, 3};
Shape indices_shape{2, 1};
int axis = 1;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);
auto I = make_shared<op::Parameter>(element::Type_t::u32, indices_shape);

try
{
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
// Should have thrown, so fail if it didn't
FAIL() << "Shape inconsistency check failed";
FAIL() << "the indices tensor type check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("data and indices must have equal shapes except for axis "));
error.what(), std::string("indices mush be of int32 or int64 type. But instead got"));
}
catch (...)
{
FAIL() << "Deduced shape check failed for unexpected reason";
FAIL() << "type check failed for unexpected reason";
}
}

TEST(type_prop, gather_elements_type_inconsistency)
TEST(type_prop, gather_elements_data_rank_check)
{
Shape data_shape{3};
Shape indices_shape{2, 333};
int axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);

try
{
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
// Should have thrown, so fail if it didn't
FAIL() << "data rank check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("data rank must be greater than 1"));
}
catch (...)
{
FAIL() << "data rank check failed for unexpected reason";
}
}

TEST(type_prop, gather_elements_out_of_bounds_axis)
{
Shape data_shape{3, 3};
Shape indices_shape{2, 1};
int axis = -33;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);

try
{
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
// Should have thrown, so fail if it didn't
FAIL() << "axis out of bounds check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("axis must be within interval"));
}
catch (...)
{
FAIL() << "axis out of bounds check failed for unexpected reason";
}
}

TEST(type_prop, gather_elements_indices_rank_check)
{
Shape data_shape{3, 3};
Shape indices_shape{333};
int axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);

try
{
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
// Should have thrown, so fail if it didn't
FAIL() << "indices rank check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("indices rank must be greater than 1"));
}
catch (...)
{
FAIL() << "indices rank check failed for unexpected reason";
}
}

TEST(type_prop, gather_elements_rank_consistency_check)
{
Shape data_shape{3, 3};
Shape indices_shape{2, 3, 3333};
int axis = 0;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);

try
{
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
// Should have thrown, so fail if it didn't
FAIL() << "rank consistency check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("data and indices rank must be equal"));
}
catch (...)
{
FAIL() << "rank consistency check failed for unexpected reason";
}
}

TEST(type_prop, gather_elements_shapes_inconsistency)
{
Shape data_shape{3, 3};
Shape indices_shape{2, 1};
int axis = 1;
auto D = make_shared<op::Parameter>(element::Type_t::f32, data_shape);
auto I = make_shared<op::Parameter>(element::Type_t::u32, indices_shape);
auto I = make_shared<op::Parameter>(element::Type_t::i32, indices_shape);

try
{
auto GE = make_shared<op::v6::GatherElements>(D, I, axis);
// Should have thrown, so fail if it didn't
FAIL() << "the indices tensor type check failed";
FAIL() << "Shape inconsistency check failed";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("indices mush be of int32 or int64 type. But instead got"));
error.what(), std::string("data and indices must have equal shapes except for axis "));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
FAIL() << "Deduced shape check failed for unexpected reason";
}
}

// negative tests
// axis out of bounds
// rank check

0 comments on commit 0d61833

Please sign in to comment.