Skip to content

Commit

Permalink
[array api] add stable & descending params to jnp.sort & jnp.argsort
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 4, 2024
1 parent ebc7af9 commit cecf432
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 46 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Remember to align the itemized text with the first line of an item within a list
devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical
devices.
* {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable`
and `descending` arguments.
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
Expand Down
58 changes: 39 additions & 19 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3901,23 +3901,29 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False):


@util._wraps(np.sort)
@partial(jit, static_argnames=('axis', 'kind', 'order'))
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
def sort(
a: ArrayLike,
axis: int | None = -1,
kind: str = "quicksort",
order: None = None,
kind: None = None,
order: None = None, *,
stable: bool = True,
descending: bool = False,
) -> Array:
util.check_arraylike("sort", a)
if kind != 'quicksort':
if kind is not None:
warnings.warn("'kind' argument to sort is ignored.")
if order is not None:
raise ValueError("'order' argument to sort is not supported.")

if axis is None:
return lax.sort(ravel(a), dimension=0)
else:
return lax.sort(asarray(a), dimension=_canonicalize_axis(axis, ndim(a)))
a = ravel(a)
axis = 0
dimension = _canonicalize_axis(axis, ndim(a))
result = lax.sort(asarray(a), dimension=dimension, is_stable=stable)
if descending:
result = lax.rev(result, dimensions=[dimension])
return result


@util._wraps(np.sort_complex)
Expand Down Expand Up @@ -3953,29 +3959,43 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A


@util._wraps(np.argsort, lax_description=_ARGSORT_DOC)
@partial(jit, static_argnames=('axis', 'kind', 'order'))
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
def argsort(
a: ArrayLike,
axis: int | None = -1,
kind: str = "stable",
kind: None = None,
order: None = None,
*, stable: bool = True,
descending: bool = False,
) -> Array:
util.check_arraylike("argsort", a)
arr = asarray(a)
if kind != 'stable':
warnings.warn("'kind' argument to argsort is ignored; only 'stable' sorts "
"are supported.")
if kind is not None:
warnings.warn("'kind' argument to argsort is ignored.")
if order is not None:
raise ValueError("'order' argument to argsort is not supported.")

if axis is None:
return argsort(arr.ravel(), 0)
arr = ravel(arr)
axis = 0
else:
arr = asarray(a)
dimension = _canonicalize_axis(axis, arr.ndim)

if descending:
rev = partial(lax.rev, dimensions=[dimension])
if stable:
return arr.shape[dimension] - 1 - rev(_argsort(rev(arr), axis=dimension, stable=stable))
else:
return rev(_argsort(arr, axis=dimension, stable=stable))
else:
axis_num = _canonicalize_axis(axis, arr.ndim)
use_64bit_index = not core.is_constant_dim(arr.shape[axis_num]) or arr.shape[axis_num] >= (1 << 31)
iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, arr.shape, axis_num)
_, perm = lax.sort_key_val(arr, iota, dimension=axis_num)
return perm
return _argsort(arr, axis=dimension, stable=stable)


def _argsort(a: Array, *, axis: int, stable: bool):
use_64bit_index = not core.is_constant_dim(a.shape[axis]) or a.shape[axis] >= (1 << 31)
iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, a.shape, axis)
_, indices = lax.sort_key_val(a, iota, dimension=axis, is_stable=stable)
return indices


@util._wraps(np.partition, lax_description="""
Expand Down
12 changes: 2 additions & 10 deletions jax/experimental/array_api/_sorting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,10 @@
def argsort(x: Array, /, *, axis: int = -1, descending: bool = False,
stable: bool = True) -> Array:
"""Returns the indices that sort an array x along a specified axis."""
del stable # unused
if descending:
return jax.numpy.argsort(-x, axis=axis)
else:
return jax.numpy.argsort(x, axis=axis)
return jax.numpy.argsort(x, axis=axis, descending=descending, stable=stable)


def sort(x: Array, /, *, axis: int = -1, descending: bool = False,
stable: bool = True) -> Array:
"""Returns a sorted copy of an input array x."""
del stable # unused
result = jax.numpy.sort(x, axis=axis)
if descending:
return jax.lax.rev(result, dimensions=[axis + x.ndim if axis < 0 else axis])
return result
return jax.numpy.sort(x, axis=axis, descending=descending, stable=stable)
3 changes: 0 additions & 3 deletions jax/experimental/array_api/skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,5 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_linalg.py::test_matrix_power
array_api_tests/test_linalg.py::test_solve

# JAX's NaN sorting doesn't match specification
array_api_tests/test_sorting_functions.py::test_argsort

# fft test suite is buggy as of 83f0bcdc
array_api_tests/test_fft.py
10 changes: 8 additions & 2 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,11 @@ def argpartition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ...
def argsort(
a: ArrayLike,
axis: Optional[int] = -1,
kind: str = "stable",
kind: None = ...,
order: None = ...,
*,
stable: bool = ...,
descending: bool = ...,
) -> Array: ...
def argwhere(
a: ArrayLike,
Expand Down Expand Up @@ -701,8 +704,11 @@ sometrue = any
def sort(
a: ArrayLike,
axis: Optional[int] = ...,
kind: str = ...,
kind: None = ...,
order: None = ...,
*,
stable: bool = ...,
descending: bool = ...,
) -> Array: ...
def sort_complex(a: ArrayLike) -> Array: ...
def split(
Expand Down
36 changes: 24 additions & 12 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
all_dtypes = number_dtypes + bool_dtypes

NO_VALUE = object()

python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]

Expand Down Expand Up @@ -3771,21 +3772,32 @@ def testArangeTypes(self):
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonzerodim_shapes
for axis in (None, *range(len(shape)))
for axis in (NO_VALUE, None, *range(len(shape)))
],
dtype=all_dtypes,
)
def testSort(self, dtype, shape, axis):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = jnp.sort
np_fun = np.sort
if axis is not None:
jnp_fun = partial(jnp_fun, axis=axis)
np_fun = partial(np_fun, axis=axis)
kwds = {} if axis is NO_VALUE else {'axis': axis}
# TODO(jakevdp): switch to stable=True when supported by numpy.
np_fun = partial(np.sort, kind='stable', **kwds)
jnp_fun = partial(jnp.sort, stable=True, **kwds)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def testSortStableDescending(self):
# TODO(jakevdp): test directly against np.sort when descending is supported.
x = jnp.array([0, 1, jnp.nan, 0, 2, jnp.nan, -jnp.inf, jnp.inf])
x_sorted = jnp.array([-jnp.inf, 0, 0, 1, 2, jnp.inf, jnp.nan, jnp.nan])
argsorted_stable = jnp.array([6, 0, 3, 1, 4, 7, 2, 5])
argsorted_rev_stable = jnp.array([2, 5, 7, 4, 1, 0, 3, 6])

self.assertArraysEqual(jnp.sort(x), x_sorted)
self.assertArraysEqual(jnp.sort(x, descending=True), lax.rev(x_sorted, [0]))
self.assertArraysEqual(jnp.argsort(x), argsorted_stable)
self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable)

@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in one_dim_array_shapes
Expand Down Expand Up @@ -3819,18 +3831,18 @@ def testLexsort(self, dtype, shape, input_type, axis):
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonzerodim_shapes
for axis in (None, *range(len(shape)))
for axis in (NO_VALUE, None, *range(len(shape)))
],
dtype=all_dtypes,
)
def testArgsort(self, dtype, shape, axis):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
jnp_fun = jnp.argsort
np_fun = jtu.with_jax_dtype_defaults(np.argsort)
if axis is not None:
jnp_fun = partial(jnp_fun, axis=axis)
np_fun = partial(np_fun, axis=axis)
kwds = {} if axis is NO_VALUE else {'axis': axis}
# TODO(jakevdp): switch to stable=True when supported by numpy.
np_fun = jtu.with_jax_dtype_defaults(
partial(np.argsort, kind='stable', **kwds))
jnp_fun = partial(jnp.argsort, stable=True, **kwds)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

Expand Down

0 comments on commit cecf432

Please sign in to comment.