Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for kind="mergesort" or "radixsort" for dpnp.sort and dpnp.argsort #2159

Merged
merged 6 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions dpnp/dpnp_iface_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ def _wrap_sort_argsort(
raise NotImplementedError(
"order keyword argument is only supported with its default value."
)
if kind is not None and kind != "stable":
raise NotImplementedError(
"kind keyword argument can only be None or 'stable'."
if kind is not None and stable is not None:
raise ValueError(
"`kind` and `stable` parameters can't be provided at the same time."
" Use only one of them."
)

usm_a = dpnp.get_usm_ndarray(a)
Expand All @@ -77,11 +78,11 @@ def _wrap_sort_argsort(
axis = -1

axis = normalize_axis_index(axis, ndim=usm_a.ndim)
usm_res = _sorting_fn(usm_a, axis=axis, stable=stable)
usm_res = _sorting_fn(usm_a, axis=axis, stable=stable, kind=kind)
return dpnp_array._create_from_usm_ndarray(usm_res)


def argsort(a, axis=-1, kind=None, order=None, *, stable=True):
def argsort(a, axis=-1, kind=None, order=None, *, stable=None):
"""
Returns the indices that would sort an array.

Expand All @@ -94,9 +95,9 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=True):
axis : {None, int}, optional
Axis along which to sort. If ``None``, the array is flattened before
sorting. The default is ``-1``, which sorts along the last axis.
kind : {None, "stable"}, optional
kind : {None, "stable", "mergesort", "radixsort"}, optional
Sorting algorithm. Default is ``None``, which is equivalent to
``"stable"``. Unlike NumPy, no other option is accepted here.
``"stable"``.
stable : {None, bool}, optional
Sort stability. If ``True``, the returned array will maintain
the relative order of ``a`` values which compare as equal.
Expand All @@ -121,8 +122,9 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=True):
Limitations
-----------
Parameters `order` is only supported with its default value.
Parameter `kind` can only be ``None`` or ``"stable"`` which are equivalent.
Otherwise ``NotImplementedError`` exception will be raised.
Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported.


See Also
--------
Expand Down Expand Up @@ -203,7 +205,7 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None):
return call_origin(numpy.partition, x1, kth, axis, kind, order)


def sort(a, axis=-1, kind=None, order=None, *, stable=True):
def sort(a, axis=-1, kind=None, order=None, *, stable=None):
"""
Return a sorted copy of an array.

Expand All @@ -216,9 +218,9 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=True):
axis : {None, int}, optional
Axis along which to sort. If ``None``, the array is flattened before
sorting. The default is ``-1``, which sorts along the last axis.
kind : {None, "stable"}, optional
kind : {None, "stable", "mergesort", "radixsort"}, optional
Sorting algorithm. Default is ``None``, which is equivalent to
``"stable"``. Unlike NumPy, no other option is accepted here.
``"stable"``.
stable : {None, bool}, optional
Sort stability. If ``True``, the returned array will maintain
the relative order of ``a`` values which compare as equal.
Expand All @@ -239,8 +241,8 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=True):
Limitations
-----------
Parameters `order` is only supported with its default value.
Parameter `kind` can only be ``None`` or ``"stable"`` which are equivalent.
Otherwise ``NotImplementedError`` exception will be raised.
Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported.

See Also
--------
Expand Down
71 changes: 44 additions & 27 deletions tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,34 @@


class TestArgsort:
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
def test_argsort_dtype(self, dtype):
def test_basic(self, kind, dtype):
a = numpy.random.uniform(-5, 5, 10)
np_array = numpy.array(a, dtype=dtype)
dp_array = dpnp.array(np_array)

result = dpnp.argsort(dp_array, kind="stable")
result = dpnp.argsort(dp_array, kind=kind)
expected = numpy.argsort(np_array, kind="stable")
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
@pytest.mark.parametrize("dtype", get_complex_dtypes())
def test_argsort_complex(self, dtype):
def test_complex(self, kind, dtype):
a = numpy.random.uniform(-5, 5, 10)
b = numpy.random.uniform(-5, 5, 10)
np_array = numpy.array(a + b * 1j, dtype=dtype)
dp_array = dpnp.array(np_array)

result = dpnp.argsort(dp_array)
expected = numpy.argsort(np_array)
assert_dtype_allclose(result, expected)
if kind == "radixsort":
assert_raises(ValueError, dpnp.argsort, dp_array, kind=kind)
else:
result = dpnp.argsort(dp_array, kind=kind)
expected = numpy.argsort(np_array)
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1, 2])
def test_argsort_axis(self, axis):
def test_axis(self, axis):
a = numpy.random.uniform(-10, 10, 36)
np_array = numpy.array(a).reshape(3, 4, 3)
dp_array = dpnp.array(np_array)
Expand All @@ -48,7 +53,7 @@ def test_argsort_axis(self, axis):

@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1])
def test_argsort_ndarray(self, dtype, axis):
def test_ndarray(self, dtype, axis):
if dtype and issubclass(dtype, numpy.integer):
a = numpy.random.choice(
numpy.arange(-10, 10), replace=False, size=12
Expand All @@ -62,8 +67,9 @@ def test_argsort_ndarray(self, dtype, axis):
expected = np_array.argsort(axis=axis)
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("kind", [None, "stable"])
def test_sort_kind(self, kind):
# this test validates that all different options of kind in dpnp are stable
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
vtavana marked this conversation as resolved.
Show resolved Hide resolved
def test_kind(self, kind):
np_array = numpy.repeat(numpy.arange(10), 10)
dp_array = dpnp.array(np_array)

Expand All @@ -74,15 +80,15 @@ def test_sort_kind(self, kind):
# `stable` keyword is supported in numpy 2.0 and above
@testing.with_requires("numpy>=2.0")
@pytest.mark.parametrize("stable", [None, False, True])
def test_sort_stable(self, stable):
def test_stable(self, stable):
np_array = numpy.repeat(numpy.arange(10), 10)
dp_array = dpnp.array(np_array)

result = dpnp.argsort(dp_array, stable="stable")
expected = numpy.argsort(np_array, stable=True)
assert_dtype_allclose(result, expected)

def test_argsort_zero_dim(self):
def test_zero_dim(self):
np_array = numpy.array(2.5)
dp_array = dpnp.array(np_array)

Expand Down Expand Up @@ -266,29 +272,34 @@ def test_v_scalar(self):


class TestSort:
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
def test_sort_dtype(self, dtype):
def test_basic(self, kind, dtype):
a = numpy.random.uniform(-5, 5, 10)
np_array = numpy.array(a, dtype=dtype)
dp_array = dpnp.array(np_array)

result = dpnp.sort(dp_array)
result = dpnp.sort(dp_array, kind=kind)
expected = numpy.sort(np_array)
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
@pytest.mark.parametrize("dtype", get_complex_dtypes())
def test_sort_complex(self, dtype):
def test_complex(self, kind, dtype):
a = numpy.random.uniform(-5, 5, 10)
b = numpy.random.uniform(-5, 5, 10)
np_array = numpy.array(a + b * 1j, dtype=dtype)
dp_array = dpnp.array(np_array)

result = dpnp.sort(dp_array)
expected = numpy.sort(np_array)
assert_dtype_allclose(result, expected)
if kind == "radixsort":
assert_raises(ValueError, dpnp.argsort, dp_array, kind=kind)
else:
result = dpnp.sort(dp_array, kind=kind)
expected = numpy.sort(np_array)
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1, 2])
def test_sort_axis(self, axis):
def test_axis(self, axis):
a = numpy.random.uniform(-10, 10, 36)
np_array = numpy.array(a).reshape(3, 4, 3)
dp_array = dpnp.array(np_array)
Expand All @@ -299,7 +310,7 @@ def test_sort_axis(self, axis):

@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize("axis", [-2, -1, 0, 1])
def test_sort_ndarray(self, dtype, axis):
def test_ndarray(self, dtype, axis):
a = numpy.random.uniform(-10, 10, 12)
np_array = numpy.array(a, dtype=dtype).reshape(6, 2)
dp_array = dpnp.array(np_array)
Expand All @@ -308,8 +319,9 @@ def test_sort_ndarray(self, dtype, axis):
np_array.sort(axis=axis)
assert_dtype_allclose(dp_array, np_array)

@pytest.mark.parametrize("kind", [None, "stable"])
def test_sort_kind(self, kind):
# this test validates that all different options of kind in dpnp are stable
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
def test_kind(self, kind):
np_array = numpy.repeat(numpy.arange(10), 10)
dp_array = dpnp.array(np_array)

Expand All @@ -320,21 +332,21 @@ def test_sort_kind(self, kind):
# `stable` keyword is supported in numpy 2.0 and above
@testing.with_requires("numpy>=2.0")
@pytest.mark.parametrize("stable", [None, False, True])
def test_sort_stable(self, stable):
def test_stable(self, stable):
np_array = numpy.repeat(numpy.arange(10), 10)
dp_array = dpnp.array(np_array)

result = dpnp.sort(dp_array, stable="stable")
expected = numpy.sort(np_array, stable=True)
assert_dtype_allclose(result, expected)

def test_sort_ndarray_axis_none(self):
def test_ndarray_axis_none(self):
a = numpy.random.uniform(-10, 10, 12)
dp_array = dpnp.array(a).reshape(6, 2)
with pytest.raises(TypeError):
dp_array.sort(axis=None)

def test_sort_zero_dim(self):
def test_zero_dim(self):
np_array = numpy.array(2.5)
dp_array = dpnp.array(np_array)

Expand All @@ -347,15 +359,20 @@ def test_sort_zero_dim(self):
expected = numpy.sort(np_array, axis=None)
assert_dtype_allclose(result, expected)

def test_sort_notimplemented(self):
def test_error(self):
dp_array = dpnp.arange(10)

with pytest.raises(NotImplementedError):
# quicksort is currently not supported
with pytest.raises(ValueError):
dpnp.sort(dp_array, kind="quicksort")

with pytest.raises(NotImplementedError):
dpnp.sort(dp_array, order=["age"])

# both kind and stable are given
with pytest.raises(ValueError):
dpnp.sort(dp_array, kind="mergesort", stable=True)


class TestSortComplex:
@pytest.mark.parametrize(
Expand Down
Loading