Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for take_along_axis and put_along_axis #436

Merged
merged 23 commits into from
Jul 30, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
fe5c679
adding support for take_along_axis
ipdemes Jun 27, 2022
1cb2aed
adding support for put_along_axis
ipdemes Jun 27, 2022
1d5acf2
fixing tests
ipdemes Jun 29, 2022
e5d7cd8
adding some comments
ipdemes Jun 29, 2022
250f33f
addressing PR comments
ipdemes Jul 4, 2022
83bf3ca
addressing PR comments
ipdemes Jul 5, 2022
231cf11
Merge remote-tracking branch 'origin/branch-22.07' into take_along_axis
ipdemes Jul 6, 2022
b9c2787
fixing logic for the case when rhs and index_arrays are both futures …
ipdemes Jul 6, 2022
f9321ed
reducing sizes of arrays
ipdemes Jul 7, 2022
5f1921a
adding missing @
ipdemes Jul 11, 2022
985e2f6
updating documentation for new API
ipdemes Jul 19, 2022
61517e9
Merge remote-tracking branch 'origin/branch-22.07' into take_along_axis
ipdemes Jul 19, 2022
13aa67f
fixing documentation
ipdemes Jul 26, 2022
8318a4d
addressing PR comments
ipdemes Jul 27, 2022
02596b9
Merge branch 'branch-22.07' into take_along_axis
manopapad Jul 28, 2022
cdaebb5
addressing PR comments
ipdemes Jul 28, 2022
73711dc
Merge branch 'take_along_axis' of github.com:ipdemes/cunumeric into t…
ipdemes Jul 28, 2022
d7d6d08
addressing PR comments
ipdemes Jul 28, 2022
90ae8f4
Merge remote-tracking branch 'origin/branch-22.07' into take_along_axis
ipdemes Jul 28, 2022
0e54d56
fixing mypy errors
ipdemes Jul 29, 2022
6aa0b8e
fixing mypy errors
ipdemes Jul 29, 2022
2be0439
addressing PR comments
ipdemes Jul 30, 2022
d8623c2
removing undefined behavior from advanced_indexing test
ipdemes Jul 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,12 @@ def set_item(self, key, rhs):
rhs = rhs._copy_store(rhs.base)
rhs_store = rhs.base

# the case when rhs is a scalar and indices array contains
# a single value
if rhs.base.kind == Future:
lhs.copy(rhs, deep=True)
return

copy = self.context.create_copy()
copy.set_target_indirect_out_of_range(False)

Expand Down
135 changes: 135 additions & 0 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import opt_einsum as oe # type: ignore [import]
from numpy.core.multiarray import normalize_axis_index
from numpy.core.numeric import ( # type: ignore [attr-defined]
normalize_axis_tuple,
)
Expand Down Expand Up @@ -2498,6 +2499,140 @@ def take(
return a.take(indices=indices, axis=axis, out=out, mode=mode)


def _fill_fancy_index_for_along_axis_routines(
a_shape: tuple, axis: int, indices: ndarray
manopapad marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple:
manopapad marked this conversation as resolved.
Show resolved Hide resolved

# the logic below is base on the cupy implementation of
# the *_along_axis routines
ndim = len(a_shape)
fancy_index = []
for i, n in enumerate(a_shape):
if i == axis:
fancy_index.append(indices)
else:
ind_shape = (1,) * i + (-1,) + (1,) * (ndim - i - 1)
fancy_index.append(arange(n).reshape(ind_shape))
return tuple(fancy_index)


@add_boilerplate("a", "indices")
def take_along_axis(a: ndarray, indices: ndarray, axis: int) -> ndarray:
"""
Take values from the input array by matching 1d index and data slices.

This iterates over matching 1d slices oriented along the specified axis in
the index and data arrays, and uses the former to look up values in the
latter. These slices can be different lengths.

Functions returning an index along an axis, like `argsort` and
`argpartition`, produce suitable indices for this function.

Parameters
----------
arr : ndarray (Ni..., M, Nk...)
Source array
indices : ndarray (Ni..., J, Nk...)
Indices to take along each 1d slice of `arr`. This must match the
dimension of arr, but dimensions Ni and Nj only need to broadcast
against `arr`.
axis : int
The axis to take 1d slices along. If axis is None, the input array is
treated as if it had first been flattened to 1d, for consistency with
`sort` and `argsort`.

Returns
-------
out: ndarray (Ni..., J, Nk...)
The indexed result.

See Also
--------
numpy.take_along_axis

Availability
--------
Multiple GPUs, Multiple CPUs
"""
if not np.issubdtype(indices.dtype, np.integer):
raise IndexError("`indices` must be an integer array")
manopapad marked this conversation as resolved.
Show resolved Hide resolved

if axis is None:
a = a.ravel()
axis = 0
if indices.ndim != 1:
raise ValueError("indices must be 1D if axis=None")
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
else:
axis = normalize_axis_index(axis, a.ndim)

ndim = a.ndim
if ndim != indices.ndim:
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"`indices` and `a` must have the same number of dimensions"
)
return a[_fill_fancy_index_for_along_axis_routines(a.shape, axis, indices)]


add_boilerplate("a", "indices", "values")


def put_along_axis(
manopapad marked this conversation as resolved.
Show resolved Hide resolved
a: ndarray, indices: ndarray, values: ndarray, axis: int
) -> ndarray:
"""
Put values into the destination array by matching 1d index and data slices.

This iterates over matching 1d slices oriented along the specified axis in
the index and data arrays, and uses the former to place values into the
latter. These slices can be different lengths.

Functions returning an index along an axis, like `argsort` and
`argpartition`, produce suitable indices for this function.

Parameters
----------
arr : ndarray (Ni..., M, Nk...)
Destination array.
indices : ndarray (Ni..., J, Nk...)
Indices to change along each 1d slice of `arr`. This must match the
dimension of arr, but dimensions in Ni and Nj may be 1 to broadcast
against `arr`.
values : array_like (Ni..., J, Nk...)
values to insert at those indices. Its shape and dimension are
broadcast to match that of `indices`.
axis : int
The axis to take 1d slices along. If axis is None, the destination
array is treated as if a flattened 1d view had been created of it.

manopapad marked this conversation as resolved.
Show resolved Hide resolved
See Also
--------
numpy.put_along_axis

Availability
--------
Multiple GPUs, Multiple CPUs

"""
if not np.issubdtype(indices.dtype, np.integer):
raise IndexError("`indices` must be an integer array")
manopapad marked this conversation as resolved.
Show resolved Hide resolved

if axis is None:
a = a.ravel()
manopapad marked this conversation as resolved.
Show resolved Hide resolved
axis = 0
if indices.ndim != 1:
raise ValueError("indices must be 1D if axis=None")
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
else:
axis = normalize_axis_index(axis, a.ndim)

if a.ndim != indices.ndim:
raise ValueError(
"`indices` and `a` must have the same number of dimensions"
)
a[
_fill_fancy_index_for_along_axis_routines(a.shape, axis, indices)
] = values
manopapad marked this conversation as resolved.
Show resolved Hide resolved


@add_boilerplate("a")
def choose(
a: ndarray,
Expand Down
75 changes: 75 additions & 0 deletions tests/integration/test_put_along_axis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2022 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import pytest
from test_tools.generators import mk_seq_array

import cunumeric as num
from legate.core import LEGATE_MAX_DIM


def test_3d():

x = mk_seq_array(np, (256, 256, 100))
x_num = mk_seq_array(num, (256, 256, 100))

indices = mk_seq_array(np, (256, 256, 10)) % 100
indices_num = num.array(indices)

res = np.put_along_axis(x, indices, -10, -1)
manopapad marked this conversation as resolved.
Show resolved Hide resolved
res_num = num.put_along_axis(x_num, indices_num, -10, -1)
assert np.array_equal(res_num, res)


def test_None_axis():
manopapad marked this conversation as resolved.
Show resolved Hide resolved
x = mk_seq_array(np, (256, 256, 100))
x_num = mk_seq_array(num, (256, 256, 100))

# testig the case when axis = None
indices = mk_seq_array(np, (256,))
indices_num = num.array(indices)

res = np.put_along_axis(x, indices, 99, None)
res_num = num.put_along_axis(x_num, indices_num, 99, None)
assert np.array_equal(res_num, res)


N = 50


@pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1))
def test_ndim(ndim):
shape = (N,) * ndim
np_arr = mk_seq_array(np, shape)
num_arr = mk_seq_array(num, shape)
shape_idx = (1,) * ndim
np_indices = mk_seq_array(np, shape_idx) % N
num_indices = mk_seq_array(num, shape_idx) % N
for axis in range(ndim):
res_np = np.put_along_axis(np_arr, np_indices, 8, axis=axis)
manopapad marked this conversation as resolved.
Show resolved Hide resolved
res_num = num.put_along_axis(num_arr, num_indices, 8, axis=axis)
assert np.array_equal(res_num, res_np)
np_indices = mk_seq_array(np, (30,))
num_indices = mk_seq_array(num, (30,))
res_np = np.put_along_axis(np_arr, np_indices, 11, None)
manopapad marked this conversation as resolved.
Show resolved Hide resolved
res_num = num.put_along_axis(num_arr, num_indices, 11, None)
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
assert np.array_equal(res_num, res_np)


if __name__ == "__main__":
import sys

sys.exit(pytest.main(sys.argv))
75 changes: 75 additions & 0 deletions tests/integration/test_take_along_axis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2022 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import pytest
from test_tools.generators import mk_seq_array

import cunumeric as num
from legate.core import LEGATE_MAX_DIM


def test_3d():

x = mk_seq_array(np, (256, 256, 100))
x_num = mk_seq_array(num, (256, 256, 100))

indices = mk_seq_array(np, (256, 256, 10)) % 100
indices_num = num.array(indices)

res = np.take_along_axis(x, indices, -1)
res_num = num.take_along_axis(x_num, indices_num, -1)
assert np.array_equal(res_num, res)


def test_None_axis():
manopapad marked this conversation as resolved.
Show resolved Hide resolved
x = mk_seq_array(np, (256, 256, 100))
x_num = mk_seq_array(num, (256, 256, 100))

# testig the case when axis = None
indices = mk_seq_array(np, (256,))
indices_num = num.array(indices)

res = np.take_along_axis(x, indices, None)
res_num = num.take_along_axis(x_num, indices_num, None)
assert np.array_equal(res_num, res)


N = 50


@pytest.mark.parametrize("ndim", range(1, LEGATE_MAX_DIM + 1))
def test_ndim(ndim):
shape = (N,) * ndim
np_arr = mk_seq_array(np, shape)
num_arr = mk_seq_array(num, shape)
shape_idx = (1,) * ndim
np_indices = mk_seq_array(np, shape_idx) % N
num_indices = mk_seq_array(num, shape_idx) % N
for axis in range(ndim):
res_np = np.take_along_axis(np_arr, np_indices, axis=axis)
res_num = num.take_along_axis(num_arr, num_indices, axis=axis)
assert np.array_equal(res_num, res_np)
np_indices = mk_seq_array(np, (30,))
num_indices = mk_seq_array(num, (30,))
res_np = np.take_along_axis(np_arr, np_indices, None)
res_num = num.take_along_axis(num_arr, num_indices, None)
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
assert np.array_equal(res_num, res_np)


if __name__ == "__main__":
import sys

sys.exit(pytest.main(sys.argv))