Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] fix Dask docstrings and mimic sklearn wrapper importing way #3855

Merged
merged 9 commits into from
Jan 26, 2021
3 changes: 1 addition & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
import scipy.sparse

from .compat import PANDAS_INSTALLED, DataFrame, Series, is_dtype_sparse, DataTable
from .compat import PANDAS_INSTALLED, DataFrame, Series, concat, is_dtype_sparse, DataTable
from .libpath import find_lib_path


Expand Down Expand Up @@ -2081,7 +2081,6 @@ def add_features_from(self, other):
if not PANDAS_INSTALLED:
raise LightGBMError("Cannot add features to DataFrame type of raw data "
"without pandas installed")
from pandas import concat
if isinstance(other.data, np.ndarray):
self.data = concat((self.data, DataFrame(other.data)),
axis=1, ignore_index=True)
Expand Down
25 changes: 21 additions & 4 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""pandas"""
try:
from pandas import Series, DataFrame
from pandas import Series, DataFrame, concat
from pandas.api.types import is_sparse as is_dtype_sparse
PANDAS_INSTALLED = True
except ImportError:
Expand All @@ -19,6 +19,7 @@ class DataFrame:

pass

concat = None
is_dtype_sparse = None

"""matplotlib"""
Expand Down Expand Up @@ -108,9 +109,25 @@ def _check_sample_weight(sample_weight, X, dtype=None):

"""dask"""
try:
from dask import array
from dask import dataframe
from dask.distributed import Client
from dask import delayed
from dask.array import Array as dask_Array
from dask.dataframe import _Frame as dask_Frame
from dask.distributed import Client, default_client, get_worker, wait
DASK_INSTALLED = True
except ImportError:
DASK_INSTALLED = False
delayed = None
Client = object
default_client = None
get_worker = None
wait = None

class dask_Array:
"""Dummy class for dask.array.Array."""

pass

class dask_Frame:
"""Dummy class for ddask.dataframe._Frame."""

pass
107 changes: 71 additions & 36 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,12 @@
from urllib.parse import urlparse

import numpy as np
import pandas as pd
import scipy.sparse as ss

from dask import array as da
from dask import dataframe as dd
from dask import delayed
from dask.distributed import Client, default_client, get_worker, wait

from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .compat import DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED
from .compat import (PANDAS_INSTALLED, DataFrame, Series, concat,
SKLEARN_INSTALLED,
DASK_INSTALLED, dask_Frame, dask_Array, delayed, Client, default_client, get_worker, wait)
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker


Expand All @@ -46,7 +42,7 @@ def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Itera

Returns
-------
result : int
port : int
A free port on the machine referenced by ``worker_ip``.
"""
max_tries = 1000
Expand Down Expand Up @@ -81,7 +77,7 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc
client : dask.distributed.Client
Dask client.
worker_addresses : Iterable[str]
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``.
local_listen_port : int
First port to try when searching for open ports.

Expand Down Expand Up @@ -109,8 +105,8 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc
def _concat(seq):
if isinstance(seq[0], np.ndarray):
return np.concatenate(seq, axis=0)
elif isinstance(seq[0], (pd.DataFrame, pd.Series)):
return pd.concat(seq, axis=0)
elif isinstance(seq[0], (DataFrame, Series)):
return concat(seq, axis=0)
Comment on lines +108 to +109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with importing from compat, but could the names be changed on import to something like pd_DataFrame?

from .compat import DataFrame as pd_DataFrame
from .compat import Series as pd_Series

Since both pandas and dask have a DataFrame class, I think just calling this DataFrame makes the code difficult to read. I know that I personally will read this in the future and think "wait does that mean pandas or Dask DataFrame".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm totally agree! But I think we can make import aliases in compat.py and then import like from .compat import pd_DataFrame. Otherwise in case of identical names it will be confusing to have

from dask import DataFrame
from pandas import DataFrame

in compat.py.

I'm going to rename only Dask imports in this PR to not overcomplicate review. pandas will be done in a follow-up PR. Do you agree?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes sounds good, thank you

elif isinstance(seq[0], ss.spmatrix):
return ss.vstack(seq, format='csr')
else:
Expand Down Expand Up @@ -152,9 +148,9 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
try:
model = model_factory(**params)
if is_ranker:
model.fit(data, y=label, sample_weight=weight, group=group, **kwargs)
model.fit(data, label, sample_weight=weight, group=group, **kwargs)
else:
model.fit(data, y=label, sample_weight=weight, **kwargs)
model.fit(data, label, sample_weight=weight, **kwargs)

finally:
_safe_call(_LIB.LGBM_NetworkFree())
Expand All @@ -178,13 +174,16 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group

Parameters
----------
client: dask.Client - client
X : dask array of shape = [n_samples, n_features]
client : dask.distributed.Client
Dask client.
data : dask array of shape = [n_samples, n_features]
Input feature matrix.
y : dask array of shape = [n_samples]
label : dask array of shape = [n_samples]
The target values (class labels in classification, real numbers in regression).
params : dict
Parameters passed to constructor of the local underlying model.
model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Class of the local underlying model.
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
Weights of training data.
group : array-like or None, optional (default=None)
Expand All @@ -193,6 +192,13 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
**kwargs
Other parameters passed to ``fit`` method of the local underlying model.

Returns
-------
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Returns fitted underlying model.
"""
params = deepcopy(params)

Expand Down Expand Up @@ -298,7 +304,7 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group


def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs):
data = part.values if isinstance(part, pd.DataFrame) else part
data = part.values if isinstance(part, DataFrame) else part

if data.shape[0] == 0:
result = np.array([])
Expand All @@ -319,11 +325,11 @@ def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, *
**kwargs
)

if isinstance(part, pd.DataFrame):
if isinstance(part, DataFrame):
if pred_proba or pred_contrib:
result = pd.DataFrame(result, index=part.index)
result = DataFrame(result, index=part.index)
else:
result = pd.Series(result, index=part.index, name='predictions')
result = Series(result, index=part.index, name='predictions')

return result

Expand All @@ -335,20 +341,34 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr
Parameters
----------
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Fitted underlying model.
data : dask array of shape = [n_samples, n_features]
Input feature matrix.
raw_score : bool, optional (default=False)
Whether to predict raw scores.
pred_proba : bool, optional (default=False)
Should method return results of ``predict_proba`` (``pred_proba=True``) or ``predict`` (``pred_proba=False``).
pred_leaf : bool, optional (default=False)
Whether to predict leaf index.
pred_contrib : bool, optional (default=False)
Whether to predict feature contributions.
dtype : np.dtype
dtype : np.dtype, optional (default=np.float32)
Dtype of the output.
kwargs : dict
**kwargs
Other parameters passed to ``predict`` or ``predict_proba`` method.

Returns
-------
predicted_result : dask array of shape = [n_samples] or shape = [n_samples, n_classes]
The predicted values.
X_leaves : dask arrayof shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
X_SHAP_values : dask array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects
If ``pred_contrib=True``, the feature contributions for each sample.
"""
if isinstance(data, dd._Frame):
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if isinstance(data, dask_Frame):
return data.map_partitions(
_predict_part,
model=model,
Expand All @@ -358,7 +378,7 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr
pred_contrib=pred_contrib,
**kwargs
).values
elif isinstance(data, da.Array):
elif isinstance(data, dask_Array):
if pred_proba:
kwargs['chunks'] = (data.chunks[0], (model.n_classes_,))
else:
Expand All @@ -378,12 +398,9 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr


class _DaskLGBMModel:
def __init__(self):
def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None, **kwargs):
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this check is being moved out of the constructor, then can you please put it in the _predict() function as well?

If someone tries to load a saved DaskLGBMClassifier from a pickle file (for example) and then use its .predict() method, I think we also want them to get an informative error about dask not being available. They won't get an ImportError on pickle.load() because of the magic of .compat.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! Will do.

if this check is being moved out of the constructor

I wish I could leave it in __init__(), but I realized that parent's constructor in _LGBMModel class is never called 🙁 .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah! didn't think about that when reviewing the MRO change in #3822


def _fit(self, model_factory, X, y=None, sample_weight=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
if client is None:
client = default_client()

Expand Down Expand Up @@ -422,7 +439,7 @@ def _copy_extra_params(source, dest):
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMClassifier."""

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
def fit(self, X, y, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(
model_factory=LGBMClassifier,
Expand All @@ -433,7 +450,12 @@ def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
**kwargs
)

fit.__doc__ = LGBMClassifier.fit.__doc__
_base_doc = LGBMClassifier.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)

def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
Expand Down Expand Up @@ -463,14 +485,15 @@ def to_local(self):
Returns
-------
model : lightgbm.LGBMClassifier
Local underlying model.
"""
return self._to_local(LGBMClassifier)


class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
"""Docstring is inherited from the lightgbm.LGBMRegressor."""
"""Distributed version of lightgbm.LGBMRegressor."""

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
def fit(self, X, y, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(
model_factory=LGBMRegressor,
Expand All @@ -481,7 +504,12 @@ def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
**kwargs
)

fit.__doc__ = LGBMRegressor.fit.__doc__
_base_doc = LGBMRegressor.fit.__doc__
_before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :')
fit.__doc__ = (_before_init_score
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)

def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
Expand All @@ -499,14 +527,15 @@ def to_local(self):
Returns
-------
model : lightgbm.LGBMRegressor
Local underlying model.
"""
return self._to_local(LGBMRegressor)


class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
"""Docstring is inherited from the lightgbm.LGBMRanker."""
"""Distributed version of lightgbm.LGBMRanker."""

def fit(self, X, y=None, sample_weight=None, init_score=None, group=None, client=None, **kwargs):
def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')
Expand All @@ -521,7 +550,12 @@ def fit(self, X, y=None, sample_weight=None, init_score=None, group=None, client
**kwargs
)

fit.__doc__ = LGBMRanker.fit.__doc__
_base_doc = LGBMRanker.fit.__doc__
_before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :')
fit.__doc__ = (_before_eval_set
+ 'client : dask.distributed.Client or None, optional (default=None)\n'
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _eval_set + _after_eval_set)

def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
Expand All @@ -535,5 +569,6 @@ def to_local(self):
Returns
-------
model : lightgbm.LGBMRanker
Local underlying model.
"""
return self._to_local(LGBMRanker)