Skip to content

Commit

Permalink
fix expected rank for dynamic axis
Browse files Browse the repository at this point in the history
  • Loading branch information
barnasm1 committed Oct 28, 2024
1 parent e7a641f commit 6450f72
Showing 1 changed file with 77 additions and 14 deletions.
91 changes: 77 additions & 14 deletions src/core/tests/type_prop/squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,42 @@ TYPED_TEST(SqueezelOperator, squeeze_data_static_param_axes_1D_two_elem_static_s
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic());
}

TYPED_TEST(SqueezelOperator, squeeze_data_static_param_axes_1D_single_elem_static_shape_squeezable_dims_one) {
TEST(SqueezelOperatorV0, squeeze_data_static_param_axes_1D_single_elem_static_shape_squeezable_dims_one) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{2, 1, 4});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{1});
const auto squeeze = this->make_op(param, axes_node);
const auto squeeze = std::make_shared<ov::op::v0::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic(2));
}

TYPED_TEST(SqueezelOperator, squeeze_data_static_param_axes_scalar_static_shape_squeezable_dims_one) {
TEST(SqueezelOperatorV15, squeeze_data_static_param_axes_1D_single_elem_static_shape_squeezable_dims_one) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{2, 1, 4});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{1});
const auto squeeze = std::make_shared<ov::op::v15::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic({2, 3}));
}

TEST(SqueezelOperatorV0, squeeze_data_static_param_axes_scalar_static_shape_squeezable_dims_one) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{2, 1, 4});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{});
const auto squeeze = this->make_op(param, axes_node);
const auto squeeze = std::make_shared<ov::op::v0::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic(2));
}

TEST(SqueezelOperatorV15, squeeze_data_static_param_axes_scalar_static_shape_squeezable_dims_one) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{2, 1, 4});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{});
const auto squeeze = std::make_shared<ov::op::v15::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic({2, 3}));
}

TYPED_TEST(SqueezelOperator, squeeze_data_scalar_param_axes_1D_single_elem_static_shape) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{1});
Expand Down Expand Up @@ -127,24 +145,42 @@ TYPED_TEST(SqueezelOperator, squeeze_data_static_param_axes_1D_two_elem_static_s
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic());
}

TYPED_TEST(SqueezelOperator, squeeze_data_static_param_axes_1D_single_elem_static_shape_squeezable_dims_more) {
TEST(SqueezelOperatorV0, squeeze_data_static_param_axes_1D_single_elem_static_shape_squeezable_dims_more) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{1, 2, 1, 3, 1});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{1});
const auto squeeze = this->make_op(param, axes_node);
const auto squeeze = std::make_shared<ov::op::v0::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic(4));
}

TYPED_TEST(SqueezelOperator, squeeze_data_static_param_axes_scalar_static_shape_squeezable_dims_more) {
TEST(SqueezelOperatorV15, squeeze_data_static_param_axes_1D_single_elem_static_shape_squeezable_dims_more) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{1, 2, 1, 3, 1});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{1});
const auto squeeze = std::make_shared<ov::op::v15::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic({4, 5}));
}

TEST(SqueezelOperatorV0, squeeze_data_static_param_axes_scalar_static_shape_squeezable_dims_more) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{1, 2, 1, 3, 1});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{});
const auto squeeze = this->make_op(param, axes_node);
const auto squeeze = std::make_shared<ov::op::v0::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic(4));
}

TEST(SqueezelOperatorV15, squeeze_data_static_param_axes_scalar_static_shape_squeezable_dims_more) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{1, 2, 1, 3, 1});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{});
const auto squeeze = std::make_shared<ov::op::v15::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic({4, 5}));
}

TYPED_TEST(SqueezelOperator, squeeze_data_dynamic_param_axes_1D_two_elem_static_shape_squeezable_dims_more) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{-1, {2, 8}, {1, 3}, {4, -1}});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{2});
Expand All @@ -154,24 +190,42 @@ TYPED_TEST(SqueezelOperator, squeeze_data_dynamic_param_axes_1D_two_elem_static_
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic());
}

TYPED_TEST(SqueezelOperator, squeeze_data_dynamic_param_axes_1D_single_elem_static_shape_squeezable_dims_more) {
TEST(SqueezelOperatorV0, squeeze_data_dynamic_param_axes_1D_single_elem_static_shape_squeezable_dims_more) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{-1, {2, 8}, {1, 3}, {4, -1}});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{1});
const auto squeeze = this->make_op(param, axes_node);
const auto squeeze = std::make_shared<ov::op::v0::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic(3));
}

TYPED_TEST(SqueezelOperator, squeeze_data_dynamic_param_axes_scalar_static_shape_squeezable_dims_more) {
TEST(SqueezelOperatorV15, squeeze_data_dynamic_param_axes_1D_single_elem_static_shape_squeezable_dims_more) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{-1, {2, 8}, {1, 3}, {4, -1}});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{1});
const auto squeeze = std::make_shared<ov::op::v15::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic({3, 4}));
}

TEST(SqueezelOperatorV0, squeeze_data_dynamic_param_axes_scalar_static_shape_squeezable_dims_more) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{-1, {2, 8}, {1, 3}, {4, -1}});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{});
const auto squeeze = this->make_op(param, axes_node);
const auto squeeze = std::make_shared<ov::op::v0::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic(3));
}

TEST(SqueezelOperatorV15, squeeze_data_dynamic_param_axes_scalar_static_shape_squeezable_dims_more) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{-1, {2, 8}, {1, 3}, {4, -1}});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{});
const auto squeeze = std::make_shared<ov::op::v15::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic({3, 4}));
}

TYPED_TEST(SqueezelOperator, squeeze_data_dyamic_param_axes_1D_two_elem_static_shape_squeezable_dims_one) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{2, -1, 4});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{2});
Expand All @@ -190,15 +244,24 @@ TYPED_TEST(SqueezelOperator, squeeze_data_dynamic_param_axes_1D_three_elem_stati
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic());
}

TYPED_TEST(SqueezelOperator, squeeze_data_dynamic_param_axes_1D_single_elem_static_shape_squeezable_dims_less) {
TEST(SqueezelOperatorV0, squeeze_data_dynamic_param_axes_1D_single_elem_static_shape_squeezable_dims_less) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{-1, {2, 8}, {1, 3}, {4, -1}});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{1});
const auto squeeze = this->make_op(param, axes_node);
const auto squeeze = std::make_shared<ov::op::v0::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic(3));
}

TEST(SqueezelOperatorV15, squeeze_data_dynamic_param_axes_1D_single_elem_static_shape_squeezable_dims_less) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, PartialShape{-1, {2, 8}, {1, 3}, {4, -1}});
const auto axes_node = std::make_shared<ov::op::v0::Parameter>(element::u64, PartialShape{1});
const auto squeeze = std::make_shared<ov::op::v15::Squeeze>(param, axes_node);

EXPECT_EQ(squeeze->get_element_type(), element::f32);
EXPECT_EQ(squeeze->get_output_partial_shape(0), PartialShape::dynamic({3, 4}));
}

using SqueezeTypePropTestParam = std::tuple<PartialShape, // Input shape
std::vector<int64_t>, // Squeeze axis
PartialShape // Expected shape
Expand Down

0 comments on commit 6450f72

Please sign in to comment.