Skip to content

Commit

Permalink
[python-packages] [docs] add type hints and define 'array-like' for X…
Browse files Browse the repository at this point in the history
…, y, group in scikit-learn interface (microsoft#5757)
  • Loading branch information
ClaudioSalvatoreArcidiacono committed Jun 5, 2023
1 parent 6a3ed13 commit 352936b
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 31 deletions.
3 changes: 2 additions & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
List[np.ndarray]
]
_LGBM_LabelType = Union[
list,
List[float],
List[int],
np.ndarray,
pd_Series,
pd_DataFrame
Expand Down
46 changes: 27 additions & 19 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import scipy.sparse

from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _LGBM_BoosterBestScoreType,
_LGBM_CategoricalFeatureConfiguration, _LGBM_EvalFunctionResultType, _LGBM_FeatureNameConfiguration,
_log_warning)
_LGBM_GroupType, _LGBM_LabelType, _log_warning)
from .callback import _EvalResultDict, record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
Expand All @@ -24,6 +25,13 @@
'LGBMRegressor',
]

_LGBM_ScikitMatrixLike = Union[
dt_DataTable,
List[Union[List[float], List[int]]],
np.ndarray,
pd_DataFrame,
scipy.sparse.spmatrix
]
_LGBM_ScikitCustomObjectiveFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Expand Down Expand Up @@ -697,11 +705,11 @@ def _process_n_jobs(self, n_jobs: Optional[int]) -> int:

def fit(
self,
X,
y,
X: _LGBM_ScikitMatrixLike,
y: _LGBM_LabelType,
sample_weight=None,
init_score=None,
group=None,
group: Optional[_LGBM_GroupType] = None,
eval_set=None,
eval_names: Optional[List[str]] = None,
eval_sample_weight=None,
Expand Down Expand Up @@ -829,19 +837,19 @@ def _get_meta_data(collection, name, i):
return self

fit.__doc__ = _lgbmmodel_doc_fit.format(
X_shape="array-like or sparse matrix of shape = [n_samples, n_features]",
y_shape="array-like of shape = [n_samples]",
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame , scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
y_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples]",
sample_weight_shape="array-like of shape = [n_samples] or None, optional (default=None)",
init_score_shape="array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) or shape = [n_samples, n_classes] (for multi-class task) or None, optional (default=None)",
group_shape="array-like or None, optional (default=None)",
group_shape="numpy array, pandas Series, list of int or float, or None, optional (default=None)",
eval_sample_weight_shape="list of array, or None, optional (default=None)",
eval_init_score_shape="list of array, or None, optional (default=None)",
eval_group_shape="list of array, or None, optional (default=None)"
) + "\n\n" + _lgbmmodel_doc_custom_eval_note

def predict(
self,
X,
X: _LGBM_ScikitMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
Expand Down Expand Up @@ -889,7 +897,7 @@ def predict(

predict.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted value for each sample.",
X_shape="array-like or sparse matrix of shape = [n_samples, n_features]",
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame , scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
output_name="predicted_result",
predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
Expand Down Expand Up @@ -993,8 +1001,8 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):

def fit( # type: ignore[override]
self,
X,
y,
X: _LGBM_ScikitMatrixLike,
y: _LGBM_LabelType,
sample_weight=None,
init_score=None,
eval_set=None,
Expand Down Expand Up @@ -1039,8 +1047,8 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):

def fit( # type: ignore[override]
self,
X,
y,
X: _LGBM_ScikitMatrixLike,
y: _LGBM_LabelType,
sample_weight=None,
init_score=None,
eval_set=None,
Expand Down Expand Up @@ -1127,7 +1135,7 @@ def fit( # type: ignore[override]

def predict(
self,
X,
X: _LGBM_ScikitMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
Expand Down Expand Up @@ -1157,7 +1165,7 @@ def predict(

def predict_proba(
self,
X,
X: _LGBM_ScikitMatrixLike,
raw_score: bool = False,
start_iteration: int = 0,
num_iteration: Optional[int] = None,
Expand Down Expand Up @@ -1189,7 +1197,7 @@ def predict_proba(

predict_proba.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted probability for each class for each sample.",
X_shape="array-like or sparse matrix of shape = [n_samples, n_features]",
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame , scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
output_name="predicted_probability",
predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
Expand Down Expand Up @@ -1223,11 +1231,11 @@ class LGBMRanker(LGBMModel):

def fit( # type: ignore[override]
self,
X,
y,
X: _LGBM_ScikitMatrixLike,
y: _LGBM_LabelType,
sample_weight=None,
init_score=None,
group=None,
group: Optional[_LGBM_GroupType] = None,
eval_set=None,
eval_names: Optional[List[str]] = None,
eval_sample_weight=None,
Expand Down
122 changes: 111 additions & 11 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,47 @@
import joblib
import numpy as np
import pytest
import scipy.sparse
from scipy.stats import spearmanr
from sklearn.base import clone
from sklearn.datasets import load_svmlight_file, make_blobs, make_multilabel_classification
from sklearn.ensemble import StackingClassifier, StackingRegressor
from sklearn.metrics import log_loss, mean_squared_error
from sklearn.metrics import accuracy_score, log_loss, mean_squared_error, r2_score
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier, MultiOutputRegressor, RegressorChain
from sklearn.utils.estimator_checks import parametrize_with_checks
from sklearn.utils.validation import check_is_fitted

import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame
from lightgbm.compat import DATATABLE_INSTALLED, PANDAS_INSTALLED, dt_DataTable, pd_DataFrame, pd_Series

from .utils import (load_breast_cancer, load_digits, load_iris, load_linnerud, make_ranking, make_synthetic_regression,
sklearn_multiclass_custom_objective, softmax)

decreasing_generator = itertools.count(0, -1)
task_to_model_factory = {
'ranking': lgb.LGBMRanker,
'classification': lgb.LGBMClassifier,
'binary-classification': lgb.LGBMClassifier,
'multiclass-classification': lgb.LGBMClassifier,
'regression': lgb.LGBMRegressor,
}


def _create_data(task):
def _create_data(task, n_samples=100, n_features=4):
if task == 'ranking':
X, y, g = make_ranking(n_features=4)
X, y, g = make_ranking(n_features=4, n_samples=n_samples)
g = np.bincount(g)
elif task == 'classification':
X, y = load_iris(return_X_y=True)
elif task.endswith('classification'):
if task == 'binary-classification':
centers = 2
elif task == 'multiclass-classification':
centers = 3
else:
ValueError(f"Unknown classification task '{task}'")
X, y = make_blobs(n_samples=n_samples, n_features=n_features, centers=centers, random_state=42)
g = None
elif task == 'regression':
X, y = make_synthetic_regression()
X, y = make_synthetic_regression(n_samples=n_samples, n_features=n_features)
g = None
return X, y, g

Expand Down Expand Up @@ -1268,7 +1277,7 @@ def test_sklearn_integration(estimator, check):
check(estimator)


@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression'])
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification', 'ranking', 'regression'])
def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task):
pd = pytest.importorskip("pandas")
X, y, g = _create_data(task)
Expand Down Expand Up @@ -1378,9 +1387,9 @@ def test_default_n_jobs(tmp_path):


@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed')
@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression'])
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification', 'ranking', 'regression'])
def test_validate_features(task):
X, y, g = _create_data(task)
X, y, g = _create_data(task, n_features=4)
features = ['x1', 'x2', 'x3', 'x4']
df = pd_DataFrame(X, columns=features)
model = task_to_model_factory[task](n_estimators=10, num_leaves=15, verbose=-1)
Expand All @@ -1397,3 +1406,94 @@ def test_validate_features(task):

# check that disabling the check doesn't raise the error
model.predict(df2, validate_features=False)


@pytest.mark.parametrize('X_type', ['dt_DataTable', 'list2d', 'numpy', 'scipy_csc', 'scipy_csr', 'pd_DataFrame'])
@pytest.mark.parametrize('y_type', ['list1d', 'numpy', 'pd_Series', 'pd_DataFrame'])
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification', 'regression'])
def test_classification_and_regression_minimally_work_with_all_all_accepted_data_types(X_type, y_type, task):
if any(t.startswith("pd_") for t in [X_type, y_type]) and not PANDAS_INSTALLED:
pytest.skip('pandas is not installed')
if any(t.startswith("dt_") for t in [X_type, y_type]) and not DATATABLE_INSTALLED:
pytest.skip('datatable is not installed')
X, y, g = _create_data(task, n_samples=1_000)
if X_type == 'dt_DataTable':
X = dt_DataTable(X)
elif X_type == 'list2d':
X = X.tolist()
elif X_type == 'scipy_csc':
X = scipy.sparse.csc_matrix(X)
elif X_type == 'scipy_csr':
X = scipy.sparse.csr_matrix(X)
elif X_type == 'pd_DataFrame':
X = pd_DataFrame(X)
elif X_type != 'numpy':
raise ValueError(f"Unrecognized X_type: '{X_type}'")

if y_type == 'list1d':
y = y.tolist()
elif y_type == 'pd_DataFrame':
y = pd_DataFrame(y)
elif y_type == 'pd_Series':
y = pd_Series(y)
elif y_type != 'numpy':
raise ValueError(f"Unrecognized y_type: '{y_type}'")

model = task_to_model_factory[task](n_estimators=10, verbose=-1)
model.fit(X, y)

preds = model.predict(X)
if task == 'binary-classification':
assert accuracy_score(y, preds) >= 0.99
elif task == 'multiclass-classification':
assert accuracy_score(y, preds) >= 0.99
elif task == 'regression':
assert r2_score(y, preds) > 0.86
else:
raise ValueError(f"Unrecognized task: '{task}'")


@pytest.mark.parametrize('X_type', ['dt_DataTable', 'list2d', 'numpy', 'scipy_csc', 'scipy_csr', 'pd_DataFrame'])
@pytest.mark.parametrize('y_type', ['list1d', 'numpy', 'pd_DataFrame', 'pd_Series'])
@pytest.mark.parametrize('g_type', ['list1d_float', 'list1d_int', 'numpy', 'pd_Series'])
def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type, g_type):
if any(t.startswith("pd_") for t in [X_type, y_type, g_type]) and not PANDAS_INSTALLED:
pytest.skip('pandas is not installed')
if any(t.startswith("dt_") for t in [X_type, y_type, g_type]) and not DATATABLE_INSTALLED:
pytest.skip('datatable is not installed')
X, y, g = _create_data(task='ranking', n_samples=1_000)
if X_type == 'dt_DataTable':
X = dt_DataTable(X)
elif X_type == 'list2d':
X = X.tolist()
elif X_type == 'scipy_csc':
X = scipy.sparse.csc_matrix(X)
elif X_type == 'scipy_csr':
X = scipy.sparse.csr_matrix(X)
elif X_type == 'pd_DataFrame':
X = pd_DataFrame(X)
elif X_type != 'numpy':
raise ValueError(f"Unrecognized X_type: '{X_type}'")

if y_type == 'list1d':
y = y.tolist()
elif y_type == 'pd_DataFrame':
y = pd_DataFrame(y)
elif y_type == 'pd_Series':
y = pd_Series(y)
elif y_type != 'numpy':
raise ValueError(f"Unrecognized y_type: '{y_type}'")

if g_type == 'list1d_float':
g = g.astype("float").tolist()
elif g_type == 'list1d_int':
g = g.astype("int").tolist()
elif g_type == 'pd_Series':
g = pd_Series(g)
elif g_type != 'numpy':
raise ValueError(f"Unrecognized g_type: '{g_type}'")

model = task_to_model_factory['ranking'](n_estimators=10, verbose=-1)
model.fit(X, y, group=g)
preds = model.predict(X)
assert spearmanr(preds, y).correlation >= 0.99

0 comments on commit 352936b

Please sign in to comment.