From cd8999d43b1cf5425244624f38ab5f75dac87865 Mon Sep 17 00:00:00 2001 From: Pawel Raasz Date: Wed, 8 Mar 2023 08:21:58 +0100 Subject: [PATCH] Fix tile shape inference when repeats got dynamic shape (#15792) * fix shape infer when repeats got dynamic shape * Dynamic output shape when repeats dim is dynamic --- .../include/tile_shape_inference.hpp | 7 ++++--- src/core/tests/type_prop/tile.cpp | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/core/shape_inference/include/tile_shape_inference.hpp b/src/core/shape_inference/include/tile_shape_inference.hpp index 91fe47878dbaa6..abc220b999ccd3 100644 --- a/src/core/shape_inference/include/tile_shape_inference.hpp +++ b/src/core/shape_inference/include/tile_shape_inference.hpp @@ -21,7 +21,8 @@ std::vector shape_infer(const Tile* op, NODE_VALIDATION_CHECK(op, input_shapes.size() == 2); const auto& repeats_shape = input_shapes[1]; - NODE_VALIDATION_CHECK(op, repeats_shape.rank().compatible(1), "Tile repeats must be of rank 1"); + const auto& repeats_rank = repeats_shape.rank(); + NODE_VALIDATION_CHECK(op, repeats_rank.compatible(1), "Tile repeats must be of rank 1"); const auto& arg_shape = input_shapes[0]; T output_shape; @@ -51,8 +52,8 @@ std::vector shape_infer(const Tile* op, rep_it, std::back_inserter(output_shape), std::multiplies()); - } else if (arg_rank.is_static() && repeats_shape[0].is_static()) { - // unknown repeats but shape is 1-D static, any dim can be repeated (add missing dimension) + } else if (arg_rank.is_static() && repeats_rank.is_static() && repeats_shape[0].is_static()) { + // unknown repeats any dim can be repeated (add missing dimension) output_shape.resize(std::max(arg_rank.get_length(), repeats_shape[0].get_length())); } else { // can't deduce shape, set default value diff --git a/src/core/tests/type_prop/tile.cpp b/src/core/tests/type_prop/tile.cpp index bb97e30ef52e02..9d14df58d3e2e6 100644 --- a/src/core/tests/type_prop/tile.cpp +++ b/src/core/tests/type_prop/tile.cpp @@ -139,6 +139,24 @@ TEST_F(TypePropTileTest, preserve_partial_values_and_labels) { ElementsAre(ov::no_label, ov::no_label, ov::no_label, 23, 24)); } +TEST_F(TypePropTileTest, repeats_has_dynamic_shape) { + const auto data = make_shared(element::f32, PartialShape{1, 3, 10, 2, 5}); + const auto repeats = make_shared(element::i32, PartialShape::dynamic()); + + const auto op = make_op(data, repeats); + + EXPECT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic()); +} + +TEST_F(TypePropTileTest, repeats_has_interval_shape) { + const auto data = make_shared(element::f32, PartialShape{1, 3, 10, 2, 5}); + const auto repeats = make_shared(element::i32, PartialShape{{3, 10}}); + + const auto op = make_op(data, repeats); + + EXPECT_EQ(op->get_output_partial_shape(0), PartialShape::dynamic()); +} + using TileTestParam = std::tuple, PartialShape>; class TileTest : public TypePropTileTest, public WithParamInterface {