Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Make output of feature contribution predictions for sparse matrices match those from sklearn estimators (fixes #3881) #4378

Merged
merged 39 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a397d23
test_classifier working
jameslamb Jun 11, 2021
b90040f
adding tests
jameslamb Jun 13, 2021
c9032a0
Merge branch 'master' into fix/dask-multiclass-sparse-predict
jameslamb Jun 13, 2021
2a09151
docs
jameslamb Jun 13, 2021
7437274
tests
jameslamb Jun 13, 2021
4e40e97
Merge branch 'master' into fix/dask-multiclass-sparse-predict
jameslamb Jun 14, 2021
058188b
revert unnecessary changes in tests
jameslamb Jun 14, 2021
8e7df9a
test output type
jameslamb Jun 14, 2021
13b9c3b
linting
jameslamb Jun 15, 2021
f46da71
linting
jameslamb Jun 15, 2021
e6072bf
use from_delayed() instead
jameslamb Jun 16, 2021
8378cee
docstring pycodestyle is happy with
jameslamb Jun 16, 2021
f86220d
isort
jameslamb Jun 16, 2021
8090966
Merge branch 'master' into fix/dask-multiclass-sparse-predict
jameslamb Jun 16, 2021
a6557a8
put pytest skips back
jameslamb Jun 16, 2021
45ebf7d
respect sparse return type
jameslamb Jun 22, 2021
585fef4
fix doc
jameslamb Jun 22, 2021
c5a0483
remove unnecessary dask_array_concatenate()
jameslamb Jun 23, 2021
c9c2cfa
Merge branch 'master' into fix/dask-multiclass-sparse-predict
jameslamb Jun 23, 2021
3ecfed5
merge main
jameslamb Jun 28, 2021
3361950
Apply suggestions from code review
jameslamb Jul 4, 2021
1a7c462
Merge branch 'master' into fix/dask-multiclass-sparse-predict
jameslamb Jul 4, 2021
7766121
Apply suggestions from code review
jameslamb Jul 4, 2021
86e680e
Merge branch 'fix/dask-multiclass-sparse-predict' of github.com:micro…
jameslamb Jul 4, 2021
810a0b8
update predict_proba() docstring
jameslamb Jul 4, 2021
f79ab3c
remove unnecessary np.array()
jameslamb Jul 4, 2021
6a7da56
Update python-package/lightgbm/dask.py
jameslamb Jul 4, 2021
9a3f6f4
Merge branch 'master' into fix/dask-multiclass-sparse-predict
jameslamb Jul 4, 2021
9b4b980
Merge branch 'fix/dask-multiclass-sparse-predict' of github.com:micro…
jameslamb Jul 4, 2021
95ae45f
fix assertion
jameslamb Jul 4, 2021
43c3a9c
Merge branch 'master' into fix/dask-multiclass-sparse-predict
jameslamb Jul 4, 2021
a7d6d37
fix test use of len()
jameslamb Jul 4, 2021
ff49a58
Merge branch 'master' into fix/dask-multiclass-sparse-predict
jameslamb Jul 5, 2021
831927b
restore np.array() in tests
jameslamb Jul 5, 2021
61f31dd
use np.asarray() instead
jameslamb Jul 5, 2021
466917e
Merge branch 'master' into fix/dask-multiclass-sparse-predict
jameslamb Jul 5, 2021
57edeaa
use toarray()
jameslamb Jul 5, 2021
c4c8da8
Merge branch 'master' into fix/dask-multiclass-sparse-predict
jameslamb Jul 6, 2021
e4fed6f
remove empty functions in compat
jameslamb Jul 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class _LGBMRegressorBase: # type: ignore
try:
from dask import delayed
from dask.array import Array as dask_Array
from dask.array import from_delayed as dask_array_from_delayed
from dask.bag import from_delayed as dask_bag_from_delayed
from dask.dataframe import DataFrame as dask_DataFrame
from dask.dataframe import Series as dask_Series
from dask.distributed import Client, default_client, wait
Expand All @@ -146,6 +148,14 @@ class dask_Array: # type: ignore

pass

def dask_array_from_delayed(*args, **kwargs): # type: ignore
"""Mock function for dask.array.from_delayed()."""
pass

def dask_bag_from_delayed(*args, **kwargs): # type: ignore
"""Mock function for dask.bag.from_delayed()."""
pass
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved

class dask_DataFrame: # type: ignore
"""Dummy class for dask.dataframe.DataFrame."""

Expand Down
80 changes: 75 additions & 5 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections import defaultdict
from copy import deepcopy
from enum import Enum, auto
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from urllib.parse import urlparse

Expand All @@ -18,7 +19,8 @@

from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning, _safe_call
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
dask_Array, dask_DataFrame, dask_Series, default_client, delayed, pd_DataFrame, pd_Series, wait)
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note,
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict)

Expand Down Expand Up @@ -842,7 +844,7 @@ def _predict(
pred_contrib: bool = False,
dtype: _PredictionDtype = np.float32,
**kwargs: Any
) -> dask_Array:
) -> Union[dask_Array, List[dask_Array]]:
"""Inner predict routine.

Parameters
Expand Down Expand Up @@ -870,7 +872,7 @@ def _predict(
The predicted values.
X_leaves : Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]
X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]
If ``pred_contrib=True``, the feature contributions for each sample.
"""
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
Expand All @@ -886,6 +888,74 @@ def _predict(
**kwargs
).values
elif isinstance(data, dask_Array):
# for multi-class classification with sparse matrices, pred_contrib predictions
# are returned as a list of sparse matrices (one per class)
num_classes = model._n_classes or -1

if (
num_classes > 2
and pred_contrib
and isinstance(data._meta, ss.spmatrix)
):

predict_function = partial(
_predict_part,
model=model,
raw_score=False,
pred_proba=pred_proba,
pred_leaf=False,
pred_contrib=True,
**kwargs
)

delayed_chunks = data.to_delayed()
bag = dask_bag_from_delayed(delayed_chunks[:, 0])

@delayed
def _extract(items: List[Any], i: int) -> Any:
return items[i]

preds = bag.map_partitions(predict_function)

# pred_contrib output will have one column per feature,
# plus one more for the base value
num_cols = model.n_features_ + 1

nrows_per_chunk = data.chunks[0]
out = [[] for _ in range(num_classes)]

# need to tell Dask the expected type and shape of individual preds
pred_meta = data._meta

for j, partition in enumerate(preds.to_delayed()):
for i in range(num_classes):
part = dask_array_from_delayed(
value=_extract(partition, i),
shape=(nrows_per_chunk[j], num_cols),
meta=pred_meta
)
out[i].append(part)

# by default, dask.array.concatenate() concatenates sparse arrays into a COO matrix
# the code below is used instead to ensure that the sparse type is preserved during concatentation
if isinstance(pred_meta, ss.csr_matrix):
concat_fn = partial(ss.vstack, format='csr')
elif isinstance(pred_meta, ss.csc_matrix):
concat_fn = partial(ss.vstack, format='csc')
else:
concat_fn = ss.vstack

# At this point, `out` is a list of lists of delayeds (each of which points to a matrix).
# Concatenate them to return a list of Dask Arrays.
for i in range(num_classes):
out[i] = dask_array_from_delayed(
value=delayed(concat_fn)(out[i]),
shape=(data.shape[0], num_cols),
meta=pred_meta
)

return out

return data.map_blocks(
_predict_part,
model=model,
Expand Down Expand Up @@ -1140,7 +1210,7 @@ def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
output_name="predicted_result",
predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]"
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]"
)

def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
Expand All @@ -1158,7 +1228,7 @@ def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
output_name="predicted_probability",
predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]"
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]"
)

def to_local(self) -> LGBMClassifier:
Expand Down
66 changes: 51 additions & 15 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dask.array.utils import assert_eq
from dask.distributed import Client, LocalCluster, default_client, wait
from pkg_resources import parse_version
from scipy.sparse import csr_matrix
from scipy.sparse import csc_matrix, csr_matrix
from scipy.stats import spearmanr
from sklearn import __version__ as sk_version
from sklearn.datasets import make_blobs, make_regression
Expand Down Expand Up @@ -198,6 +198,12 @@ def _create_data(objective, n_samples=1_000, output='array', chunk_size=500, **k
dX = da.from_array(X, chunks=(chunk_size, X.shape[1])).map_blocks(csr_matrix)
dy = da.from_array(y, chunks=chunk_size)
dw = da.from_array(weights, chunk_size)
X = csr_matrix(X)
elif output == 'scipy_csc_matrix':
dX = da.from_array(X, chunks=(chunk_size, X.shape[1])).map_blocks(csc_matrix)
dy = da.from_array(y, chunks=chunk_size)
dw = da.from_array(weights, chunk_size)
X = csc_matrix(X)
else:
raise ValueError(f"Unknown output type '{output}'")

Expand Down Expand Up @@ -344,7 +350,7 @@ def test_classifier(output, task, boosting_type, tree_learner, cluster):
assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '=='


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('output', data_output + ['scipy_csc_matrix'])
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
def test_classifier_pred_contrib(output, task, cluster):
with Client(cluster) as client:
Expand All @@ -365,14 +371,52 @@ def test_classifier_pred_contrib(output, task, cluster):
**params
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute()
preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True)

local_classifier = lgb.LGBMClassifier(**params)
local_classifier.fit(X, y, sample_weight=w)
local_preds_with_contrib = local_classifier.predict(X, pred_contrib=True)

if output == 'scipy_csr_matrix':
preds_with_contrib = np.array(preds_with_contrib.todense())
# shape depends on whether it is binary or multiclass classification
num_features = dask_classifier.n_features_
num_classes = dask_classifier.n_classes_
if num_classes == 2:
expected_num_cols = num_features + 1
else:
expected_num_cols = (num_features + 1) * num_classes

# in the special case of multi-class classification using scipy sparse matrices,
# the output of `.predict(..., pred_contrib=True)` is a list of sparse matrices (one per class)
#
# since that case is so different than all other cases, check the relevant things here
# and then return early
if output.startswith('scipy') and task == 'multiclass-classification':
if output == 'scipy_csr_matrix':
expected_type = csr_matrix
elif output == 'scipy_csc_matrix':
expected_type = csc_matrix
else:
raise ValueError(f"Unrecognized output type: {output}")
assert isinstance(preds_with_contrib, list)
assert all(isinstance(arr, da.Array) for arr in preds_with_contrib)
assert all(isinstance(arr._meta, expected_type) for arr in preds_with_contrib)
assert len(preds_with_contrib) == num_classes
assert len(preds_with_contrib) == len(local_preds_with_contrib)
for i in range(num_classes):
computed_preds = preds_with_contrib[i].compute()
assert isinstance(computed_preds, expected_type)
assert computed_preds.shape[1] == num_classes
assert computed_preds.shape == local_preds_with_contrib[i].shape
assert len(np.unique(computed_preds[:, -1])) == 1
# raw scores will probably be different, but at least check that all predicted classes are the same
pred_classes = np.argmax(computed_preds.toarray(), axis=1)
local_pred_classes = np.argmax(local_preds_with_contrib[i].toarray(), axis=1)
np.testing.assert_array_equal(pred_classes, local_pred_classes)
return

preds_with_contrib = preds_with_contrib.compute()
if output.startswith('scipy'):
preds_with_contrib = preds_with_contrib.toarray()

# be sure LightGBM actually used at least one categorical column,
# and that it was correctly treated as a categorical feature
Expand All @@ -386,14 +430,6 @@ def test_classifier_pred_contrib(output, task, cluster):
assert node_uses_cat_col.sum() > 0
assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '=='

# shape depends on whether it is binary or multiclass classification
num_features = dask_classifier.n_features_
num_classes = dask_classifier.n_classes_
if num_classes == 2:
expected_num_cols = num_features + 1
else:
expected_num_cols = (num_features + 1) * num_classes

# * shape depends on whether it is binary or multiclass classification
# * matrix for binary classification is of the form [feature_contrib, base_value],
# for multi-class it's [feat_contrib_class1, base_value_class1, feat_contrib_class2, base_value_class2, etc.]
Expand All @@ -403,7 +439,7 @@ def test_classifier_pred_contrib(output, task, cluster):
assert preds_with_contrib.shape == local_preds_with_contrib.shape

if num_classes == 2:
assert len(np.unique(preds_with_contrib[:, num_features]) == 1)
assert len(np.unique(preds_with_contrib[:, num_features])) == 1
else:
for i in range(num_classes):
base_value_col = num_features * (i + 1) + i
Expand Down Expand Up @@ -585,7 +621,7 @@ def test_regressor_pred_contrib(output, cluster):
local_preds_with_contrib = local_regressor.predict(X, pred_contrib=True)

if output == "scipy_csr_matrix":
preds_with_contrib = np.array(preds_with_contrib.todense())
preds_with_contrib = preds_with_contrib.toarray()

# contrib outputs for distributed training are different than from local training, so we can just test
# that the output has the right shape and base values are in the right position
Expand Down