Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[array api] add stable & descending params to jnp.sort & jnp.argsort #19201

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Remember to align the itemized text with the first line of an item within a list
devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical
devices.
* {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable`
and `descending` arguments.
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
Expand Down
51 changes: 32 additions & 19 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3901,23 +3901,28 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False):


@util._wraps(np.sort)
@partial(jit, static_argnames=('axis', 'kind', 'order'))
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
def sort(
a: ArrayLike,
axis: int | None = -1,
kind: str = "quicksort",
order: None = None,
kind: None = None,
order: None = None, *,
stable: bool = True,
descending: bool = False,
) -> Array:
util.check_arraylike("sort", a)
if kind != 'quicksort':
if kind is not None:
warnings.warn("'kind' argument to sort is ignored.")
if order is not None:
raise ValueError("'order' argument to sort is not supported.")

if axis is None:
return lax.sort(ravel(a), dimension=0)
arr = ravel(a)
axis = 0
else:
return lax.sort(asarray(a), dimension=_canonicalize_axis(axis, ndim(a)))
arr = asarray(a)
dimension = _canonicalize_axis(axis, arr.ndim)
result = lax.sort(arr, dimension=dimension, is_stable=stable)
return lax.rev(result, dimensions=[dimension]) if descending else result


@util._wraps(np.sort_complex)
Expand Down Expand Up @@ -3953,29 +3958,37 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A


@util._wraps(np.argsort, lax_description=_ARGSORT_DOC)
@partial(jit, static_argnames=('axis', 'kind', 'order'))
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
def argsort(
a: ArrayLike,
axis: int | None = -1,
kind: str = "stable",
kind: None = None,
order: None = None,
*, stable: bool = True,
descending: bool = False,
) -> Array:
util.check_arraylike("argsort", a)
arr = asarray(a)
if kind != 'stable':
warnings.warn("'kind' argument to argsort is ignored; only 'stable' sorts "
"are supported.")
if kind is not None:
warnings.warn("'kind' argument to argsort is ignored.")
if order is not None:
raise ValueError("'order' argument to argsort is not supported.")

if axis is None:
return argsort(arr.ravel(), 0)
arr = ravel(arr)
axis = 0
else:
axis_num = _canonicalize_axis(axis, arr.ndim)
use_64bit_index = not core.is_constant_dim(arr.shape[axis_num]) or arr.shape[axis_num] >= (1 << 31)
iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, arr.shape, axis_num)
_, perm = lax.sort_key_val(arr, iota, dimension=axis_num)
return perm
arr = asarray(a)
dimension = _canonicalize_axis(axis, arr.ndim)
use_64bit_index = not core.is_constant_dim(arr.shape[dimension]) or arr.shape[dimension] >= (1 << 31)
iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, arr.shape, dimension)
# For stable descending sort, we reverse the array and indices to ensure that
# duplicates remain in their original order when the final indices are reversed.
# For non-stable descending sort, we can avoid these extra operations.
if descending and stable:
arr = lax.rev(arr, dimensions=[dimension])
iota = lax.rev(iota, dimensions=[dimension])
_, indices = lax.sort_key_val(arr, iota, dimension=dimension, is_stable=stable)
return lax.rev(indices, dimensions=[dimension]) if descending else indices


@util._wraps(np.partition, lax_description="""
Expand Down
12 changes: 2 additions & 10 deletions jax/experimental/array_api/_sorting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,10 @@
def argsort(x: Array, /, *, axis: int = -1, descending: bool = False,
stable: bool = True) -> Array:
"""Returns the indices that sort an array x along a specified axis."""
del stable # unused
if descending:
return jax.numpy.argsort(-x, axis=axis)
else:
return jax.numpy.argsort(x, axis=axis)
return jax.numpy.argsort(x, axis=axis, descending=descending, stable=stable)


def sort(x: Array, /, *, axis: int = -1, descending: bool = False,
stable: bool = True) -> Array:
"""Returns a sorted copy of an input array x."""
del stable # unused
result = jax.numpy.sort(x, axis=axis)
if descending:
return jax.lax.rev(result, dimensions=[axis + x.ndim if axis < 0 else axis])
return result
return jax.numpy.sort(x, axis=axis, descending=descending, stable=stable)
3 changes: 0 additions & 3 deletions jax/experimental/array_api/skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,5 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_linalg.py::test_matrix_power
array_api_tests/test_linalg.py::test_solve

# JAX's NaN sorting doesn't match specification
array_api_tests/test_sorting_functions.py::test_argsort

# fft test suite is buggy as of 83f0bcdc
array_api_tests/test_fft.py
12 changes: 9 additions & 3 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,12 @@ def argmin(
def argpartition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ...
def argsort(
a: ArrayLike,
axis: Optional[int] = -1,
kind: str = "stable",
axis: Optional[int] = ...,
kind: None = ...,
order: None = ...,
*,
stable: bool = ...,
descending: bool = ...,
) -> Array: ...
def argwhere(
a: ArrayLike,
Expand Down Expand Up @@ -701,8 +704,11 @@ sometrue = any
def sort(
a: ArrayLike,
axis: Optional[int] = ...,
kind: str = ...,
kind: None = ...,
order: None = ...,
*,
stable: bool = ...,
descending: bool = ...,
) -> Array: ...
def sort_complex(a: ArrayLike) -> Array: ...
def split(
Expand Down
76 changes: 62 additions & 14 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
all_dtypes = number_dtypes + bool_dtypes

NO_VALUE = object()

python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]

Expand Down Expand Up @@ -3771,21 +3772,41 @@ def testArangeTypes(self):
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonzerodim_shapes
for axis in (None, *range(len(shape)))
for axis in (NO_VALUE, None, *range(-len(shape), len(shape)))
],
stable=[True, False],
dtype=all_dtypes,
)
def testSort(self, dtype, shape, axis):
rng = jtu.rand_some_equal(self.rng())
def testSort(self, dtype, shape, axis, stable):
rng = jtu.rand_some_equal(self.rng()) if stable else jtu.rand_some_inf_and_nan(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = jnp.sort
np_fun = np.sort
if axis is not None:
jnp_fun = partial(jnp_fun, axis=axis)
np_fun = partial(np_fun, axis=axis)
kwds = {} if axis is NO_VALUE else {'axis': axis}

def np_fun(arr):
# Note: numpy sort fails on NaN and Inf values with bfloat16
dtype = arr.dtype
if arr.dtype == jnp.bfloat16:
arr = arr.astype('float32')
# TODO(jakevdp): switch to stable=stable when supported by numpy.
result = np.sort(arr, kind='stable' if stable else None, **kwds)
with jtu.ignore_warning(category=RuntimeWarning, message='invalid value'):
return result.astype(dtype)
jnp_fun = partial(jnp.sort, stable=stable, **kwds)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def testSortStableDescending(self):
# TODO(jakevdp): test directly against np.sort when descending is supported.
x = jnp.array([0, 1, jnp.nan, 0, 2, jnp.nan, -jnp.inf, jnp.inf])
x_sorted = jnp.array([-jnp.inf, 0, 0, 1, 2, jnp.inf, jnp.nan, jnp.nan])
argsorted_stable = jnp.array([6, 0, 3, 1, 4, 7, 2, 5])
argsorted_rev_stable = jnp.array([2, 5, 7, 4, 1, 0, 3, 6])

self.assertArraysEqual(jnp.sort(x), x_sorted)
self.assertArraysEqual(jnp.sort(x, descending=True), lax.rev(x_sorted, [0]))
self.assertArraysEqual(jnp.argsort(x), argsorted_stable)
self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable)

@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in one_dim_array_shapes
Expand Down Expand Up @@ -3819,21 +3840,48 @@ def testLexsort(self, dtype, shape, input_type, axis):
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonzerodim_shapes
for axis in (None, *range(len(shape)))
for axis in (NO_VALUE, None, *range(-len(shape), len(shape)))
],
dtype=all_dtypes,
)
def testArgsort(self, dtype, shape, axis):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = jnp.argsort
np_fun = jtu.with_jax_dtype_defaults(np.argsort)
if axis is not None:
jnp_fun = partial(jnp_fun, axis=axis)
np_fun = partial(np_fun, axis=axis)
kwds = {} if axis is NO_VALUE else {'axis': axis}

@jtu.with_jax_dtype_defaults
def np_fun(arr):
# Note: numpy sort fails on NaN and Inf values with bfloat16
if arr.dtype == jnp.bfloat16:
arr = arr.astype('float32')
# TODO(jakevdp): switch to stable=True when supported by numpy.
return np.argsort(arr, kind='stable', **kwds)
jnp_fun = partial(jnp.argsort, stable=True, **kwds)

self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonempty_nonscalar_array_shapes
for axis in (NO_VALUE, None, *range(-len(shape), len(shape)))
],
descending=[True, False],
dtype=all_dtypes,
)
def testArgsortUnstable(self, dtype, shape, axis, descending):
# We cannot directly compare unstable argsorts, so instead check that indexed values match.
rng = jtu.rand_some_equal(self.rng())
x = rng(shape, dtype)
kwds = {} if axis is NO_VALUE else {'axis': axis}
expected = jnp.sort(x, descending=descending, stable=False, **kwds)
indices = jnp.argsort(x, descending=descending, stable=False, **kwds)
if axis is None:
actual = jnp.ravel(x)[indices]
else:
actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis)
self.assertArraysEqual(actual, expected)

@jtu.sample_product(
[{'shape': shape, 'axis': axis, 'kth': kth}
for shape in nonzerodim_shapes
Expand Down