diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 068a667b2603..52726622f076 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -125,6 +125,8 @@ class _LGBMRegressorBase: # type: ignore try: from dask import delayed from dask.array import Array as dask_Array + from dask.array import from_delayed as dask_array_from_delayed + from dask.bag import from_delayed as dask_bag_from_delayed from dask.dataframe import DataFrame as dask_DataFrame from dask.dataframe import Series as dask_Series from dask.distributed import Client, default_client, wait @@ -132,6 +134,8 @@ class _LGBMRegressorBase: # type: ignore except ImportError: DASK_INSTALLED = False + dask_array_from_delayed = None + dask_bag_from_delayed = None delayed = None default_client = None wait = None diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 39919d96ad58..107cf218d861 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -10,6 +10,7 @@ from collections import defaultdict from copy import deepcopy from enum import Enum, auto +from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union from urllib.parse import urlparse @@ -18,7 +19,8 @@ from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning, _safe_call from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat, - dask_Array, dask_DataFrame, dask_Series, default_client, delayed, pd_DataFrame, pd_Series, wait) + dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series, + default_client, delayed, pd_DataFrame, pd_Series, wait) from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict) @@ -842,7 +844,7 @@ def _predict( pred_contrib: bool = False, dtype: _PredictionDtype = np.float32, **kwargs: Any -) -> dask_Array: +) -> Union[dask_Array, List[dask_Array]]: """Inner predict routine. Parameters @@ -870,7 +872,7 @@ def _predict( The predicted values. X_leaves : Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes] If ``pred_leaf=True``, the predicted leaf of every tree for each sample. - X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] + X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1] If ``pred_contrib=True``, the feature contributions for each sample. """ if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): @@ -886,6 +888,74 @@ def _predict( **kwargs ).values elif isinstance(data, dask_Array): + # for multi-class classification with sparse matrices, pred_contrib predictions + # are returned as a list of sparse matrices (one per class) + num_classes = model._n_classes or -1 + + if ( + num_classes > 2 + and pred_contrib + and isinstance(data._meta, ss.spmatrix) + ): + + predict_function = partial( + _predict_part, + model=model, + raw_score=False, + pred_proba=pred_proba, + pred_leaf=False, + pred_contrib=True, + **kwargs + ) + + delayed_chunks = data.to_delayed() + bag = dask_bag_from_delayed(delayed_chunks[:, 0]) + + @delayed + def _extract(items: List[Any], i: int) -> Any: + return items[i] + + preds = bag.map_partitions(predict_function) + + # pred_contrib output will have one column per feature, + # plus one more for the base value + num_cols = model.n_features_ + 1 + + nrows_per_chunk = data.chunks[0] + out = [[] for _ in range(num_classes)] + + # need to tell Dask the expected type and shape of individual preds + pred_meta = data._meta + + for j, partition in enumerate(preds.to_delayed()): + for i in range(num_classes): + part = dask_array_from_delayed( + value=_extract(partition, i), + shape=(nrows_per_chunk[j], num_cols), + meta=pred_meta + ) + out[i].append(part) + + # by default, dask.array.concatenate() concatenates sparse arrays into a COO matrix + # the code below is used instead to ensure that the sparse type is preserved during concatentation + if isinstance(pred_meta, ss.csr_matrix): + concat_fn = partial(ss.vstack, format='csr') + elif isinstance(pred_meta, ss.csc_matrix): + concat_fn = partial(ss.vstack, format='csc') + else: + concat_fn = ss.vstack + + # At this point, `out` is a list of lists of delayeds (each of which points to a matrix). + # Concatenate them to return a list of Dask Arrays. + for i in range(num_classes): + out[i] = dask_array_from_delayed( + value=delayed(concat_fn)(out[i]), + shape=(data.shape[0], num_cols), + meta=pred_meta + ) + + return out + return data.map_blocks( _predict_part, model=model, @@ -1140,7 +1210,7 @@ def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: output_name="predicted_result", predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]", X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", - X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]" + X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]" ) def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: @@ -1158,7 +1228,7 @@ def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: output_name="predicted_probability", predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]", X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", - X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]" + X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]" ) def to_local(self) -> LGBMClassifier: diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 98f738ddcb30..2c0d4089c990 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -28,7 +28,7 @@ from dask.array.utils import assert_eq from dask.distributed import Client, LocalCluster, default_client, wait from pkg_resources import parse_version -from scipy.sparse import csr_matrix +from scipy.sparse import csc_matrix, csr_matrix from scipy.stats import spearmanr from sklearn import __version__ as sk_version from sklearn.datasets import make_blobs, make_regression @@ -198,6 +198,12 @@ def _create_data(objective, n_samples=1_000, output='array', chunk_size=500, **k dX = da.from_array(X, chunks=(chunk_size, X.shape[1])).map_blocks(csr_matrix) dy = da.from_array(y, chunks=chunk_size) dw = da.from_array(weights, chunk_size) + X = csr_matrix(X) + elif output == 'scipy_csc_matrix': + dX = da.from_array(X, chunks=(chunk_size, X.shape[1])).map_blocks(csc_matrix) + dy = da.from_array(y, chunks=chunk_size) + dw = da.from_array(weights, chunk_size) + X = csc_matrix(X) else: raise ValueError(f"Unknown output type '{output}'") @@ -344,7 +350,7 @@ def test_classifier(output, task, boosting_type, tree_learner, cluster): assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' -@pytest.mark.parametrize('output', data_output) +@pytest.mark.parametrize('output', data_output + ['scipy_csc_matrix']) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) def test_classifier_pred_contrib(output, task, cluster): with Client(cluster) as client: @@ -365,14 +371,52 @@ def test_classifier_pred_contrib(output, task, cluster): **params ) dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw) - preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute() + preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True) local_classifier = lgb.LGBMClassifier(**params) local_classifier.fit(X, y, sample_weight=w) local_preds_with_contrib = local_classifier.predict(X, pred_contrib=True) - if output == 'scipy_csr_matrix': - preds_with_contrib = np.array(preds_with_contrib.todense()) + # shape depends on whether it is binary or multiclass classification + num_features = dask_classifier.n_features_ + num_classes = dask_classifier.n_classes_ + if num_classes == 2: + expected_num_cols = num_features + 1 + else: + expected_num_cols = (num_features + 1) * num_classes + + # in the special case of multi-class classification using scipy sparse matrices, + # the output of `.predict(..., pred_contrib=True)` is a list of sparse matrices (one per class) + # + # since that case is so different than all other cases, check the relevant things here + # and then return early + if output.startswith('scipy') and task == 'multiclass-classification': + if output == 'scipy_csr_matrix': + expected_type = csr_matrix + elif output == 'scipy_csc_matrix': + expected_type = csc_matrix + else: + raise ValueError(f"Unrecognized output type: {output}") + assert isinstance(preds_with_contrib, list) + assert all(isinstance(arr, da.Array) for arr in preds_with_contrib) + assert all(isinstance(arr._meta, expected_type) for arr in preds_with_contrib) + assert len(preds_with_contrib) == num_classes + assert len(preds_with_contrib) == len(local_preds_with_contrib) + for i in range(num_classes): + computed_preds = preds_with_contrib[i].compute() + assert isinstance(computed_preds, expected_type) + assert computed_preds.shape[1] == num_classes + assert computed_preds.shape == local_preds_with_contrib[i].shape + assert len(np.unique(computed_preds[:, -1])) == 1 + # raw scores will probably be different, but at least check that all predicted classes are the same + pred_classes = np.argmax(computed_preds.toarray(), axis=1) + local_pred_classes = np.argmax(local_preds_with_contrib[i].toarray(), axis=1) + np.testing.assert_array_equal(pred_classes, local_pred_classes) + return + + preds_with_contrib = preds_with_contrib.compute() + if output.startswith('scipy'): + preds_with_contrib = preds_with_contrib.toarray() # be sure LightGBM actually used at least one categorical column, # and that it was correctly treated as a categorical feature @@ -386,14 +430,6 @@ def test_classifier_pred_contrib(output, task, cluster): assert node_uses_cat_col.sum() > 0 assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' - # shape depends on whether it is binary or multiclass classification - num_features = dask_classifier.n_features_ - num_classes = dask_classifier.n_classes_ - if num_classes == 2: - expected_num_cols = num_features + 1 - else: - expected_num_cols = (num_features + 1) * num_classes - # * shape depends on whether it is binary or multiclass classification # * matrix for binary classification is of the form [feature_contrib, base_value], # for multi-class it's [feat_contrib_class1, base_value_class1, feat_contrib_class2, base_value_class2, etc.] @@ -403,7 +439,7 @@ def test_classifier_pred_contrib(output, task, cluster): assert preds_with_contrib.shape == local_preds_with_contrib.shape if num_classes == 2: - assert len(np.unique(preds_with_contrib[:, num_features]) == 1) + assert len(np.unique(preds_with_contrib[:, num_features])) == 1 else: for i in range(num_classes): base_value_col = num_features * (i + 1) + i @@ -585,7 +621,7 @@ def test_regressor_pred_contrib(output, cluster): local_preds_with_contrib = local_regressor.predict(X, pred_contrib=True) if output == "scipy_csr_matrix": - preds_with_contrib = np.array(preds_with_contrib.todense()) + preds_with_contrib = preds_with_contrib.toarray() # contrib outputs for distributed training are different than from local training, so we can just test # that the output has the right shape and base values are in the right position