From 8b6251667617de4509188a944bc441aef1ed55b7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 4 Jan 2024 14:21:25 -0800 Subject: [PATCH] [array api] add stable & descending params to jnp.sort & jnp.argsort --- CHANGELOG.md | 2 + jax/_src/numpy/lax_numpy.py | 51 ++++++++----- .../array_api/_sorting_functions.py | 12 +-- jax/experimental/array_api/skips.txt | 3 - jax/numpy/__init__.pyi | 12 ++- tests/lax_numpy_test.py | 76 +++++++++++++++---- 6 files changed, 107 insertions(+), 49 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a533241f8ee..98308a6f128a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ Remember to align the itemized text with the first line of an item within a list devices to create `Sharding`s during lowering. This is a temporary state until we can create `Sharding`s without physical devices. + * {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable` + and `descending` arguments. * Deprecations & Removals * A number of previously deprecated functions have been removed, following a standard 3+ month deprecation cycle (see {ref}`api-compatibility`). diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index cf6058bb48a6..f261a003362d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3901,23 +3901,28 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False): @util._wraps(np.sort) -@partial(jit, static_argnames=('axis', 'kind', 'order')) +@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def sort( a: ArrayLike, axis: int | None = -1, - kind: str = "quicksort", - order: None = None, + kind: None = None, + order: None = None, *, + stable: bool = True, + descending: bool = False, ) -> Array: util.check_arraylike("sort", a) - if kind != 'quicksort': + if kind is not None: warnings.warn("'kind' argument to sort is ignored.") if order is not None: raise ValueError("'order' argument to sort is not supported.") - if axis is None: - return lax.sort(ravel(a), dimension=0) + arr = ravel(a) + axis = 0 else: - return lax.sort(asarray(a), dimension=_canonicalize_axis(axis, ndim(a))) + arr = asarray(a) + dimension = _canonicalize_axis(axis, arr.ndim) + result = lax.sort(arr, dimension=dimension, is_stable=stable) + return lax.rev(result, dimensions=[dimension]) if descending else result @util._wraps(np.sort_complex) @@ -3953,29 +3958,37 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A @util._wraps(np.argsort, lax_description=_ARGSORT_DOC) -@partial(jit, static_argnames=('axis', 'kind', 'order')) +@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending')) def argsort( a: ArrayLike, axis: int | None = -1, - kind: str = "stable", + kind: None = None, order: None = None, + *, stable: bool = True, + descending: bool = False, ) -> Array: util.check_arraylike("argsort", a) arr = asarray(a) - if kind != 'stable': - warnings.warn("'kind' argument to argsort is ignored; only 'stable' sorts " - "are supported.") + if kind is not None: + warnings.warn("'kind' argument to argsort is ignored.") if order is not None: raise ValueError("'order' argument to argsort is not supported.") - if axis is None: - return argsort(arr.ravel(), 0) + arr = ravel(arr) + axis = 0 else: - axis_num = _canonicalize_axis(axis, arr.ndim) - use_64bit_index = not core.is_constant_dim(arr.shape[axis_num]) or arr.shape[axis_num] >= (1 << 31) - iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, arr.shape, axis_num) - _, perm = lax.sort_key_val(arr, iota, dimension=axis_num) - return perm + arr = asarray(a) + dimension = _canonicalize_axis(axis, arr.ndim) + use_64bit_index = not core.is_constant_dim(arr.shape[dimension]) or arr.shape[dimension] >= (1 << 31) + iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, arr.shape, dimension) + # For stable descending sort, we reverse the array and indices to ensure that + # duplicates remain in their original order when the final indices are reversed. + # For non-stable descending sort, we can avoid these extra operations. + if descending and stable: + arr = lax.rev(arr, dimensions=[dimension]) + iota = lax.rev(iota, dimensions=[dimension]) + _, indices = lax.sort_key_val(arr, iota, dimension=dimension, is_stable=stable) + return lax.rev(indices, dimensions=[dimension]) if descending else indices @util._wraps(np.partition, lax_description=""" diff --git a/jax/experimental/array_api/_sorting_functions.py b/jax/experimental/array_api/_sorting_functions.py index 139593f203cf..4c64480d39a6 100644 --- a/jax/experimental/array_api/_sorting_functions.py +++ b/jax/experimental/array_api/_sorting_functions.py @@ -19,18 +19,10 @@ def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: """Returns the indices that sort an array x along a specified axis.""" - del stable # unused - if descending: - return jax.numpy.argsort(-x, axis=axis) - else: - return jax.numpy.argsort(x, axis=axis) + return jax.numpy.argsort(x, axis=axis, descending=descending, stable=stable) def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: """Returns a sorted copy of an input array x.""" - del stable # unused - result = jax.numpy.sort(x, axis=axis) - if descending: - return jax.lax.rev(result, dimensions=[axis + x.ndim if axis < 0 else axis]) - return result + return jax.numpy.sort(x, axis=axis, descending=descending, stable=stable) diff --git a/jax/experimental/array_api/skips.txt b/jax/experimental/array_api/skips.txt index 3142e8b2320d..e2bcd093eb47 100644 --- a/jax/experimental/array_api/skips.txt +++ b/jax/experimental/array_api/skips.txt @@ -13,8 +13,5 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_linalg.py::test_matrix_power array_api_tests/test_linalg.py::test_solve -# JAX's NaN sorting doesn't match specification -array_api_tests/test_sorting_functions.py::test_argsort - # fft test suite is buggy as of 83f0bcdc array_api_tests/test_fft.py diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 4438209396db..b133ca143010 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -74,9 +74,12 @@ def argmin( def argpartition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ... def argsort( a: ArrayLike, - axis: Optional[int] = -1, - kind: str = "stable", + axis: Optional[int] = ..., + kind: None = ..., order: None = ..., + *, + stable: bool = ..., + descending: bool = ..., ) -> Array: ... def argwhere( a: ArrayLike, @@ -701,8 +704,11 @@ sometrue = any def sort( a: ArrayLike, axis: Optional[int] = ..., - kind: str = ..., + kind: None = ..., order: None = ..., + *, + stable: bool = ..., + descending: bool = ..., ) -> Array: ... def sort_complex(a: ArrayLike) -> Array: ... def split( diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index e85e4ccb9708..622397b27656 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -79,6 +79,7 @@ number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes all_dtypes = number_dtypes + bool_dtypes +NO_VALUE = object() python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_] @@ -3771,21 +3772,41 @@ def testArangeTypes(self): @jtu.sample_product( [dict(shape=shape, axis=axis) for shape in nonzerodim_shapes - for axis in (None, *range(len(shape))) + for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) ], + stable=[True, False], dtype=all_dtypes, ) - def testSort(self, dtype, shape, axis): - rng = jtu.rand_some_equal(self.rng()) + def testSort(self, dtype, shape, axis, stable): + rng = jtu.rand_some_equal(self.rng()) if stable else jtu.rand_some_inf_and_nan(self.rng()) args_maker = lambda: [rng(shape, dtype)] - jnp_fun = jnp.sort - np_fun = np.sort - if axis is not None: - jnp_fun = partial(jnp_fun, axis=axis) - np_fun = partial(np_fun, axis=axis) + kwds = {} if axis is NO_VALUE else {'axis': axis} + + def np_fun(arr): + # Note: numpy sort fails on NaN and Inf values with bfloat16 + dtype = arr.dtype + if arr.dtype == jnp.bfloat16: + arr = arr.astype('float32') + # TODO(jakevdp): switch to stable=stable when supported by numpy. + result = np.sort(arr, kind='stable' if stable else None, **kwds) + with jtu.ignore_warning(category=RuntimeWarning, message='invalid value'): + return result.astype(dtype) + jnp_fun = partial(jnp.sort, stable=stable, **kwds) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + def testSortStableDescending(self): + # TODO(jakevdp): test directly against np.sort when descending is supported. + x = jnp.array([0, 1, jnp.nan, 0, 2, jnp.nan, -jnp.inf, jnp.inf]) + x_sorted = jnp.array([-jnp.inf, 0, 0, 1, 2, jnp.inf, jnp.nan, jnp.nan]) + argsorted_stable = jnp.array([6, 0, 3, 1, 4, 7, 2, 5]) + argsorted_rev_stable = jnp.array([2, 5, 7, 4, 1, 0, 3, 6]) + + self.assertArraysEqual(jnp.sort(x), x_sorted) + self.assertArraysEqual(jnp.sort(x, descending=True), lax.rev(x_sorted, [0])) + self.assertArraysEqual(jnp.argsort(x), argsorted_stable) + self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable) + @jtu.sample_product( [dict(shape=shape, axis=axis) for shape in one_dim_array_shapes @@ -3819,21 +3840,48 @@ def testLexsort(self, dtype, shape, input_type, axis): @jtu.sample_product( [dict(shape=shape, axis=axis) for shape in nonzerodim_shapes - for axis in (None, *range(len(shape))) + for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) ], dtype=all_dtypes, ) def testArgsort(self, dtype, shape, axis): rng = jtu.rand_some_equal(self.rng()) args_maker = lambda: [rng(shape, dtype)] - jnp_fun = jnp.argsort - np_fun = jtu.with_jax_dtype_defaults(np.argsort) - if axis is not None: - jnp_fun = partial(jnp_fun, axis=axis) - np_fun = partial(np_fun, axis=axis) + kwds = {} if axis is NO_VALUE else {'axis': axis} + + @jtu.with_jax_dtype_defaults + def np_fun(arr): + # Note: numpy sort fails on NaN and Inf values with bfloat16 + if arr.dtype == jnp.bfloat16: + arr = arr.astype('float32') + # TODO(jakevdp): switch to stable=True when supported by numpy. + return np.argsort(arr, kind='stable', **kwds) + jnp_fun = partial(jnp.argsort, stable=True, **kwds) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in nonempty_nonscalar_array_shapes + for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) + ], + descending=[True, False], + dtype=all_dtypes, + ) + def testArgsortUnstable(self, dtype, shape, axis, descending): + # We cannot directly compare unstable argsorts, so instead check that indexed values match. + rng = jtu.rand_some_equal(self.rng()) + x = rng(shape, dtype) + kwds = {} if axis is NO_VALUE else {'axis': axis} + expected = jnp.sort(x, descending=descending, stable=False, **kwds) + indices = jnp.argsort(x, descending=descending, stable=False, **kwds) + if axis is None: + actual = jnp.ravel(x)[indices] + else: + actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis) + self.assertArraysEqual(actual, expected) + @jtu.sample_product( [{'shape': shape, 'axis': axis, 'kth': kth} for shape in nonzerodim_shapes