Skip to content

Commit

Permalink
New Gather op reference implementation. (#3633)
Browse files Browse the repository at this point in the history
* New Gather op reference implementation.

* Unify span implementation for gather and gather_nd.

Create span.hpp for common implementation of span.

* Move span to utils directory.

* Address review comments.

* update span

* Address PR comments.

Co-authored-by: Patryk Elszkowski <[email protected]>
  • Loading branch information
pelszkow and Patryk Elszkowski authored Dec 23, 2020
1 parent 96b0325 commit bd9bbe0
Show file tree
Hide file tree
Showing 5 changed files with 460 additions and 211 deletions.
8 changes: 7 additions & 1 deletion ngraph/core/reference/include/ngraph/coordinate_range.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ namespace ngraph
Strides(source_shape.size(), 1));
}

/// \brief Class allows to iterate over Tensor with reverted axies part by part.
/// \brief Class allows to iterate over Tensor with reverted axes part by part.
///
/// To create ReverseRange use _reverse_ function.
///
Expand Down Expand Up @@ -213,8 +213,14 @@ namespace ngraph
return ReverseRange(source_shape, reversed_axis);
}

inline ReverseRange index(const Shape& source_shape)
{
return reverse(source_shape, {});
}

} // namespace impl
using impl::Direction;
using impl::index;
using impl::reverse;
using impl::slice;
} // namespace coordinates
Expand Down
222 changes: 91 additions & 131 deletions ngraph/core/reference/include/ngraph/runtime/reference/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,156 +18,116 @@

#include <numeric>

#include "ngraph/coordinate_range.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/runtime/reference/gather_nd.hpp"
#include "utils/span.hpp"

namespace ngraph
{
namespace runtime
{
namespace reference
{
// Implement gather by calling gather_nd on sub-problems
// # prepare constant shapes for tensors used for sub problems
// indices'.shape = indices.shape[-1] + [1]
// params'.shape = params.shape[axis:]
// out'.shape = params'.shape
// out'.shape[0] = indices.shape[-1]
// # call sub-problems
// foreach (params_index, out_index) in outer "axis" dimensions
// # params_prime is shared by inner loop
// params' = param[params_index] # rank(params') == rank(params) - axis
// foreach indices_index in outer N-1 dimensions
// indices' = indices[indices_index] # rank(indices') == 2
// out_index = out_index + indices_index
// out' = out[out_index] # rank(out') == rank(params')
// gather_nd(params', indices'', out')
namespace
{
template <typename Container>
Shape to_shape(const Container& c)
{
return Shape(begin(c), end(c));
}

template <typename Container>
std::vector<size_t>
join(const Container& c1, const Container& c2, const Container& c3)
{
using container_value_type =
typename std::remove_cv<typename Container::value_type>::type;
static_assert(std::is_same<container_value_type, size_t>::value,
"Expect same type in container");
std::vector<size_t> ret;
ret.reserve(c1.size() + c2.size() + c3.size());
std::copy(begin(c1), end(c1), std::back_inserter(ret));
std::copy(begin(c2), end(c2), std::back_inserter(ret));
std::copy(begin(c3), end(c3), std::back_inserter(ret));
return ret;
}

const auto only_one = [] { return coordinates::index(Shape{1}); };
} // namespace
template <typename T, typename U>
void gather(const T* params,
const U* indices,
T* out,
void gather(const T* const params,
const U* const indices,
T* const out,
const Shape& params_shape,
const Shape& indices_shape,
const Shape& out_shape,
size_t axis)
{
// prepare shape of params_prime (remove first "axis" dimensions)
const Shape params_prime_shape(params_shape.begin() + axis, params_shape.end());
// prepare shape of indices_prime
const size_t indices_ndim = indices_shape.size();
Shape indices_prime_shape;
// prepare shape of out_prime (same as params_prime except for first dim)
Shape out_prime_shape(params_prime_shape);
if (indices_ndim > 0)
{
out_prime_shape[0] = indices_shape[indices_ndim - 1];
indices_prime_shape.emplace_back(indices_shape[indices_ndim - 1]);
}
else
{
out_prime_shape[0] = 1;
}
indices_prime_shape.emplace_back(1);
using std::next;
assert(std::memset(out, 0, shape_size(out_shape) * sizeof(T)));

// Create a CoordinateTransform for "out" that visits the outer "axis" dimensions
const size_t out_ndim = out_shape.size();
const Coordinate out_outer_start_corner(out_ndim, 0);
Coordinate out_outer_end_corner(out_shape);
for (size_t i = axis; i < out_ndim; i++)
{
out_outer_end_corner[i] = 1;
}
Strides out_outer_strides(out_ndim, 1);
AxisVector out_outer_axis_order(out_ndim);
std::iota(out_outer_axis_order.begin(), out_outer_axis_order.end(), 0);
CoordinateTransform out_outer_transform(out_shape,
out_outer_start_corner,
out_outer_end_corner,
out_outer_strides,
out_outer_axis_order);

// Create a CoordinateTransform for "params" that visits the outer "axis" dimensions
const size_t params_ndim = params_shape.size();
const Coordinate params_outer_start_corner(params_ndim, 0);
Coordinate params_outer_end_corner(params_shape);
for (size_t i = axis; i < params_ndim; i++)
{
params_outer_end_corner[i] = 1;
}
const Strides params_outer_strides(params_ndim, 1);
AxisVector params_outer_axis_order(params_ndim);
std::iota(params_outer_axis_order.begin(), params_outer_axis_order.end(), 0);
const CoordinateTransform params_outer_transform(params_shape,
params_outer_start_corner,
params_outer_end_corner,
params_outer_strides,
params_outer_axis_order);

// Create a CoordinateTransform for "indices" that visits only the first element
// along inner most axis
const Coordinate indices_outer_start_corner(indices_ndim, 0);
Coordinate indices_outer_end_corner(indices_shape);
if (indices_ndim > 0)
{
indices_outer_end_corner[indices_ndim - 1] = 1;
}
const Strides indices_outer_strides(indices_ndim, 1);
AxisVector indices_outer_axis_order(indices_ndim);
std::iota(indices_outer_axis_order.begin(), indices_outer_axis_order.end(), 0);
const CoordinateTransform indices_outer_transform(indices_shape,
indices_outer_start_corner,
indices_outer_end_corner,
indices_outer_strides,
indices_outer_axis_order);

// Create an inner CoordinateTransfrom for "out"
const size_t out_inner_ndim = out_ndim - axis;
const Shape out_inner_shape(out_shape.begin() + axis, out_shape.end());
const Coordinate out_inner_start_corner(out_inner_ndim, 0);
Coordinate out_inner_end_corner(out_inner_shape);
if (indices_ndim > 0)
{
out_inner_end_corner[indices_ndim - 1] = 1;
}
for (size_t i = indices_ndim; i < out_inner_ndim; i++)
{
out_inner_end_corner[i] = 1;
}
const Strides out_inner_strides(out_inner_ndim, 1);
AxisVector out_inner_axis_order(out_inner_ndim);
std::iota(out_inner_axis_order.begin(), out_inner_axis_order.end(), 0);
const CoordinateTransform out_inner_transform(out_inner_shape,
out_inner_start_corner,
out_inner_end_corner,
out_inner_strides,
out_inner_axis_order);

auto out_outer_coord_iter = out_outer_transform.begin();
for (const Coordinate& params_outer_coord : params_outer_transform)
const auto params_axes_part = span(params_shape).subspan(0, axis);

NGRAPH_CHECK(params_shape.size() >= axis, "Not enough axes in param_shape.");

const auto remainder_part_shape = span(params_shape).subspan(axis + 1);

const auto found_out_shape =
join(params_axes_part, span(indices_shape), remainder_part_shape);

NGRAPH_CHECK(found_out_shape == out_shape,
"Output shape mismatch with calculations");

const auto batch_shape = span(params_shape).subspan(axis);

const auto batch_size = shape_size(batch_shape);

const auto copy_size = shape_size(remainder_part_shape);

const size_t copy_round_in_batch =
indices_shape.size() > 1
? shape_size(span(indices_shape.data(), indices_shape.size() - 1))
: 1;
const size_t round_batch_offset = indices_shape.empty() ? 1 : indices_shape.back();

auto dst = out;

auto gather_range = params_axes_part.empty()
? only_one()
: coordinates::index(to_shape(params_axes_part));
for (auto i : gather_range)
{
if (out_outer_coord_iter == out_outer_transform.end())
break;
const T* params_prime =
&params[params_outer_transform.index(params_outer_coord)];
T* out_outer = &out[out_outer_transform.index(*out_outer_coord_iter)];

auto out_inner_coord_iter = out_inner_transform.begin();
for (const Coordinate& indices_outer_coord : indices_outer_transform)
auto batch_index = i.begin_index;
for (size_t batch = 0; batch != i.element_number;
batch_index += i.step, ++batch)
{
if (out_inner_coord_iter == out_inner_transform.end())
break;
const U* indices_prime =
&indices[indices_outer_transform.index(indices_outer_coord)];
T* out_prime = &out_outer[out_inner_transform.index(*out_inner_coord_iter)];
gather_nd<T, U>(params_prime,
indices_prime,
out_prime,
params_prime_shape,
indices_prime_shape,
out_prime_shape);
++out_inner_coord_iter;
const auto batch_offset = batch_index * batch_size;
assert(batch_offset < shape_size(params_shape));
for (size_t round = 0; round != copy_round_in_batch; ++round)
{
const U* input_indices = indices + round * round_batch_offset;
const auto indices_no =
indices_shape.empty() ? 1 : indices_shape.back();

assert(!batch_shape.empty());
for (size_t ii = 0; ii != indices_no; ++ii)
{
const auto positive_input_index =
input_indices[ii] < 0 ? batch_shape.front() + input_indices[ii]
: input_indices[ii];

const auto src_offset =
batch_offset + copy_size * positive_input_index;

const auto src_begin = next(params, src_offset);
const auto src_end = next(src_begin, copy_size);

std::copy(src_begin, src_end, dst);
dst += copy_size;
}
}
}
++out_outer_coord_iter;
}
}
} // namespace reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,59 +21,16 @@
#include <numeric>

#include "ngraph/coordinate_transform.hpp"
#include "utils/span.hpp"

namespace ngraph
{
namespace runtime
{
namespace reference
{
namespace
namespace details
{
template <bool check>
using Required = typename std::enable_if<check, bool>::type;

template <typename It>
struct IsRandomAccessIt
{
static constexpr bool value =
std::is_same<typename It::iterator_category,
std::random_access_iterator_tag>::value;
};

template <typename Iterator, Required<IsRandomAccessIt<Iterator>::value> = true>
class Span
{
public:
Span(Iterator begin, Iterator end)
: m_begin{begin}
, m_end{end}
{
}

Iterator begin() const { return m_begin; }
Iterator end() const { return m_end; };
typename Iterator::value_type operator[](size_t idx) const
{
return *next(m_begin, idx);
}

typename Iterator::difference_type size() const
{
return std::distance(m_begin, m_end);
}

private:
Iterator m_begin;
Iterator m_end;
};

template <typename Iterator>
Span<Iterator> span(Iterator begin, Iterator end)
{
return Span<Iterator>{begin, end};
};

template <typename Iterator>
std::vector<size_t> get_indices_offsets(const Iterator beg,
const Iterator end,
Expand All @@ -90,7 +47,7 @@ namespace ngraph

return offsets;
}
} // namespace
} // namespace details

///
/// Implementation find maximum length of *slice* of input *params* which might be
Expand Down Expand Up @@ -143,14 +100,14 @@ namespace ngraph
"params_shape should have enough rank to be index by indices"};
}

const auto slice_shape =
span(next(begin(params_shape), first_slice_index_in_params), end(params_shape));
const auto slice_shape = span(params_shape).subspan(first_slice_index_in_params);
const auto slice_size = shape_size(slice_shape);

const auto dims_begin = next(rbegin(params_shape), slice_shape.size());
const auto dims_end = next(dims_begin, indices_shape.back() - 1);

const auto indices_offsets = get_indices_offsets(dims_begin, dims_end, slice_size);
const auto indices_offsets =
details::get_indices_offsets(dims_begin, dims_end, slice_size);

const auto batch_offset = indices_offsets.front() * params_shape[batch_dims];

Expand Down
Loading

0 comments on commit bd9bbe0

Please sign in to comment.