Skip to content

Commit

Permalink
Fix tile shape inference when repeats got dynamic shape (#15792)
Browse files Browse the repository at this point in the history
* fix shape infer when repeats got dynamic shape

* Dynamic output shape when repeats dim is dynamic
  • Loading branch information
praasz authored Mar 8, 2023
1 parent 7c8dc76 commit cd8999d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/core/shape_inference/include/tile_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ std::vector<T> shape_infer(const Tile* op,
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2);

const auto& repeats_shape = input_shapes[1];
NODE_VALIDATION_CHECK(op, repeats_shape.rank().compatible(1), "Tile repeats must be of rank 1");
const auto& repeats_rank = repeats_shape.rank();
NODE_VALIDATION_CHECK(op, repeats_rank.compatible(1), "Tile repeats must be of rank 1");

const auto& arg_shape = input_shapes[0];
T output_shape;
Expand Down Expand Up @@ -51,8 +52,8 @@ std::vector<T> shape_infer(const Tile* op,
rep_it,
std::back_inserter(output_shape),
std::multiplies<TDim>());
} else if (arg_rank.is_static() && repeats_shape[0].is_static()) {
// unknown repeats but shape is 1-D static, any dim can be repeated (add missing dimension)
} else if (arg_rank.is_static() && repeats_rank.is_static() && repeats_shape[0].is_static()) {
// unknown repeats any dim can be repeated (add missing dimension)
output_shape.resize(std::max<size_t>(arg_rank.get_length(), repeats_shape[0].get_length()));
} else {
// can't deduce shape, set default value
Expand Down
18 changes: 18 additions & 0 deletions src/core/tests/type_prop/tile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,24 @@ TEST_F(TypePropTileTest, preserve_partial_values_and_labels) {
ElementsAre(ov::no_label, ov::no_label, ov::no_label, 23, 24));
}

TEST_F(TypePropTileTest, repeats_has_dynamic_shape) {
const auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 10, 2, 5});
const auto repeats = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());

const auto op = make_op(data, repeats);

EXPECT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic());
}

TEST_F(TypePropTileTest, repeats_has_interval_shape) {
const auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 10, 2, 5});
const auto repeats = make_shared<op::Parameter>(element::i32, PartialShape{{3, 10}});

const auto op = make_op(data, repeats);

EXPECT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic());
}

using TileTestParam = std::tuple<PartialShape, std::vector<int64_t>, PartialShape>;

class TileTest : public TypePropTileTest, public WithParamInterface<TileTestParam> {
Expand Down

0 comments on commit cd8999d

Please sign in to comment.