Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Shape and PartialShape supports negative idx #22049

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/core/dev_api/openvino/core/shape_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,16 @@ OPENVINO_API Shape reduce_keep_dims(const Shape& input, const AxisSet& axes);
* @return Result shape from inputs with applied broadcast specification.
*/
Shape get_broadcast_shape(const Shape& first, const Shape& second, const ov::op::AutoBroadcastSpec& broadcast_spec);

/**
* @brief Normalize shape index to the rank
*
* If input index is out of range [-rank, rank) throws exception.
*
* @param idx Shape dimension index.
* @param rank Shape rank.
* @return Normalized shape dimension index.
*/
OPENVINO_API std::ptrdiff_t normalize_shape_index(std::ptrdiff_t idx, size_t rank);
} // namespace util
} // namespace ov
13 changes: 7 additions & 6 deletions src/core/include/openvino/core/partial_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,15 @@ class OPENVINO_API PartialShape {
/// `false`.
bool all_non_negative() const;

/// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected.
/// \brief Index operator for PartialShape, with bound checking.
/// \param i The index of the dimension being selected in range [-rank, rank).
/// \return A reference to the `i`th Dimension of this shape.
const Dimension& operator[](size_t i) const;
/// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected.
Dimension& operator[](std::ptrdiff_t i);
/// \brief Index operator for PartialShape, with bound checking.
/// \param i The index of the dimension being selected in range [-rank, rank).
/// \return A reference to the `i`th Dimension of this shape.
Dimension& operator[](size_t i);
const Dimension& operator[](std::ptrdiff_t i) const;

/// \brief Returns a vector of the dimensions. This has no meaning if dynamic.
explicit operator std::vector<Dimension>() const {
return m_dimensions;
Expand Down
36 changes: 36 additions & 0 deletions src/core/include/openvino/core/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,42 @@ class Shape : public std::vector<size_t> {
OPENVINO_API Shape& operator=(const Shape& v);
OPENVINO_API Shape& operator=(Shape&& v) noexcept;
OPENVINO_API std::string to_string() const;

/**
* @brief Gets dimension at index.
*
* @param i Index to shape dimension [-rank, rank).
*
* @return A reference to i-th dimension of this shape.
*/
OPENVINO_API typename Shape::reference operator[](std::ptrdiff_t i);

/**
* @brief Gets dimension at index.
*
* @param i Index to shape dimension [-rank, rank).
*
* @return A const reference to i-th dimension of this shape.
*/
OPENVINO_API typename Shape::const_reference operator[](std::ptrdiff_t i) const;

/**
* @brief Gets dimension at index, with bounds checking.
*
* @param i Index to shape dimension [-rank, rank).
*
* @return A reference to i-th dimension of this shape.
*/
OPENVINO_API typename Shape::reference at(std::ptrdiff_t i);

/**
* @brief Gets dimension at index, with bounds checking.
*
* @param i Index to shape dimension [-rank, rank).
*
* @return A const reference to i-th dimension of this shape.
*/
OPENVINO_API typename Shape::const_reference at(std::ptrdiff_t i) const;
};

/**
Expand Down
21 changes: 7 additions & 14 deletions src/core/src/partial_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
#include <iostream>
#include <vector>

#include "compare.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/shape_util.hpp"
#include "openvino/util/common_util.hpp"

namespace {
static constexpr char dim_out_range_access_txt[] = "Accessing out-of-range dimension in Dimension[]";
}
#include "validation_util.hpp"

ov::PartialShape::PartialShape() : PartialShape(std::initializer_list<Dimension>{}) {}

Expand Down Expand Up @@ -374,17 +373,11 @@ bool ov::PartialShape::all_non_negative() const {
return true;
}

const ov::Dimension& ov::PartialShape::operator[](size_t i) const {
if (i >= m_dimensions.size()) {
OPENVINO_THROW(dim_out_range_access_txt);
}
return m_dimensions[i];
const ov::Dimension& ov::PartialShape::operator[](std::ptrdiff_t i) const {
return m_dimensions[util::normalize_shape_index(i, m_dimensions.size())];
}

ov::Dimension& ov::PartialShape::operator[](size_t i) {
if (i >= m_dimensions.size()) {
OPENVINO_THROW(dim_out_range_access_txt);
}
ov::Dimension& ov::PartialShape::operator[](std::ptrdiff_t i) {
m_shape_type = ShapeType::SHAPE_IS_UPDATED; // We can't guarantee that the shape remains static or dynamic.
return m_dimensions[i];
return m_dimensions[util::normalize_shape_index(i, m_dimensions.size())];
}
27 changes: 23 additions & 4 deletions src/core/src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

#include "openvino/core/shape.hpp"

#include "openvino/core/shape_util.hpp"
#include "openvino/util/common_util.hpp"

using namespace std;
#include "validation_util.hpp"

std::ostream& ov::operator<<(std::ostream& s, const Shape& shape) {
s << "[";
Expand All @@ -16,9 +16,9 @@ std::ostream& ov::operator<<(std::ostream& s, const Shape& shape) {
}

namespace {
size_t stringToSizeT(const string& valStr) {
size_t stringToSizeT(const std::string& valStr) {
size_t ret{0};
istringstream ss(valStr);
std::istringstream ss(valStr);
if (!ss.eof()) {
ss >> ret;
}
Expand Down Expand Up @@ -68,3 +68,22 @@ std::string ov::Shape::to_string() const {
shape_str_stream << *this;
return shape_str_stream.str();
}

namespace ov {

typename Shape::reference& Shape::operator[](std::ptrdiff_t i) {
return std::vector<size_t>::operator[](util::normalize(i, size()));
}

typename Shape::const_reference& Shape::operator[](std::ptrdiff_t i) const {
return std::vector<size_t>::operator[](util::normalize(i, size()));
}

typename Shape::reference& Shape::at(std::ptrdiff_t i) {
return std::vector<size_t>::operator[](util::normalize_shape_index(i, size()));
}

typename Shape::const_reference& Shape::at(std::ptrdiff_t i) const {
return std::vector<size_t>::operator[](util::normalize_shape_index(i, size()));
}
} // namespace ov
10 changes: 10 additions & 0 deletions src/core/src/shape_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "openvino/core/partial_shape.hpp"
#include "openvino/core/shape_util.hpp"
#include "validation_util.hpp"

namespace ngraph {
template <>
Expand Down Expand Up @@ -126,5 +127,14 @@ Shape get_broadcast_shape(const Shape& first, const Shape& second, const op::Aut
"Argument shapes are inconsistent");
return out_shape.to_shape();
}

std::ptrdiff_t normalize_shape_index(std::ptrdiff_t idx, size_t rank) {
idx = normalize(idx, static_cast<int64_t>(rank));
if (static_cast<decltype(rank)>(idx) >= rank) {
OPENVINO_THROW("Accessing out-of-range dimension");
} else {
return idx;
}
}
} // namespace util
} // namespace ov
46 changes: 46 additions & 0 deletions src/core/tests/partial_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,3 +1315,49 @@ TEST(partial_shape, infer_windowed_reduction_rank_static_dynamic_rank_static_dyn
NodeValidationFailure);
OPENVINO_SUPPRESS_DEPRECATED_END
}

TEST(partial_shape, const_subscribe_operator) {
const auto shape = ov::PartialShape{-1, {2, 10}, 5, 6, 7};

EXPECT_EQ(shape[2], ov::Dimension(5));
EXPECT_EQ(shape[0], ov::Dimension::dynamic());
EXPECT_EQ(shape[1], ov::Dimension(2, 10));
EXPECT_EQ(shape[4], ov::Dimension(7));

EXPECT_EQ(shape[-3], ov::Dimension(5));
EXPECT_EQ(shape[-5], ov::Dimension::dynamic());
EXPECT_EQ(shape[-4], ov::Dimension(2, 10));
EXPECT_EQ(shape[-1], ov::Dimension(7));
}

TEST(partial_shape, subscribe_operator) {
auto shape = ov::PartialShape{-1, {2, 10}, 5, 6, 7};

EXPECT_EQ(shape[2], ov::Dimension(5));
EXPECT_EQ(shape[0], ov::Dimension::dynamic());
EXPECT_EQ(shape[1], ov::Dimension(2, 10));
EXPECT_EQ(shape[4], ov::Dimension(7));

EXPECT_EQ(shape[-3], ov::Dimension(5));
EXPECT_EQ(shape[-5], ov::Dimension::dynamic());
EXPECT_EQ(shape[-4], ov::Dimension(2, 10));
EXPECT_EQ(shape[-1], ov::Dimension(7));
}

TEST(partial_shape, const_subscribe_operator_throw_out_of_range) {
const auto shape = ov::PartialShape::dynamic(7);

EXPECT_THROW(shape[7], ov::Exception);
EXPECT_THROW(shape[1000], ov::Exception);
EXPECT_THROW(shape[-8], ov::Exception);
EXPECT_THROW(shape[-80000], ov::Exception);
}

TEST(partial_shape, subscribe_operator_throw_out_of_range) {
auto shape = ov::PartialShape::dynamic(7);

EXPECT_THROW(shape[7], ov::Exception);
EXPECT_THROW(shape[1000], ov::Exception);
EXPECT_THROW(shape[-8], ov::Exception);
EXPECT_THROW(shape[-80000], ov::Exception);
}
37 changes: 37 additions & 0 deletions src/core/tests/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,40 @@ TEST(shape, test_shape_strides) {
ASSERT_EQ((Strides{7, 1}), row_major_strides(Shape{2, 7}));
ASSERT_EQ((Strides{84, 12, 1}), row_major_strides(Shape{5, 7, 12}));
}

TEST(shape, at) {
const auto shape = ov::Shape{100, 200, 5, 6, 7};

EXPECT_EQ(shape.at(2), 5);
EXPECT_EQ(shape.at(0), 100);
EXPECT_EQ(shape.at(1), 200);
EXPECT_EQ(shape.at(4), 7);

EXPECT_EQ(shape.at(-3), 5);
EXPECT_EQ(shape.at(-5), 100);
EXPECT_EQ(shape.at(-4), 200);
EXPECT_EQ(shape.at(-1), 7);
}

TEST(shape, subscribe_operator) {
const auto shape = ov::Shape{100, 200, 5, 6, 7};

EXPECT_EQ(shape[2], 5);
EXPECT_EQ(shape[0], 100);
EXPECT_EQ(shape[1], 200);
EXPECT_EQ(shape[4], 7);

EXPECT_EQ(shape[-3], 5);
EXPECT_EQ(shape[-5], 100);
EXPECT_EQ(shape[-4], 200);
EXPECT_EQ(shape[-1], 7);
}

TEST(shape, at_throw_exception) {
auto shape = ov::Shape{1, 2, 3, 4, 5, 6, 7};

EXPECT_THROW(shape.at(7), ov::Exception);
EXPECT_THROW(shape.at(1000), ov::Exception);
EXPECT_THROW(shape.at(-8), ov::Exception);
EXPECT_THROW(shape.at(-80000), ov::Exception);
}
Loading