Skip to content

Commit

Permalink
[core] Shape and PartialShape supports negative idx (openvinotoolkit#…
Browse files Browse the repository at this point in the history
…22049)

* Core Shape and PartialShape supports negative idx

* Remove redundant `&` in shape impl

* Remove not required include
  • Loading branch information
praasz authored Jan 10, 2024
1 parent b7ea17e commit 8798d7b
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 24 deletions.
11 changes: 11 additions & 0 deletions src/core/dev_api/openvino/core/shape_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,16 @@ OPENVINO_API Shape reduce_keep_dims(const Shape& input, const AxisSet& axes);
* @return Result shape from inputs with applied broadcast specification.
*/
Shape get_broadcast_shape(const Shape& first, const Shape& second, const ov::op::AutoBroadcastSpec& broadcast_spec);

/**
* @brief Normalize shape index to the rank
*
* If input index is out of range [-rank, rank) throws exception.
*
* @param idx Shape dimension index.
* @param rank Shape rank.
* @return Normalized shape dimension index.
*/
OPENVINO_API std::ptrdiff_t normalize_shape_index(std::ptrdiff_t idx, size_t rank);
} // namespace util
} // namespace ov
13 changes: 7 additions & 6 deletions src/core/include/openvino/core/partial_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,15 @@ class OPENVINO_API PartialShape {
/// `false`.
bool all_non_negative() const;

/// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected.
/// \brief Index operator for PartialShape, with bound checking.
/// \param i The index of the dimension being selected in range [-rank, rank).
/// \return A reference to the `i`th Dimension of this shape.
const Dimension& operator[](size_t i) const;
/// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected.
Dimension& operator[](std::ptrdiff_t i);
/// \brief Index operator for PartialShape, with bound checking.
/// \param i The index of the dimension being selected in range [-rank, rank).
/// \return A reference to the `i`th Dimension of this shape.
Dimension& operator[](size_t i);
const Dimension& operator[](std::ptrdiff_t i) const;

/// \brief Returns a vector of the dimensions. This has no meaning if dynamic.
explicit operator std::vector<Dimension>() const {
return m_dimensions;
Expand Down
36 changes: 36 additions & 0 deletions src/core/include/openvino/core/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,42 @@ class Shape : public std::vector<size_t> {
OPENVINO_API Shape& operator=(const Shape& v);
OPENVINO_API Shape& operator=(Shape&& v) noexcept;
OPENVINO_API std::string to_string() const;

/**
* @brief Gets dimension at index.
*
* @param i Index to shape dimension [-rank, rank).
*
* @return A reference to i-th dimension of this shape.
*/
OPENVINO_API typename Shape::reference operator[](std::ptrdiff_t i);

/**
* @brief Gets dimension at index.
*
* @param i Index to shape dimension [-rank, rank).
*
* @return A const reference to i-th dimension of this shape.
*/
OPENVINO_API typename Shape::const_reference operator[](std::ptrdiff_t i) const;

/**
* @brief Gets dimension at index, with bounds checking.
*
* @param i Index to shape dimension [-rank, rank).
*
* @return A reference to i-th dimension of this shape.
*/
OPENVINO_API typename Shape::reference at(std::ptrdiff_t i);

/**
* @brief Gets dimension at index, with bounds checking.
*
* @param i Index to shape dimension [-rank, rank).
*
* @return A const reference to i-th dimension of this shape.
*/
OPENVINO_API typename Shape::const_reference at(std::ptrdiff_t i) const;
};

/**
Expand Down
20 changes: 6 additions & 14 deletions src/core/src/partial_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
#include <vector>

#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/shape_util.hpp"
#include "openvino/util/common_util.hpp"

namespace {
static constexpr char dim_out_range_access_txt[] = "Accessing out-of-range dimension in Dimension[]";
}
#include "validation_util.hpp"

ov::PartialShape::PartialShape() : PartialShape(std::initializer_list<Dimension>{}) {}

Expand Down Expand Up @@ -374,17 +372,11 @@ bool ov::PartialShape::all_non_negative() const {
return true;
}

const ov::Dimension& ov::PartialShape::operator[](size_t i) const {
if (i >= m_dimensions.size()) {
OPENVINO_THROW(dim_out_range_access_txt);
}
return m_dimensions[i];
const ov::Dimension& ov::PartialShape::operator[](std::ptrdiff_t i) const {
return m_dimensions[util::normalize_shape_index(i, m_dimensions.size())];
}

ov::Dimension& ov::PartialShape::operator[](size_t i) {
if (i >= m_dimensions.size()) {
OPENVINO_THROW(dim_out_range_access_txt);
}
ov::Dimension& ov::PartialShape::operator[](std::ptrdiff_t i) {
m_shape_type = ShapeType::SHAPE_IS_UPDATED; // We can't guarantee that the shape remains static or dynamic.
return m_dimensions[i];
return m_dimensions[util::normalize_shape_index(i, m_dimensions.size())];
}
27 changes: 23 additions & 4 deletions src/core/src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

#include "openvino/core/shape.hpp"

#include "openvino/core/shape_util.hpp"
#include "openvino/util/common_util.hpp"

using namespace std;
#include "validation_util.hpp"

std::ostream& ov::operator<<(std::ostream& s, const Shape& shape) {
s << "[";
Expand All @@ -16,9 +16,9 @@ std::ostream& ov::operator<<(std::ostream& s, const Shape& shape) {
}

namespace {
size_t stringToSizeT(const string& valStr) {
size_t stringToSizeT(const std::string& valStr) {
size_t ret{0};
istringstream ss(valStr);
std::istringstream ss(valStr);
if (!ss.eof()) {
ss >> ret;
}
Expand Down Expand Up @@ -68,3 +68,22 @@ std::string ov::Shape::to_string() const {
shape_str_stream << *this;
return shape_str_stream.str();
}

namespace ov {

typename Shape::reference Shape::operator[](std::ptrdiff_t i) {
return std::vector<size_t>::operator[](util::normalize(i, size()));
}

typename Shape::const_reference Shape::operator[](std::ptrdiff_t i) const {
return std::vector<size_t>::operator[](util::normalize(i, size()));
}

typename Shape::reference Shape::at(std::ptrdiff_t i) {
return std::vector<size_t>::operator[](util::normalize_shape_index(i, size()));
}

typename Shape::const_reference Shape::at(std::ptrdiff_t i) const {
return std::vector<size_t>::operator[](util::normalize_shape_index(i, size()));
}
} // namespace ov
10 changes: 10 additions & 0 deletions src/core/src/shape_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "openvino/core/partial_shape.hpp"
#include "openvino/core/shape_util.hpp"
#include "validation_util.hpp"

namespace ngraph {
template <>
Expand Down Expand Up @@ -126,5 +127,14 @@ Shape get_broadcast_shape(const Shape& first, const Shape& second, const op::Aut
"Argument shapes are inconsistent");
return out_shape.to_shape();
}

std::ptrdiff_t normalize_shape_index(std::ptrdiff_t idx, size_t rank) {
idx = normalize(idx, static_cast<int64_t>(rank));
if (static_cast<decltype(rank)>(idx) >= rank) {
OPENVINO_THROW("Accessing out-of-range dimension");
} else {
return idx;
}
}
} // namespace util
} // namespace ov
46 changes: 46 additions & 0 deletions src/core/tests/partial_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,3 +1315,49 @@ TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_static_dyn
NodeValidationFailure);
OPENVINO_SUPPRESS_DEPRECATED_END
}

TEST(partial_shape, const_subscribe_operator) {
const auto shape = ov::PartialShape{-1, {2, 10}, 5, 6, 7};

EXPECT_EQ(shape[2], ov::Dimension(5));
EXPECT_EQ(shape[0], ov::Dimension::dynamic());
EXPECT_EQ(shape[1], ov::Dimension(2, 10));
EXPECT_EQ(shape[4], ov::Dimension(7));

EXPECT_EQ(shape[-3], ov::Dimension(5));
EXPECT_EQ(shape[-5], ov::Dimension::dynamic());
EXPECT_EQ(shape[-4], ov::Dimension(2, 10));
EXPECT_EQ(shape[-1], ov::Dimension(7));
}

TEST(partial_shape, subscribe_operator) {
auto shape = ov::PartialShape{-1, {2, 10}, 5, 6, 7};

EXPECT_EQ(shape[2], ov::Dimension(5));
EXPECT_EQ(shape[0], ov::Dimension::dynamic());
EXPECT_EQ(shape[1], ov::Dimension(2, 10));
EXPECT_EQ(shape[4], ov::Dimension(7));

EXPECT_EQ(shape[-3], ov::Dimension(5));
EXPECT_EQ(shape[-5], ov::Dimension::dynamic());
EXPECT_EQ(shape[-4], ov::Dimension(2, 10));
EXPECT_EQ(shape[-1], ov::Dimension(7));
}

TEST(partial_shape, const_subscribe_operator_throw_out_of_range) {
const auto shape = ov::PartialShape::dynamic(7);

EXPECT_THROW(shape[7], ov::Exception);
EXPECT_THROW(shape[1000], ov::Exception);
EXPECT_THROW(shape[-8], ov::Exception);
EXPECT_THROW(shape[-80000], ov::Exception);
}

TEST(partial_shape, subscribe_operator_throw_out_of_range) {
auto shape = ov::PartialShape::dynamic(7);

EXPECT_THROW(shape[7], ov::Exception);
EXPECT_THROW(shape[1000], ov::Exception);
EXPECT_THROW(shape[-8], ov::Exception);
EXPECT_THROW(shape[-80000], ov::Exception);
}
37 changes: 37 additions & 0 deletions src/core/tests/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,40 @@ TEST(shape, test_shape_strides) {
ASSERT_EQ((Strides{7, 1}), row_major_strides(Shape{2, 7}));
ASSERT_EQ((Strides{84, 12, 1}), row_major_strides(Shape{5, 7, 12}));
}

TEST(shape, at) {
const auto shape = ov::Shape{100, 200, 5, 6, 7};

EXPECT_EQ(shape.at(2), 5);
EXPECT_EQ(shape.at(0), 100);
EXPECT_EQ(shape.at(1), 200);
EXPECT_EQ(shape.at(4), 7);

EXPECT_EQ(shape.at(-3), 5);
EXPECT_EQ(shape.at(-5), 100);
EXPECT_EQ(shape.at(-4), 200);
EXPECT_EQ(shape.at(-1), 7);
}

TEST(shape, subscribe_operator) {
const auto shape = ov::Shape{100, 200, 5, 6, 7};

EXPECT_EQ(shape[2], 5);
EXPECT_EQ(shape[0], 100);
EXPECT_EQ(shape[1], 200);
EXPECT_EQ(shape[4], 7);

EXPECT_EQ(shape[-3], 5);
EXPECT_EQ(shape[-5], 100);
EXPECT_EQ(shape[-4], 200);
EXPECT_EQ(shape[-1], 7);
}

TEST(shape, at_throw_exception) {
auto shape = ov::Shape{1, 2, 3, 4, 5, 6, 7};

EXPECT_THROW(shape.at(7), ov::Exception);
EXPECT_THROW(shape.at(1000), ov::Exception);
EXPECT_THROW(shape.at(-8), ov::Exception);
EXPECT_THROW(shape.at(-80000), ov::Exception);
}

0 comments on commit 8798d7b

Please sign in to comment.