Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding mask_indices routine #426

Merged
merged 13 commits into from
Jul 28, 2022
60 changes: 60 additions & 0 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import opt_einsum as oe # type: ignore [import]
from cunumeric.coverage import is_implemented
from numpy.core.numeric import ( # type: ignore [attr-defined]
normalize_axis_tuple,
)
Expand All @@ -36,6 +37,8 @@
from .utils import AxesPairLike, inner_modes, matmul_modes, tensordot_modes

if TYPE_CHECKING:
from typing import Callable

import numpy.typing as npt

from ._ufunc.ufunc import CastingKind
Expand Down Expand Up @@ -2381,6 +2384,63 @@ def indices(
return res_array


def mask_indices(
n: int, mask_func: Callable[[ndarray, int], ndarray], k: int = 0
) -> tuple[ndarray, ...]:
"""
Return the indices to access (n, n) arrays, given a masking function.

Assume `mask_func` is a function that, for a square array a of size
``(n, n)`` with a possible offset argument `k`, when called as
``mask_func(a, k)`` returns a new array with zeros in certain locations
(functions like `triu` or `tril` do precisely this). Then this function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

functions like :func:`cunumeric.triu` or :func:`cunumeric.tril`

returns the indices where the non-zero values would be located.

Parameters
----------
n : int
The returned indices will be valid to access arrays of shape (n, n).
mask_func : callable
A function whose call signature is similar to that of `triu`, `tril`.
manopapad marked this conversation as resolved.
Show resolved Hide resolved
That is, ``mask_func(x, k)`` returns a boolean array, shaped like `x`.
`k` is an optional argument to the function.
k : scalar
An optional argument which is passed through to `mask_func`. Functions
like `triu`, `tril` take a second argument that is interpreted as an
offset.

Returns
-------
indices : tuple of arrays.
The `n` arrays of indices corresponding to the locations where
``mask_func(np.ones((n, n)), k)`` is True.

See Also
--------
numpy.mask_indices

Notes
-----
WARNING: `mask_indices` expects `mask_function` to call cuNumeric functions
for good performance. In case non-cuNumeric fucntion is called by
manopapad marked this conversation as resolved.
Show resolved Hide resolved
`mask_function`, cuNumeric will have to materialize all data on the host
which might result in allocating to much of a system memory.
manopapad marked this conversation as resolved.
Show resolved Hide resolved

Availability
--------
Multiple GPUs, Multiple CPUs
"""
# this implementation is based on the Cupy
a = ones((n, n), dtype=bool)
if not is_implemented(mask_func):
manopapad marked this conversation as resolved.
Show resolved Hide resolved
runtime.warn(
"mask_func is not a cuNumeric function which can result in "
"bad performance",
category=UserWarning,
)
return mask_func(a, k).nonzero()


def diag_indices(n: int, ndim: int = 2) -> tuple[ndarray, ...]:
"""
Return the indices to access the main diagonal of an array.
Expand Down
1 change: 1 addition & 0 deletions docs/cunumeric/source/api/indexing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Generating index arrays

diag_indices
diag_indices_from
mask_indices
indices
nonzero
where
Expand Down
46 changes: 46 additions & 0 deletions tests/integration/test_mask_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2021-2022 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import pytest

import cunumeric as num

KS = [0, -1, 1, -2, 2]


def _test(mask_func, k):
num_f = getattr(num, mask_func)
np_f = getattr(np, mask_func)

a = num.mask_indices(100, num_f, k=k)
an = np.mask_indices(100, np_f, k=k)
assert num.array_equal(a, an)


@pytest.mark.parametrize("k", KS, ids=lambda k: f"(k={k})")
def test_mask_indices_tril(k):
_test("tril", k)


@pytest.mark.parametrize("k", KS, ids=lambda k: f"(k={k})")
def test_indices_triu(k):
_test("triu", k)


if __name__ == "__main__":
import sys

sys.exit(pytest.main(sys.argv))