diff --git a/cunumeric/module.py b/cunumeric/module.py index 713467113..acce7c9e5 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -22,6 +22,7 @@ import numpy as np import opt_einsum as oe # type: ignore [import] +from cunumeric.coverage import is_implemented from numpy.core.numeric import ( # type: ignore [attr-defined] normalize_axis_tuple, ) @@ -36,6 +37,8 @@ from .utils import AxesPairLike, inner_modes, matmul_modes, tensordot_modes if TYPE_CHECKING: + from typing import Callable + import numpy.typing as npt from ._ufunc.ufunc import CastingKind @@ -2549,6 +2552,65 @@ def indices( return res_array +def mask_indices( + n: int, mask_func: Callable[[ndarray, int], ndarray], k: int = 0 +) -> tuple[ndarray, ...]: + """ + Return the indices to access (n, n) arrays, given a masking function. + + Assume `mask_func` is a function that, for a square array a of size + ``(n, n)`` with a possible offset argument `k`, when called as + ``mask_func(a, k)`` returns a new array with zeros in certain locations + (functions like :func:`cunumeric.triu` or :func:`cunumeric.tril` + do precisely this). Then this function returns the indices where + the non-zero values would be located. + + Parameters + ---------- + n : int + The returned indices will be valid to access arrays of shape (n, n). + mask_func : callable + A function whose call signature is similar to that of + :func:`cunumeric.triu`, :func:`cunumeric.tril`. + That is, ``mask_func(x, k)`` returns a boolean array, shaped like `x`. + `k` is an optional argument to the function. + k : scalar + An optional argument which is passed through to `mask_func`. Functions + like :func:`cunumeric.triu`, :func:`cunumeric,tril` + take a second argument that is interpreted as an offset. + + Returns + ------- + indices : tuple of arrays. + The `n` arrays of indices corresponding to the locations where + ``mask_func(np.ones((n, n)), k)`` is True. + + See Also + -------- + numpy.mask_indices + + Notes + ----- + WARNING: `mask_indices` expects `mask_function` to call cuNumeric functions + for good performance. In case non-cuNumeric functions are called by + `mask_function`, cuNumeric will have to materialize all data on the host + which might result in running out of system memory. + + Availability + -------- + Multiple GPUs, Multiple CPUs + """ + # this implementation is based on the Cupy + a = ones((n, n), dtype=bool) + if not is_implemented(mask_func): + runtime.warn( + "Calling non-cuNumeric functions in mask_func can result in bad " + "performance", + category=UserWarning, + ) + return mask_func(a, k).nonzero() + + def diag_indices(n: int, ndim: int = 2) -> tuple[ndarray, ...]: """ Return the indices to access the main diagonal of an array. diff --git a/docs/cunumeric/source/api/indexing.rst b/docs/cunumeric/source/api/indexing.rst index 67b17af0e..d911bd763 100644 --- a/docs/cunumeric/source/api/indexing.rst +++ b/docs/cunumeric/source/api/indexing.rst @@ -11,6 +11,7 @@ Generating index arrays diag_indices diag_indices_from + mask_indices tril_indices tril_indices_from triu_indices diff --git a/tests/integration/test_mask_indices.py b/tests/integration/test_mask_indices.py new file mode 100644 index 000000000..d577d03ab --- /dev/null +++ b/tests/integration/test_mask_indices.py @@ -0,0 +1,46 @@ +# Copyright 2021-2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest + +import cunumeric as num + +KS = [0, -1, 1, -2, 2] + + +def _test(mask_func, k): + num_f = getattr(num, mask_func) + np_f = getattr(np, mask_func) + + a = num.mask_indices(100, num_f, k=k) + an = np.mask_indices(100, np_f, k=k) + assert num.array_equal(a, an) + + +@pytest.mark.parametrize("k", KS, ids=lambda k: f"(k={k})") +def test_mask_indices_tril(k): + _test("tril", k) + + +@pytest.mark.parametrize("k", KS, ids=lambda k: f"(k={k})") +def test_indices_triu(k): + _test("triu", k) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(sys.argv))