diff --git a/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_strided_slice.cpp b/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_strided_slice.cpp index 6e7fcd4492c9c5..214998696328ad 100644 --- a/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_strided_slice.cpp +++ b/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_strided_slice.cpp @@ -14,6 +14,7 @@ #include "ngraph/opsets/opset3.hpp" #include #include +#include namespace vpu { @@ -32,10 +33,16 @@ std::shared_ptr calculate_output_shape( const ngraph::AxisSet & begin_mask, const ngraph::AxisSet & end_mask, const ngraph::Output & input_shape) { - const auto shape_type = input_shape.get_element_type(); + const auto& shape_type = input_shape.get_element_type(); + + VPU_THROW_UNLESS(begin.size() == end.size() && begin.size() == strides.size(), + "Begin, end and strides inputs must be of the same size, but {}, {} and {} given accordingly", begin.size(), end.size(), strides.size()); + const auto inputShapeRank = input_shape.get_partial_shape()[0].get_length(); + VPU_THROW_UNLESS(inputShapeRank >= begin.size(), + "Input shape rank must not be less than begin/end/strides size, but {} and {} given accordingly", inputShapeRank, begin.size()); ngraph::OutputVector output_dimensions; - for (int64_t axis = 0; axis < input_shape.get_partial_shape()[0].get_length(); ++axis) { + for (int64_t axis = 0; axis < begin.size(); ++axis) { auto lb = begin[axis], ub = end[axis], stride = strides[axis]; ngraph::Output lower_bound = ngraph::opset3::Constant::create(shape_type, {1}, {lb}); @@ -99,6 +106,22 @@ std::shared_ptr calculate_output_shape( } output_dimensions.push_back(output_dimension); } + + if (output_dimensions.size() < inputShapeRank) { + std::vector indices(inputShapeRank - output_dimensions.size()); + std::iota(indices.begin(), indices.end(), static_cast(output_dimensions.size())); + + const auto tail = std::make_shared( + input_shape, + ngraph::opset3::Constant::create(ngraph::element::i64, {indices.size()}, indices), + ngraph::opset3::Constant::create(shape_type, {}, {0})); + output_dimensions.push_back(tail); + } + + VPU_THROW_UNLESS(output_dimensions.size() == inputShapeRank, + "output shape rank {} must be equal to input shape rank {} for DTS of StridedSlice", + output_dimensions.size(), inputShapeRank); + const auto output_shape = std::make_shared(output_dimensions, 0); return output_shape; }