Skip to content

Commit

Permalink
FIX: accuracy and zero_loss support for multilabel with Array API (
Browse files Browse the repository at this point in the history
…#29336)

Co-authored-by: Omar Salman <[email protected]>
Co-authored-by: Omar Salman <[email protected]>
  • Loading branch information
3 people authored and jeremiedbb committed Jul 2, 2024
1 parent 99d8a32 commit 851c0d6
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
6 changes: 3 additions & 3 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ Changelog
instead of implicitly converting those inputs as regular NumPy arrays.
:pr:`29119` by :user:`Olivier Grisel`.

- |Fix| Fix a regression in :func:`metrics.zero_one_loss` causing an error
for Array API dispatch with multilabel inputs.
:pr:`29269` by :user:`Yaroslav Korobko <Tialo>`.
- |Fix| Fix a regression in :func:`metrics.accuracy_score` and in :func:`metrics.zero_one_loss`
causing an error for Array API dispatch with multilabel inputs.
:pr:`29269` by :user:`Yaroslav Korobko <Tialo>` and :pr:`29336` by :user:`Edoardo Abati <EdAbati>`.

:mod:`sklearn.model_selection`
..............................
Expand Down
19 changes: 16 additions & 3 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
)
from ..utils._array_api import (
_average,
_count_nonzero,
_is_numpy_namespace,
_union1d,
get_namespace,
get_namespace_and_device,
Expand Down Expand Up @@ -97,6 +99,7 @@ def _check_targets(y_true, y_pred):
y_pred : array or indicator matrix
"""
xp, _ = get_namespace(y_true, y_pred)
check_consistent_length(y_true, y_pred)
type_true = type_of_target(y_true, input_name="y_true")
type_pred = type_of_target(y_pred, input_name="y_pred")
Expand Down Expand Up @@ -142,8 +145,13 @@ def _check_targets(y_true, y_pred):
y_type = "multiclass"

if y_type.startswith("multilabel"):
y_true = csr_matrix(y_true)
y_pred = csr_matrix(y_pred)
if _is_numpy_namespace(xp):
# XXX: do we really want to sparse-encode multilabel indicators when
# they are passed as a dense arrays? This is not possible for array
# API inputs in general hence we only do it for NumPy inputs. But even
# for NumPy the usefulness is questionable.
y_true = csr_matrix(y_true)
y_pred = csr_matrix(y_pred)
y_type = "multilabel-indicator"

return y_type, y_true, y_pred
Expand Down Expand Up @@ -223,7 +231,12 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
if y_type.startswith("multilabel"):
differing_labels = count_nonzero(y_true - y_pred, axis=1)
if _is_numpy_namespace(xp):
differing_labels = count_nonzero(y_true - y_pred, axis=1)
else:
differing_labels = _count_nonzero(
y_true - y_pred, xp=xp, device=device, axis=1
)
score = xp.asarray(differing_labels == 0, device=device)
else:
score = y_true == y_pred
Expand Down
17 changes: 17 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,3 +841,20 @@ def indexing_dtype(xp):
# TODO: once sufficiently adopted, we might want to instead rely on the
# newer inspection API: https://github.com/data-apis/array-api/issues/640
return xp.asarray(0).dtype


def _count_nonzero(X, xp, device, axis=None, sample_weight=None):
"""A variant of `sklearn.utils.sparsefuncs.count_nonzero` for the Array API.
It only supports 2D arrays.
"""
assert X.ndim == 2

weights = xp.ones_like(X, device=device)
if sample_weight is not None:
sample_weight = xp.asarray(sample_weight, device=device)
sample_weight = xp.reshape(sample_weight, (sample_weight.shape[0], 1))
weights = xp.astype(weights, sample_weight.dtype) * sample_weight

zero_scalar = xp.asarray(0, device=device, dtype=weights.dtype)
return xp.sum(xp.where(X != 0, weights, zero_scalar), axis=axis)
37 changes: 36 additions & 1 deletion sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_atol_for_type,
_average,
_convert_to_numpy,
_count_nonzero,
_estimator_with_converted_arrays,
_is_numpy_namespace,
_nanmax,
Expand All @@ -30,7 +31,7 @@
_array_api_for_tests,
skip_if_array_api_compat_not_configured,
)
from sklearn.utils.fixes import _IS_32BIT
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS


@pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]])
Expand Down Expand Up @@ -530,3 +531,37 @@ def test_get_namespace_and_device():
assert namespace is xp_torch
assert is_array_api
assert device == some_torch_tensor.device


@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
)
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
@pytest.mark.parametrize("sample_weight_type", [None, "int", "float"])
def test_count_nonzero(
array_namespace, device, dtype_name, csr_container, axis, sample_weight_type
):

from sklearn.utils.sparsefuncs import count_nonzero as sparse_count_nonzero

xp = _array_api_for_tests(array_namespace, device)
array = numpy.array([[0, 3, 0], [2, -1, 0], [0, 0, 0], [9, 8, 7], [4, 0, 5]])
if sample_weight_type == "int":
sample_weight = numpy.asarray([1, 2, 2, 3, 1])
elif sample_weight_type == "float":
sample_weight = numpy.asarray([0.5, 1.5, 0.8, 3.2, 2.4], dtype=dtype_name)
else:
sample_weight = None
expected = sparse_count_nonzero(
csr_container(array), axis=axis, sample_weight=sample_weight
)
array_xp = xp.asarray(array, device=device)

with config_context(array_api_dispatch=True):
result = _count_nonzero(
array_xp, xp=xp, device=device, axis=axis, sample_weight=sample_weight
)

assert_allclose(_convert_to_numpy(result, xp=xp), expected)
assert getattr(array_xp, "device", None) == getattr(result, "device", None)

0 comments on commit 851c0d6

Please sign in to comment.