Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
allnes and ilya-lavrenov authored Jan 12, 2023
1 parent 0b94374 commit c9a60a4
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ static void pad_input_data(const uint8_t* data_ptr,
const ngraph::Shape& input_shape,
const ngraph::Shape& padded_input_shape,
const std::vector<size_t>& pads_begin) {
ngraph::CoordinateTransform input_transform(input_shape);
ngraph::CoordinateTransform padded_transform(padded_input_shape);
ArmPlugin::opset::CoordinateTransform input_transform(input_shape);
ArmPlugin::opset::CoordinateTransform padded_transform(padded_input_shape);

for (const ngraph::Coordinate& input_coord : input_transform) {
auto padded_coord = input_coord;
Expand Down
180 changes: 180 additions & 0 deletions modules/arm_plugin/src/opset/interpolate_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "interpolate_arm.hpp"
#include "ngraph/coordinate_index.hpp"

using namespace ngraph;
using namespace ArmPlugin;
Expand Down Expand Up @@ -50,3 +51,182 @@ std::shared_ptr<ngraph::Node> ArmPlugin::opset::ArmInterpolate::clone_with_new_i
throw ngraph_error("Unsupported number of arguments for ArmInterpolate operation");
}
}

namespace {
Strides default_strides(size_t n_axes) {
return Strides(n_axes, 1);
}
CoordinateDiff default_padding(size_t n_axes) {
return CoordinateDiff(n_axes, 0);
}
AxisVector default_axis_order(size_t n_axes) {
AxisVector result(n_axes);
std::iota(result.begin(), result.end(), 0);
return result;
}

Coordinate default_source_start_corner(size_t n_axes) {
return Coordinate(n_axes, 0);
}
Coordinate default_source_end_corner(const Shape& source_shape) {
return source_shape;
}
} // namespace

ArmPlugin::opset::CoordinateTransform::CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner,
const Coordinate& source_end_corner,
const Strides& source_strides,
const AxisVector& source_axis_order,
const CoordinateDiff& target_padding_below,
const CoordinateDiff& target_padding_above,
const Strides& target_dilation_strides)
: CoordinateTransformBasic(source_shape),
m_source_start_corner(source_start_corner),
m_source_end_corner(source_end_corner),
m_source_strides(source_strides),
m_source_axis_order(source_axis_order),
m_target_padding_below(target_padding_below),
m_target_padding_above(target_padding_above),
m_target_dilation_strides(target_dilation_strides) {
m_n_axes = source_shape.size();

if (m_n_axes != source_start_corner.size()) {
throw std::domain_error("Source start corner does not have the same number of axes as the source space shape");
}
if (m_n_axes != source_end_corner.size()) {
throw std::domain_error("Source end corner does not have the same number of axes as the source space shape");
}
if (m_n_axes != source_strides.size()) {
throw std::domain_error("Source strides do not have the same number of axes as the source space shape");
}
if (m_n_axes != source_axis_order.size()) {
// Note: this check is NOT redundant with the is_permutation check below, though you might
// think it is. If the lengths don't match then is_permutation won't catch that; it'll
// either stop short or walk off the end of source_axis_order.
throw std::domain_error("Source axis order does not have the same number of axes as the source space shape");
}
if (m_n_axes != target_padding_below.size()) {
throw std::domain_error("Padding-below shape does not have the same number of axes as the source space shape");
}
if (m_n_axes != target_padding_above.size()) {
throw std::domain_error("Padding-above shape does not have the same number of axes as the source space shape");
}
if (m_n_axes != target_dilation_strides.size()) {
throw std::domain_error("Target dilation strides do not have the same number of axes as the source shape");
}

AxisVector all_axes(m_n_axes);
for (size_t i = 0; i < all_axes.size(); i++) {
all_axes[i] = i;
}

if (!std::is_permutation(all_axes.begin(), all_axes.end(), source_axis_order.begin())) {
throw std::domain_error("Source axis order is not a permutation of {0,...,n-1} where n is the number of axes "
"in the source space shape");
}

for (size_t i = 0; i < m_n_axes; i++) {
if (target_dilation_strides[i] == 0) {
std::stringstream ss;

ss << "The target dilation stride is 0 at axis " << i;
throw std::domain_error(ss.str());
}
}

std::vector<std::ptrdiff_t> padded_upper_bounds;

for (size_t i = 0; i < m_n_axes; i++) {
NGRAPH_SUPPRESS_DEPRECATED_START
std::ptrdiff_t padded_upper_bound = subtract_or_zero(source_shape[i], size_t(1)) * target_dilation_strides[i] +
1 + target_padding_below[i] + target_padding_above[i];
NGRAPH_SUPPRESS_DEPRECATED_END

if (padded_upper_bound < 0) {
std::stringstream ss;

ss << "The end corner is out of bounds at axis " << i;
throw std::domain_error(ss.str());
}

padded_upper_bounds.push_back(padded_upper_bound);
}

for (size_t i = 0; i < m_n_axes; i++) {
if (static_cast<int64_t>(source_start_corner[i]) >= padded_upper_bounds[i] &&
source_start_corner[i] != source_shape[i]) {
std::stringstream ss;

ss << "The start corner is out of bounds at axis " << i;
throw std::domain_error(ss.str());
}

if (static_cast<int64_t>(source_end_corner[i]) > padded_upper_bounds[i]) {
std::stringstream ss;

ss << "The end corner is out of bounds at axis " << i;
throw std::domain_error(ss.str());
}
}

for (size_t i = 0; i < m_n_axes; i++) {
if (source_strides[i] == 0) {
std::stringstream ss;

ss << "The source stride is 0 at axis " << i;
throw std::domain_error(ss.str());
}
}

for (size_t axis = 0; axis < m_n_axes; axis++) {
m_target_shape.push_back(
ceil_div(source_end_corner[source_axis_order[axis]] - source_start_corner[source_axis_order[axis]],
source_strides[source_axis_order[axis]]));
}
}

ArmPlugin::opset::CoordinateTransform::CoordinateTransform(const Shape& source_shape)
: CoordinateTransform(source_shape,
default_source_start_corner(source_shape.size()),
default_source_end_corner(source_shape),
default_strides(source_shape.size()),
default_axis_order(source_shape.size()),
default_padding(source_shape.size()),
default_padding(source_shape.size()),
default_strides(source_shape.size())) {}

// Compute the index of a target-space coordinate in thebuffer.
size_t ArmPlugin::opset::CoordinateTransform::index(const Coordinate& c) const {
return coordinate_index(to_source_coordinate(c), m_source_shape);
}

// Convert a target-space coordinate to a source-space coordinate.
Coordinate ArmPlugin::opset::CoordinateTransform::to_source_coordinate(const Coordinate& c_target) const {
if (c_target.size() != m_n_axes) {
throw std::domain_error("Target coordinate rank does not match the coordinate transform rank");
}

Coordinate c_source(c_target.size());

for (size_t target_axis = 0; target_axis < m_n_axes; target_axis++) {
size_t source_axis = m_source_axis_order[target_axis];

size_t target_pos = c_target[target_axis];
size_t pos_destrided = target_pos * m_source_strides[source_axis];
size_t pos_deshifted = pos_destrided + m_source_start_corner[source_axis];
size_t pos_depadded = pos_deshifted - m_target_padding_below[target_axis];
size_t pos_dedilated = pos_depadded / m_target_dilation_strides[target_axis];
c_source[source_axis] = pos_dedilated;
}

return c_source;
}

CoordinateIterator ArmPlugin::opset::CoordinateTransform::begin() const noexcept {
return CoordinateIterator(m_target_shape);
}

const CoordinateIterator& ArmPlugin::opset::CoordinateTransform::end() const noexcept {
return CoordinateIterator::end();
}
43 changes: 43 additions & 0 deletions modules/arm_plugin/src/opset/interpolate_arm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "ngraph_opset.hpp"
#include "utils.hpp"
#include "ngraph/coordinate_transform.hpp"

namespace ArmPlugin {
namespace opset {
Expand All @@ -29,5 +30,47 @@ class ArmInterpolate : public Interpolate {
private:
Interpolate::InterpolateAttrs m_attrs;
};

class CoordinateTransform : protected ngraph::CoordinateTransformBasic {
public:
using Iterator = ngraph::CoordinateIterator;

CoordinateTransform(const ov::Shape& source_shape,
const ov::Coordinate& source_start_corner,
const ov::Coordinate& source_end_corner,
const ov::Strides& source_strides,
const ov::AxisVector& source_axis_order,
const ov::CoordinateDiff& target_padding_below,
const ov::CoordinateDiff& target_padding_above,
const ov::Strides& source_dilation_strides);

CoordinateTransform(const ov::Shape& source_shape);

/// \brief The tensor element index calculation by given coordinate.
/// \param c tensor element coordinate
size_t index(const ov::Coordinate& c) const;

/// \brief Convert a target-space coordinate to a source-space coordinate.
/// \param c tensor element coordinate
ov::Coordinate to_source_coordinate(const ov::Coordinate& c) const;

/// \brief Returns an iterator to the first coordinate of the tensor.
ngraph::CoordinateIterator begin() const noexcept;

/// \brief Returns an iterator to the coordinate following the last element of the tensor.
const ngraph::CoordinateIterator& end() const noexcept;

private:
ov::Coordinate m_source_start_corner;
ov::Coordinate m_source_end_corner;
ov::Strides m_source_strides;
ov::AxisVector m_source_axis_order;
ov::CoordinateDiff m_target_padding_below;
ov::CoordinateDiff m_target_padding_above;
ov::Strides m_target_dilation_strides;

ov::Shape m_target_shape;
size_t m_n_axes;
};
} // namespace opset
} // namespace ArmPlugin

0 comments on commit c9a60a4

Please sign in to comment.