diff --git a/src/core/dev_api/openvino/core/shape_util.hpp b/src/core/dev_api/openvino/core/shape_util.hpp index 89688526d4f286..dda99a8ddc883c 100644 --- a/src/core/dev_api/openvino/core/shape_util.hpp +++ b/src/core/dev_api/openvino/core/shape_util.hpp @@ -60,5 +60,16 @@ OPENVINO_API Shape reduce_keep_dims(const Shape& input, const AxisSet& axes); * @return Result shape from inputs with applied broadcast specification. */ Shape get_broadcast_shape(const Shape& first, const Shape& second, const ov::op::AutoBroadcastSpec& broadcast_spec); + +/** + * @brief Normalize shape index to the rank + * + * If input index is out of range [-rank, rank) throws exception. + * + * @param idx Shape dimension index. + * @param rank Shape rank. + * @return Normalized shape dimension index. + */ +OPENVINO_API std::ptrdiff_t normalize_shape_index(std::ptrdiff_t idx, size_t rank); } // namespace util } // namespace ov diff --git a/src/core/include/openvino/core/partial_shape.hpp b/src/core/include/openvino/core/partial_shape.hpp index 4f88efcd1bf088..041cb7f6789343 100644 --- a/src/core/include/openvino/core/partial_shape.hpp +++ b/src/core/include/openvino/core/partial_shape.hpp @@ -168,14 +168,15 @@ class OPENVINO_API PartialShape { /// `false`. bool all_non_negative() const; - /// \brief Index operator for PartialShape. - /// \param i The index of the dimension being selected. + /// \brief Index operator for PartialShape, with bound checking. + /// \param i The index of the dimension being selected in range [-rank, rank). /// \return A reference to the `i`th Dimension of this shape. - const Dimension& operator[](size_t i) const; - /// \brief Index operator for PartialShape. - /// \param i The index of the dimension being selected. + Dimension& operator[](std::ptrdiff_t i); + /// \brief Index operator for PartialShape, with bound checking. + /// \param i The index of the dimension being selected in range [-rank, rank). /// \return A reference to the `i`th Dimension of this shape. - Dimension& operator[](size_t i); + const Dimension& operator[](std::ptrdiff_t i) const; + /// \brief Returns a vector of the dimensions. This has no meaning if dynamic. explicit operator std::vector() const { return m_dimensions; diff --git a/src/core/include/openvino/core/shape.hpp b/src/core/include/openvino/core/shape.hpp index a04a864a8394fb..4283e3a4a54ee1 100644 --- a/src/core/include/openvino/core/shape.hpp +++ b/src/core/include/openvino/core/shape.hpp @@ -40,6 +40,42 @@ class Shape : public std::vector { OPENVINO_API Shape& operator=(const Shape& v); OPENVINO_API Shape& operator=(Shape&& v) noexcept; OPENVINO_API std::string to_string() const; + + /** + * @brief Gets dimension at index. + * + * @param i Index to shape dimension [-rank, rank). + * + * @return A reference to i-th dimension of this shape. + */ + OPENVINO_API typename Shape::reference operator[](std::ptrdiff_t i); + + /** + * @brief Gets dimension at index. + * + * @param i Index to shape dimension [-rank, rank). + * + * @return A const reference to i-th dimension of this shape. + */ + OPENVINO_API typename Shape::const_reference operator[](std::ptrdiff_t i) const; + + /** + * @brief Gets dimension at index, with bounds checking. + * + * @param i Index to shape dimension [-rank, rank). + * + * @return A reference to i-th dimension of this shape. + */ + OPENVINO_API typename Shape::reference at(std::ptrdiff_t i); + + /** + * @brief Gets dimension at index, with bounds checking. + * + * @param i Index to shape dimension [-rank, rank). + * + * @return A const reference to i-th dimension of this shape. + */ + OPENVINO_API typename Shape::const_reference at(std::ptrdiff_t i) const; }; /** diff --git a/src/core/src/partial_shape.cpp b/src/core/src/partial_shape.cpp index 38fc53cb88a846..993a957e447fec 100644 --- a/src/core/src/partial_shape.cpp +++ b/src/core/src/partial_shape.cpp @@ -9,11 +9,9 @@ #include #include "openvino/core/dimension_tracker.hpp" +#include "openvino/core/shape_util.hpp" #include "openvino/util/common_util.hpp" - -namespace { -static constexpr char dim_out_range_access_txt[] = "Accessing out-of-range dimension in Dimension[]"; -} +#include "validation_util.hpp" ov::PartialShape::PartialShape() : PartialShape(std::initializer_list{}) {} @@ -374,17 +372,11 @@ bool ov::PartialShape::all_non_negative() const { return true; } -const ov::Dimension& ov::PartialShape::operator[](size_t i) const { - if (i >= m_dimensions.size()) { - OPENVINO_THROW(dim_out_range_access_txt); - } - return m_dimensions[i]; +const ov::Dimension& ov::PartialShape::operator[](std::ptrdiff_t i) const { + return m_dimensions[util::normalize_shape_index(i, m_dimensions.size())]; } -ov::Dimension& ov::PartialShape::operator[](size_t i) { - if (i >= m_dimensions.size()) { - OPENVINO_THROW(dim_out_range_access_txt); - } +ov::Dimension& ov::PartialShape::operator[](std::ptrdiff_t i) { m_shape_type = ShapeType::SHAPE_IS_UPDATED; // We can't guarantee that the shape remains static or dynamic. - return m_dimensions[i]; + return m_dimensions[util::normalize_shape_index(i, m_dimensions.size())]; } diff --git a/src/core/src/shape.cpp b/src/core/src/shape.cpp index 7a91ff20c9c6de..dbc57212952710 100644 --- a/src/core/src/shape.cpp +++ b/src/core/src/shape.cpp @@ -4,9 +4,9 @@ #include "openvino/core/shape.hpp" +#include "openvino/core/shape_util.hpp" #include "openvino/util/common_util.hpp" - -using namespace std; +#include "validation_util.hpp" std::ostream& ov::operator<<(std::ostream& s, const Shape& shape) { s << "["; @@ -16,9 +16,9 @@ std::ostream& ov::operator<<(std::ostream& s, const Shape& shape) { } namespace { -size_t stringToSizeT(const string& valStr) { +size_t stringToSizeT(const std::string& valStr) { size_t ret{0}; - istringstream ss(valStr); + std::istringstream ss(valStr); if (!ss.eof()) { ss >> ret; } @@ -68,3 +68,22 @@ std::string ov::Shape::to_string() const { shape_str_stream << *this; return shape_str_stream.str(); } + +namespace ov { + +typename Shape::reference Shape::operator[](std::ptrdiff_t i) { + return std::vector::operator[](util::normalize(i, size())); +} + +typename Shape::const_reference Shape::operator[](std::ptrdiff_t i) const { + return std::vector::operator[](util::normalize(i, size())); +} + +typename Shape::reference Shape::at(std::ptrdiff_t i) { + return std::vector::operator[](util::normalize_shape_index(i, size())); +} + +typename Shape::const_reference Shape::at(std::ptrdiff_t i) const { + return std::vector::operator[](util::normalize_shape_index(i, size())); +} +} // namespace ov diff --git a/src/core/src/shape_util.cpp b/src/core/src/shape_util.cpp index 810686c9c7f88c..d84d8153a92059 100644 --- a/src/core/src/shape_util.cpp +++ b/src/core/src/shape_util.cpp @@ -8,6 +8,7 @@ #include "openvino/core/partial_shape.hpp" #include "openvino/core/shape_util.hpp" +#include "validation_util.hpp" namespace ngraph { template <> @@ -126,5 +127,14 @@ Shape get_broadcast_shape(const Shape& first, const Shape& second, const op::Aut "Argument shapes are inconsistent"); return out_shape.to_shape(); } + +std::ptrdiff_t normalize_shape_index(std::ptrdiff_t idx, size_t rank) { + idx = normalize(idx, static_cast(rank)); + if (static_cast(idx) >= rank) { + OPENVINO_THROW("Accessing out-of-range dimension"); + } else { + return idx; + } +} } // namespace util } // namespace ov diff --git a/src/core/tests/partial_shape.cpp b/src/core/tests/partial_shape.cpp index 0d23114e219cb7..d3817a7d5c935d 100644 --- a/src/core/tests/partial_shape.cpp +++ b/src/core/tests/partial_shape.cpp @@ -1315,3 +1315,49 @@ TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_static_dyn NodeValidationFailure); OPENVINO_SUPPRESS_DEPRECATED_END } + +TEST(partial_shape, const_subscribe_operator) { + const auto shape = ov::PartialShape{-1, {2, 10}, 5, 6, 7}; + + EXPECT_EQ(shape[2], ov::Dimension(5)); + EXPECT_EQ(shape[0], ov::Dimension::dynamic()); + EXPECT_EQ(shape[1], ov::Dimension(2, 10)); + EXPECT_EQ(shape[4], ov::Dimension(7)); + + EXPECT_EQ(shape[-3], ov::Dimension(5)); + EXPECT_EQ(shape[-5], ov::Dimension::dynamic()); + EXPECT_EQ(shape[-4], ov::Dimension(2, 10)); + EXPECT_EQ(shape[-1], ov::Dimension(7)); +} + +TEST(partial_shape, subscribe_operator) { + auto shape = ov::PartialShape{-1, {2, 10}, 5, 6, 7}; + + EXPECT_EQ(shape[2], ov::Dimension(5)); + EXPECT_EQ(shape[0], ov::Dimension::dynamic()); + EXPECT_EQ(shape[1], ov::Dimension(2, 10)); + EXPECT_EQ(shape[4], ov::Dimension(7)); + + EXPECT_EQ(shape[-3], ov::Dimension(5)); + EXPECT_EQ(shape[-5], ov::Dimension::dynamic()); + EXPECT_EQ(shape[-4], ov::Dimension(2, 10)); + EXPECT_EQ(shape[-1], ov::Dimension(7)); +} + +TEST(partial_shape, const_subscribe_operator_throw_out_of_range) { + const auto shape = ov::PartialShape::dynamic(7); + + EXPECT_THROW(shape[7], ov::Exception); + EXPECT_THROW(shape[1000], ov::Exception); + EXPECT_THROW(shape[-8], ov::Exception); + EXPECT_THROW(shape[-80000], ov::Exception); +} + +TEST(partial_shape, subscribe_operator_throw_out_of_range) { + auto shape = ov::PartialShape::dynamic(7); + + EXPECT_THROW(shape[7], ov::Exception); + EXPECT_THROW(shape[1000], ov::Exception); + EXPECT_THROW(shape[-8], ov::Exception); + EXPECT_THROW(shape[-80000], ov::Exception); +} diff --git a/src/core/tests/shape.cpp b/src/core/tests/shape.cpp index fc80a7ab30c490..0f543a927d62b2 100644 --- a/src/core/tests/shape.cpp +++ b/src/core/tests/shape.cpp @@ -22,3 +22,40 @@ TEST(shape, test_shape_strides) { ASSERT_EQ((Strides{7, 1}), row_major_strides(Shape{2, 7})); ASSERT_EQ((Strides{84, 12, 1}), row_major_strides(Shape{5, 7, 12})); } + +TEST(shape, at) { + const auto shape = ov::Shape{100, 200, 5, 6, 7}; + + EXPECT_EQ(shape.at(2), 5); + EXPECT_EQ(shape.at(0), 100); + EXPECT_EQ(shape.at(1), 200); + EXPECT_EQ(shape.at(4), 7); + + EXPECT_EQ(shape.at(-3), 5); + EXPECT_EQ(shape.at(-5), 100); + EXPECT_EQ(shape.at(-4), 200); + EXPECT_EQ(shape.at(-1), 7); +} + +TEST(shape, subscribe_operator) { + const auto shape = ov::Shape{100, 200, 5, 6, 7}; + + EXPECT_EQ(shape[2], 5); + EXPECT_EQ(shape[0], 100); + EXPECT_EQ(shape[1], 200); + EXPECT_EQ(shape[4], 7); + + EXPECT_EQ(shape[-3], 5); + EXPECT_EQ(shape[-5], 100); + EXPECT_EQ(shape[-4], 200); + EXPECT_EQ(shape[-1], 7); +} + +TEST(shape, at_throw_exception) { + auto shape = ov::Shape{1, 2, 3, 4, 5, 6, 7}; + + EXPECT_THROW(shape.at(7), ov::Exception); + EXPECT_THROW(shape.at(1000), ov::Exception); + EXPECT_THROW(shape.at(-8), ov::Exception); + EXPECT_THROW(shape.at(-80000), ov::Exception); +}