Skip to content

Commit

Permalink
[dask] Support pred_contrib in Dask predict() methods (fixes #3713) (#…
Browse files Browse the repository at this point in the history
…3774)

* adding pred_contrib support

* add tests

* linting

* remove raw_score

* add pred kwargs

* faster tests

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* changes to tests

* Update tests/python_package_test/test_dask.py

Co-authored-by: Nikita Titov <[email protected]>

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
jameslamb and StrikerRUS authored Jan 22, 2021
1 parent 3c7e7e0 commit d9a96c9
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 13 deletions.
61 changes: 48 additions & 13 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,47 +280,82 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
return results[0]


def _predict_part(part, model, proba, **kwargs):
def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs):
data = part.values if isinstance(part, pd.DataFrame) else part

if data.shape[0] == 0:
result = np.array([])
elif proba:
result = model.predict_proba(data, **kwargs)
elif pred_proba:
result = model.predict_proba(
data,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
**kwargs
)
else:
result = model.predict(data, **kwargs)
result = model.predict(
data,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
**kwargs
)

if isinstance(part, pd.DataFrame):
if proba:
if pred_proba or pred_contrib:
result = pd.DataFrame(result, index=part.index)
else:
result = pd.Series(result, index=part.index, name='predictions')

return result


def _predict(model, data, proba=False, dtype=np.float32, **kwargs):
def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pred_contrib=False,
dtype=np.float32, **kwargs):
"""Inner predict routine.
Parameters
----------
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
data : dask array of shape = [n_samples, n_features]
Input feature matrix.
proba : bool
Should method return results of predict_proba (proba == True) or predict (proba == False).
pred_proba : bool, optional (default=False)
Should method return results of ``predict_proba`` (``pred_proba=True``) or ``predict`` (``pred_proba=False``).
pred_leaf : bool, optional (default=False)
Whether to predict leaf index.
pred_contrib : bool, optional (default=False)
Whether to predict feature contributions.
dtype : np.dtype
Dtype of the output.
kwargs : other parameters passed to predict or predict_proba method
kwargs : dict
Other parameters passed to ``predict`` or ``predict_proba`` method.
"""
if isinstance(data, dd._Frame):
return data.map_partitions(_predict_part, model=model, proba=proba, **kwargs).values
return data.map_partitions(
_predict_part,
model=model,
raw_score=raw_score,
pred_proba=pred_proba,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
**kwargs
).values
elif isinstance(data, da.Array):
if proba:
if pred_proba:
kwargs['chunks'] = (data.chunks[0], (model.n_classes_,))
else:
kwargs['drop_axis'] = 1
return data.map_blocks(_predict_part, model=model, proba=proba, dtype=dtype, **kwargs)
return data.map_blocks(
_predict_part,
model=model,
raw_score=raw_score,
pred_proba=pred_proba,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
dtype=dtype,
**kwargs
)
else:
raise TypeError('Data must be either Dask array or dataframe. Got %s.' % str(type(data)))

Expand Down Expand Up @@ -370,7 +405,7 @@ def predict(self, X, **kwargs):

def predict_proba(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
return _predict(self.to_local(), X, proba=True, **kwargs)
return _predict(self.to_local(), X, pred_proba=True, **kwargs)
predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__

def to_local(self):
Expand Down
80 changes: 80 additions & 0 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,55 @@ def test_classifier(output, centers, client, listen_port):
client.close()


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier_pred_contrib(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)

dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port,
tree_learner='data',
n_estimators=10,
num_leaves=10
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute()

local_classifier = lightgbm.LGBMClassifier(
n_estimators=10,
num_leaves=10
)
local_classifier.fit(X, y, sample_weight=w)
local_preds_with_contrib = local_classifier.predict(X, pred_contrib=True)

if output == 'scipy_csr_matrix':
preds_with_contrib = np.array(preds_with_contrib.todense())

# shape depends on whether it is binary or multiclass classification
num_features = dask_classifier.n_features_
num_classes = dask_classifier.n_classes_
if num_classes == 2:
expected_num_cols = num_features + 1
else:
expected_num_cols = (num_features + 1) * num_classes

# * shape depends on whether it is binary or multiclass classification
# * matrix for binary classification is of the form [feature_contrib, base_value],
# for multi-class it's [feat_contrib_class1, base_value_class1, feat_contrib_class2, base_value_class2, etc.]
# * contrib outputs for distributed training are different than from local training, so we can just test
# that the output has the right shape and base values are in the right position
assert preds_with_contrib.shape[1] == expected_num_cols
assert preds_with_contrib.shape == local_preds_with_contrib.shape

if num_classes == 2:
assert len(np.unique(preds_with_contrib[:, num_features]) == 1)
else:
for i in range(num_classes):
base_value_col = num_features * (i + 1) + i
assert len(np.unique(preds_with_contrib[:, base_value_col]) == 1)


def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, dX, dy, dw = _create_data('classification', output='array')

Expand Down Expand Up @@ -315,6 +364,37 @@ def test_regressor(output, client, listen_port):
client.close()


@pytest.mark.parametrize('output', data_output)
def test_regressor_pred_contrib(output, client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output=output)

dask_regressor = dlgbm.DaskLGBMRegressor(
time_out=5,
local_listen_port=listen_port,
tree_learner='data',
n_estimators=10,
num_leaves=10
)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
preds_with_contrib = dask_regressor.predict(dX, pred_contrib=True).compute()

local_regressor = lightgbm.LGBMRegressor(
n_estimators=10,
num_leaves=10
)
local_regressor.fit(X, y, sample_weight=w)
local_preds_with_contrib = local_regressor.predict(X, pred_contrib=True)

if output == "scipy_csr_matrix":
preds_with_contrib = np.array(preds_with_contrib.todense())

# contrib outputs for distributed training are different than from local training, so we can just test
# that the output has the right shape and base values are in the right position
num_features = dX.shape[1]
assert preds_with_contrib.shape[1] == num_features + 1
assert preds_with_contrib.shape == local_preds_with_contrib.shape


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('alpha', [.1, .5, .9])
def test_regressor_quantile(output, client, listen_port, alpha):
Expand Down

0 comments on commit d9a96c9

Please sign in to comment.