From dd0a3415c2d788edcbb712ee08ea75d192275f00 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 21 Dec 2023 12:42:12 -0800 Subject: [PATCH] array api: add unique_* interfaces --- docs/jax.numpy.rst | 4 ++ jax/_src/numpy/setops.py | 49 +++++++++++++++++++- jax/experimental/array_api/_set_functions.py | 31 ++----------- jax/experimental/array_api/skips.txt | 4 -- jax/numpy/__init__.py | 4 ++ jax/numpy/__init__.pyi | 17 ++++++- tests/lax_numpy_test.py | 45 ++++++++++++++++++ 7 files changed, 120 insertions(+), 34 deletions(-) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index fd8102bac2e1..eed2a55f8086 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -403,6 +403,10 @@ namespace; they are listed below. uint8 union1d unique + unique_all + unique_counts + unique_inverse + unique_values unpackbits unravel_index unsignedinteger diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index b600b4d21c9b..c30851c61633 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -18,7 +18,7 @@ import math import operator from textwrap import dedent as _dedent -from typing import cast +from typing import cast, NamedTuple import numpy as np @@ -338,3 +338,50 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal axis_int: int = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()") return _unique(arr, axis_int, return_index, return_inverse, return_counts, equal_nan=equal_nan, size=size, fill_value=fill_value) + + +class _UniqueAllResult(NamedTuple): + values: Array + indices: Array + inverse_indices: Array + counts: Array + + +class _UniqueCountsResult(NamedTuple): + values: Array + counts: Array + + +class _UniqueInverseResult(NamedTuple): + values: Array + inverse_indices: Array + + +@_wraps(getattr(np, "unique_all", None)) +def unique_all(x: ArrayLike, /) -> _UniqueAllResult: + check_arraylike("unique_all", x) + values, indices, inverse_indices, counts = unique( + x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False) + inverse_indices = inverse_indices.reshape(np.shape(x)) + return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts) + + +@_wraps(getattr(np, "unique_counts", None)) +def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult: + check_arraylike("unique_counts", x) + values, counts = unique(x, return_counts=True, equal_nan=False) + return _UniqueCountsResult(values=values, counts=counts) + + +@_wraps(getattr(np, "unique_inverse", None)) +def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult: + check_arraylike("unique_inverse", x) + values, inverse_indices = unique(x, return_inverse=True, equal_nan=False) + inverse_indices = inverse_indices.reshape(np.shape(x)) + return _UniqueInverseResult(values=values, inverse_indices=inverse_indices) + + +@_wraps(getattr(np, "unique_values", None)) +def unique_values(x: ArrayLike, /) -> Array: + check_arraylike("unique_values", x) + return unique(x, equal_nan=False) diff --git a/jax/experimental/array_api/_set_functions.py b/jax/experimental/array_api/_set_functions.py index 95043790c37c..fd6b57aa84f4 100644 --- a/jax/experimental/array_api/_set_functions.py +++ b/jax/experimental/array_api/_set_functions.py @@ -12,47 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import NamedTuple import jax -class UniqueAllResult(NamedTuple): - values: jax.Array - indices: jax.Array - inverse_indices: jax.Array - counts: jax.Array - - -class UniqueCountsResult(NamedTuple): - values: jax.Array - counts: jax.Array - - -class UniqueInverseResult(NamedTuple): - values: jax.Array - inverse_indices: jax.Array - - def unique_all(x, /): """Returns the unique elements of an input array x, the first occurring indices for each unique element in x, the indices from the set of unique elements that reconstruct x, and the corresponding counts for each unique element in x.""" - values, indices, inverse_indices, counts = jax.numpy.unique( - x, return_index=True, return_inverse=True, return_counts=True) - # jnp.unique() flattens inverse indices - inverse_indices = inverse_indices.reshape(x.shape) - return UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts) + return jax.numpy.unique_all(x) def unique_counts(x, /): """Returns the unique elements of an input array x and the corresponding counts for each unique element in x.""" - values, counts = jax.numpy.unique(x, return_counts=True) - return UniqueCountsResult(values=values, counts=counts) + return jax.numpy.unique_counts(x) def unique_inverse(x, /): """Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x.""" - values, inverse_indices = jax.numpy.unique(x, return_inverse=True) - inverse_indices = inverse_indices.reshape(x.shape) - return UniqueInverseResult(values=values, inverse_indices=inverse_indices) + return jax.numpy.unique_inverse(x) def unique_values(x, /): diff --git a/jax/experimental/array_api/skips.txt b/jax/experimental/array_api/skips.txt index 3a809d0317c1..3142e8b2320d 100644 --- a/jax/experimental/array_api/skips.txt +++ b/jax/experimental/array_api/skips.txt @@ -14,10 +14,6 @@ array_api_tests/test_linalg.py::test_matrix_power array_api_tests/test_linalg.py::test_solve # JAX's NaN sorting doesn't match specification -array_api_tests/test_set_functions.py::test_unique_all -array_api_tests/test_set_functions.py::test_unique_counts -array_api_tests/test_set_functions.py::test_unique_inverse -array_api_tests/test_set_functions.py::test_unique_values array_api_tests/test_sorting_functions.py::test_argsort # fft test suite is buggy as of 83f0bcdc diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 0de765d80760..f04707d0bee2 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -318,6 +318,10 @@ setxor1d as setxor1d, union1d as union1d, unique as unique, + unique_all as unique_all, + unique_counts as unique_counts, + unique_inverse as unique_inverse, + unique_values as unique_values, ) from jax._src.numpy.ufuncs import ( diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index d8ee1efad55d..6c7a50b2db4d 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -1,7 +1,7 @@ from __future__ import annotations -from typing import Any, Callable, Literal, Optional, Sequence, TypeVar, Union, overload +from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, TypeVar, Union, overload from jax._src import core as _core from jax._src import dtypes as _dtypes @@ -792,11 +792,26 @@ def union1d( size: Optional[int] = ..., fill_value: Optional[ArrayLike] = ..., ) -> Array: ... +class _UniqueAllResult(NamedTuple): + values: Array + indices: Array + inverse_indices: Array + counts: Array +class _UniqueCountsResult(NamedTuple): + values: Array + counts: Array +class _UniqueInverseResult(NamedTuple): + values: Array + inverse_indices: Array def unique(ar: ArrayLike, return_index: bool = ..., return_inverse: bool = ..., return_counts: bool = ..., axis: Optional[int] = ..., *, equal_nan: bool = ..., size: Optional[int] = ..., fill_value: Optional[ArrayLike] = ... ): ... +def unique_all(x: ArrayLike, /) -> _UniqueAllResult: ... +def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult: ... +def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult: ... +def unique_values(x: ArrayLike, /) -> Array: ... def unpackbits( a: ArrayLike, axis: Optional[int] = ..., diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index fcf8876218e6..9e83be82cf26 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1745,6 +1745,51 @@ def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_co jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueAll(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + def np_fun(x): + values, indices, inverse_indices, counts = np.unique( + x, return_index=True, return_inverse=True, return_counts=True) + return values, indices, inverse_indices.reshape(np.shape(x)), counts + else: + np_fun = np.unique_all + self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueCounts(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fun = lambda x: np.unique(x, return_counts=True) + else: + np_fun = np.unique_counts + self._CheckAgainstNumpy(jnp.unique_counts, np_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueInverse(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + def np_fun(x): + values, inverse_indices = np.unique(x, return_inverse=True) + return values, inverse_indices.reshape(np.shape(x)) + else: + np_fun = np.unique_inverse + self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker) + + @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) + def testUniqueValues(self, shape, dtype): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if jtu.numpy_version() < (2, 0, 0): + np_fun = np.unique + else: + np_fun = np.unique_values + self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker) + @jtu.sample_product( [dict(shape=shape, axis=axis) for shape in nonempty_array_shapes