Skip to content

Commit

Permalink
Leverage on dpctl.tensor implementation in dpnp.take_along_axis (#1969
Browse files Browse the repository at this point in the history
)

* Implement dpnp.take_along_axis through dpctl.tensor

* Added more tests to cover new logic

* Increase test coverage

* Fix type in docstring of dpnp.put()
  • Loading branch information
antonwolfy authored Aug 12, 2024
1 parent 4a23239 commit 5fda819
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 47 deletions.
34 changes: 26 additions & 8 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"):
v : {scalar, array_like}
Values to be put into `a`. Must be broadcastable to the result shape
``a.shape[:axis] + ind.shape + a.shape[axis+1:]``.
axis {None, int}, optional
axis : {None, int}, optional
The axis along which the values will be placed. If `a` is 1-D array,
this argument is optional.
Default: ``None``.
Expand Down Expand Up @@ -1502,7 +1502,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
return dpnp.get_result_array(result, out)


def take_along_axis(a, indices, axis):
def take_along_axis(a, indices, axis, mode="wrap"):
"""
Take values from the input array by matching 1d index and data slices.
Expand All @@ -1523,15 +1523,24 @@ def take_along_axis(a, indices, axis):
Indices to take along each 1d slice of `a`. This must match the
dimension of the input array, but dimensions ``Ni`` and ``Nj``
only need to broadcast against `a`.
axis : int
axis : {None, int}
The axis to take 1d slices along. If axis is ``None``, the input
array is treated as if it had first been flattened to 1d,
for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
mode : {"wrap", "clip"}, optional
Specifies how out-of-bounds indices will be handled. Possible values
are:
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
negative indices.
- ``"clip"``: clips indices to (``0 <= i < n``).
Default: ``"wrap"``.
Returns
-------
out : dpnp.ndarray
The indexed result.
The indexed result of the same data type as `a`.
See Also
--------
Expand Down Expand Up @@ -1591,12 +1600,21 @@ def take_along_axis(a, indices, axis):
"""

dpnp.check_supported_arrays_type(a, indices)

if axis is None:
a = a.ravel()
dpnp.check_supported_arrays_type(indices)
if indices.ndim != 1:
raise ValueError(
"when axis=None, `indices` must have a single dimension."
)

return a[_build_along_axis_index(a, indices, axis)]
a = dpnp.ravel(a)
axis = 0

usm_a = dpnp.get_usm_ndarray(a)
usm_ind = dpnp.get_usm_ndarray(indices)

usm_res = dpt.take_along_axis(usm_a, usm_ind, axis=axis, mode=mode)
return dpnp_array._create_from_usm_ndarray(usm_res)


def tril_indices(
Expand Down
99 changes: 60 additions & 39 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,13 @@ def test_values(self, arr_dt, idx_dt, ndim, values):
dpnp.put_along_axis(dp_a, dp_ai, values, axis)
assert_array_equal(np_a, dp_a)

@pytest.mark.parametrize("xp", [numpy, dpnp])
@pytest.mark.parametrize("dt", [bool, numpy.float32])
def test_invalid_indices_dtype(self, xp, dt):
a = xp.ones((10, 10))
ind = xp.ones(10, dtype=dt)
assert_raises(IndexError, xp.put_along_axis, a, ind, 7, axis=1)

@pytest.mark.parametrize("arr_dt", get_all_dtypes())
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
def test_broadcast(self, arr_dt, idx_dt):
Expand Down Expand Up @@ -673,66 +680,80 @@ def test_argequivalent(self, func, argfunc, kwargs):
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
@pytest.mark.parametrize("ndim", list(range(1, 4)))
def test_multi_dimensions(self, arr_dt, idx_dt, ndim):
np_a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
np_ai = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
ind = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
(1,) * (ndim - 1) + (4,)
)

dp_a = dpnp.array(np_a, dtype=arr_dt)
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
ia, iind = dpnp.array(a), dpnp.array(ind)

for axis in range(ndim):
expected = numpy.take_along_axis(np_a, np_ai, axis)
result = dpnp.take_along_axis(dp_a, dp_ai, axis)
result = dpnp.take_along_axis(ia, iind, axis)
expected = numpy.take_along_axis(a, ind, axis)
assert_array_equal(expected, result)

@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_invalid(self, xp):
def test_not_enough_indices(self, xp):
a = xp.ones((10, 10))
ai = xp.ones((10, 2), dtype=xp.intp)

# not enough indices
assert_raises(ValueError, xp.take_along_axis, a, xp.array(1), axis=1)

# bool arrays not allowed
assert_raises(
IndexError, xp.take_along_axis, a, ai.astype(bool), axis=1
)
@pytest.mark.parametrize("xp", [numpy, dpnp])
@pytest.mark.parametrize("dt", [bool, numpy.float32])
def test_invalid_indices_dtype(self, xp, dt):
a = xp.ones((10, 10))
ind = xp.ones((10, 2), dtype=dt)
assert_raises(IndexError, xp.take_along_axis, a, ind, axis=1)

# float arrays not allowed
assert_raises(
IndexError, xp.take_along_axis, a, ai.astype(numpy.float32), axis=1
)
@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_invalid_axis(self, xp):
a = xp.ones((10, 10))
ind = xp.ones((10, 2), dtype=xp.intp)
assert_raises(AxisError, xp.take_along_axis, a, ind, axis=10)

# invalid axis
assert_raises(AxisError, xp.take_along_axis, a, ai, axis=10)
@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_indices_ndim_axis_none(self, xp):
a = xp.ones((10, 10))
ind = xp.ones((10, 2), dtype=xp.intp)
assert_raises(ValueError, xp.take_along_axis, a, ind, axis=None)

@pytest.mark.parametrize("arr_dt", get_all_dtypes())
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
def test_empty(self, arr_dt, idx_dt):
np_a = numpy.ones((3, 4, 5), dtype=arr_dt)
np_ai = numpy.ones((3, 0, 5), dtype=idx_dt)

dp_a = dpnp.array(np_a, dtype=arr_dt)
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
def test_empty(self, a_dt, idx_dt):
a = numpy.ones((3, 4, 5), dtype=a_dt)
ind = numpy.ones((3, 0, 5), dtype=idx_dt)
ia, iind = dpnp.array(a), dpnp.array(ind)

expected = numpy.take_along_axis(np_a, np_ai, axis=1)
result = dpnp.take_along_axis(dp_a, dp_ai, axis=1)
result = dpnp.take_along_axis(ia, iind, axis=1)
expected = numpy.take_along_axis(a, ind, axis=1)
assert_array_equal(expected, result)

@pytest.mark.parametrize("arr_dt", get_all_dtypes())
@pytest.mark.parametrize("a_dt", get_all_dtypes(no_none=True))
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
def test_broadcast(self, arr_dt, idx_dt):
np_a = numpy.ones((3, 4, 1), dtype=arr_dt)
np_ai = numpy.ones((1, 2, 5), dtype=idx_dt)

dp_a = dpnp.array(np_a, dtype=arr_dt)
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
def test_broadcast(self, a_dt, idx_dt):
a = numpy.ones((3, 4, 1), dtype=a_dt)
ind = numpy.ones((1, 2, 5), dtype=idx_dt)
ia, iind = dpnp.array(a), dpnp.array(ind)

expected = numpy.take_along_axis(np_a, np_ai, axis=1)
result = dpnp.take_along_axis(dp_a, dp_ai, axis=1)
result = dpnp.take_along_axis(ia, iind, axis=1)
expected = numpy.take_along_axis(a, ind, axis=1)
assert_array_equal(expected, result)

def test_mode_wrap(self):
a = numpy.array([-2, -1, 0, 1, 2])
ind = numpy.array([-2, 2, -5, 4])
ia, iind = dpnp.array(a), dpnp.array(ind)

result = dpnp.take_along_axis(ia, iind, axis=0, mode="wrap")
expected = numpy.take_along_axis(a, ind, axis=0)
assert_array_equal(result, expected)

def test_mode_clip(self):
a = dpnp.array([-2, -1, 0, 1, 2])
ind = dpnp.array([-2, 2, -5, 4])

# numpy does not support keyword `mode`
result = dpnp.take_along_axis(a, ind, axis=0, mode="clip")
assert (result == dpnp.array([-2, 0, -2, 2])).all()


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
def test_choose():
Expand Down

0 comments on commit 5fda819

Please sign in to comment.