Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get rid of falling back on numpy in dpnp.put #1838

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 65 additions & 53 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def nonzero(a):
[2, 1]])

A common use for ``nonzero`` is to find the indices of an array, where
a condition is ``True.`` Given an array `a`, the condition `a` > 3 is
a condition is ``True``. Given an array `a`, the condition `a` > 3 is
a boolean array and since ``False`` is interpreted as ``0``,
``np.nonzero(a > 3)`` yields the indices of the `a` where the condition is
true.
Expand Down Expand Up @@ -736,25 +736,33 @@ def place(x, mask, vals, /):
return call_origin(numpy.place, x, mask, vals, dpnp_inplace=True)


# pylint: disable=redefined-outer-name
def put(a, indices, vals, /, *, axis=None, mode="wrap"):
def put(a, ind, v, /, *, axis=None, mode="wrap"):
"""
Puts values of an array into another array along a given axis.

For full documentation refer to :obj:`numpy.put`.

Limitations
-----------
Parameters `a` and `indices` are supported either as :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Parameter `indices` is supported as 1-D array of integer data type.
Parameter `vals` must be broadcastable to the shape of `indices`
and has the same data type as `a` if it is as :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Parameter `mode` is supported with ``wrap``, the default, and ``clip``
values.
Parameter `axis` is supported as integer only.
Otherwise the function will be executed sequentially on CPU.
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
The array the values will be put into.
ind : {array_like}
Target indices, interpreted as integers.
v : {scalar, array_like}
Values to be put into `a`. Must be broadcastable to the result shape
``a.shape[:axis] + ind.shape + a.shape[axis+1:]``.
axis {None, int}, optional
The axis along which the values will be placed. If `a` is 1-D array,
this argument is optional.
Default: ``None``.
mode : {'wrap', 'clip'}, optional
Specifies how out-of-bounds indices will behave.

- 'wrap': clamps indices to (``-n <= i < n``), then wraps negative
indices.
- 'clip': clips indices to (``0 <= i < n``).

Default: ``'wrap'``.

See Also
--------
Expand All @@ -774,49 +782,53 @@ def put(a, indices, vals, /, *, axis=None, mode="wrap"):
Examples
--------
>>> import dpnp as np
>>> x = np.arange(5)
>>> indices = np.array([0, 1])
>>> np.put(x, indices, [-44, -55])
>>> x
array([-44, -55, 2, 3, 4])
>>> a = np.arange(5)
>>> np.put(a, [0, 2], [-44, -55])
>>> a
array([-44, 1, -55, 3, 4])

>>> x = np.arange(5)
>>> indices = np.array([22])
>>> np.put(x, indices, -5, mode='clip')
>>> x
>>> a = np.arange(5)
>>> np.put(a, 22, -5, mode='clip')
>>> a
array([ 0, 1, 2, 3, -5])

"""

if dpnp.is_supported_array_type(a) and dpnp.is_supported_array_type(
indices
):
if indices.ndim != 1 or not dpnp.issubdtype(
indices.dtype, dpnp.integer
):
pass
elif mode not in ("clip", "wrap"):
pass
elif axis is not None and not isinstance(axis, int):
raise TypeError(f"`axis` must be of integer type, got {type(axis)}")
# TODO: remove when #1382(dpctl) is solved
elif dpnp.is_supported_array_type(vals) and a.dtype != vals.dtype:
pass
else:
if axis is None and a.ndim > 1:
a = dpnp.reshape(a, -1)
dpt_array = dpnp.get_usm_ndarray(a)
dpt_indices = dpnp.get_usm_ndarray(indices)
dpt_vals = (
dpnp.get_usm_ndarray(vals)
if isinstance(vals, dpnp_array)
else vals
)
return dpt.put(
dpt_array, dpt_indices, dpt_vals, axis=axis, mode=mode
)
dpnp.check_supported_arrays_type(a)

if not dpnp.is_supported_array_type(ind):
ind = dpnp.asarray(
ind, dtype=dpnp.intp, sycl_queue=a.sycl_queue, usm_type=a.usm_type
)
elif not dpnp.issubdtype(ind.dtype, dpnp.integer):
ind = dpnp.astype(ind, dtype=dpnp.intp, casting="safe")
ind = dpnp.ravel(ind)

if not dpnp.is_supported_array_type(v):
v = dpnp.asarray(
v, dtype=a.dtype, sycl_queue=a.sycl_queue, usm_type=a.usm_type
)
if v.size == 0:
return

if not (axis is None or isinstance(axis, int)):
raise TypeError(f"`axis` must be of integer type, got {type(axis)}")

in_a = a
if axis is None and a.ndim > 1:
a = dpnp.ravel(in_a)

if mode not in ("wrap", "clip"):
raise ValueError(
f"clipmode must be one of 'clip' or 'wrap' (got '{mode}')"
)

return call_origin(numpy.put, a, indices, vals, mode, dpnp_inplace=True)
usm_a = dpnp.get_usm_ndarray(a)
usm_ind = dpnp.get_usm_ndarray(ind)
usm_v = dpnp.get_usm_ndarray(v)
dpt.put(usm_a, usm_ind, usm_v, axis=axis, mode=mode)
if in_a is not a:
in_a[:] = a.reshape(in_a.shape, copy=False)


# pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -1194,7 +1206,7 @@ def triu_indices(n, k=0, m=None):
-------
inds : tuple, shape(2) of ndarrays, shape(`n`)
The indices for the triangle. The returned tuple contains two arrays,
each with the indices along one dimension of the array. Can be used
each with the indices along one dimension of the array. Can be used
to slice a ndarray of shape(`n`, `n`).
"""

Expand Down
8 changes: 0 additions & 8 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,6 @@ def get_integer_dtypes():
return [dpnp.int32, dpnp.int64]


def get_integer_dtypes():
"""
Build a list of integer types supported by DPNP.
"""

return [dpnp.int32, dpnp.int64]


def get_complex_dtypes(device=None):
"""
Build a list of complex types supported by DPNP based on device capabilities.
Expand Down
Loading
Loading