Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding mask_indices routine #426

Merged
merged 13 commits into from
Jul 28, 2022
62 changes: 62 additions & 0 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import opt_einsum as oe # type: ignore [import]
from cunumeric.coverage import is_implemented
from numpy.core.numeric import ( # type: ignore [attr-defined]
normalize_axis_tuple,
)
Expand All @@ -36,6 +37,8 @@
from .utils import AxesPairLike, inner_modes, matmul_modes, tensordot_modes

if TYPE_CHECKING:
from typing import Callable

import numpy.typing as npt

from ._ufunc.ufunc import CastingKind
Expand Down Expand Up @@ -2549,6 +2552,65 @@ def indices(
return res_array


def mask_indices(
n: int, mask_func: Callable[[ndarray, int], ndarray], k: int = 0
) -> tuple[ndarray, ...]:
"""
Return the indices to access (n, n) arrays, given a masking function.

Assume `mask_func` is a function that, for a square array a of size
``(n, n)`` with a possible offset argument `k`, when called as
``mask_func(a, k)`` returns a new array with zeros in certain locations
(functions like :func:`cunumeric.triu` or :func:`cunumeric.tril`
do precisely this). Then this function returns the indices where
the non-zero values would be located.

Parameters
----------
n : int
The returned indices will be valid to access arrays of shape (n, n).
mask_func : callable
A function whose call signature is similar to that of
:func:`cunumeric.triu`, :func:`cunumeric.tril`.
That is, ``mask_func(x, k)`` returns a boolean array, shaped like `x`.
`k` is an optional argument to the function.
k : scalar
An optional argument which is passed through to `mask_func`. Functions
like :func:`cunumeric.triu`, :func:`cunumeric,tril`
take a second argument that is interpreted as an offset.

Returns
-------
indices : tuple of arrays.
The `n` arrays of indices corresponding to the locations where
``mask_func(np.ones((n, n)), k)`` is True.

See Also
--------
numpy.mask_indices

Notes
-----
WARNING: `mask_indices` expects `mask_function` to call cuNumeric functions
for good performance. In case non-cuNumeric functions are called by
`mask_function`, cuNumeric will have to materialize all data on the host
which might result in running out of system memory.

Availability
--------
Multiple GPUs, Multiple CPUs
"""
# this implementation is based on the Cupy
a = ones((n, n), dtype=bool)
if not is_implemented(mask_func):
manopapad marked this conversation as resolved.
Show resolved Hide resolved
runtime.warn(
"Calling non-cuNumeric functions in mask_func can result in bad "
"performance",
category=UserWarning,
)
return mask_func(a, k).nonzero()


def diag_indices(n: int, ndim: int = 2) -> tuple[ndarray, ...]:
"""
Return the indices to access the main diagonal of an array.
Expand Down
1 change: 1 addition & 0 deletions docs/cunumeric/source/api/indexing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Generating index arrays

diag_indices
diag_indices_from
mask_indices
tril_indices
tril_indices_from
triu_indices
Expand Down
46 changes: 46 additions & 0 deletions tests/integration/test_mask_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2021-2022 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import pytest

import cunumeric as num

KS = [0, -1, 1, -2, 2]


def _test(mask_func, k):
num_f = getattr(num, mask_func)
np_f = getattr(np, mask_func)

a = num.mask_indices(100, num_f, k=k)
an = np.mask_indices(100, np_f, k=k)
assert num.array_equal(a, an)


@pytest.mark.parametrize("k", KS, ids=lambda k: f"(k={k})")
def test_mask_indices_tril(k):
_test("tril", k)


@pytest.mark.parametrize("k", KS, ids=lambda k: f"(k={k})")
def test_indices_triu(k):
_test("triu", k)


if __name__ == "__main__":
import sys

sys.exit(pytest.main(sys.argv))