Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] add type hints for custom objective and metric functions in scikit-learn interface #4547

Merged
merged 8 commits into from
Nov 15, 2021
23 changes: 12 additions & 11 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from copy import deepcopy
from enum import Enum, auto
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from urllib.parse import urlparse

import numpy as np
Expand All @@ -21,8 +21,9 @@
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note,
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction,
_LGBM_ScikitCustomObjectiveFunction, _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit,
_lgbmmodel_doc_predict)

_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
Expand Down Expand Up @@ -400,7 +401,7 @@ def _train(
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
**kwargs: Any
) -> LGBMModel:
Expand Down Expand Up @@ -1029,7 +1030,7 @@ def _lgb_dask_fit(
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
Expand Down Expand Up @@ -1096,7 +1097,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[Union[_LGBM_ScikitCustomObjectiveFunction, str]] = None,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1165,7 +1166,7 @@ def fit(
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
Expand Down Expand Up @@ -1281,7 +1282,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[Union[_LGBM_ScikitCustomObjectiveFunction, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1348,7 +1349,7 @@ def fit(
eval_names: Optional[List[str]] = None,
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMRegressor":
Expand Down Expand Up @@ -1446,7 +1447,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[Union[_LGBM_ScikitCustomObjectiveFunction, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1516,7 +1517,7 @@ def fit(
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Iterable[int] = (1, 2, 3, 4, 5),
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
Expand Down
39 changes: 34 additions & 5 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,52 @@
"""Scikit-learn wrapper interface for LightGBM."""
import copy
from inspect import signature
from typing import Callable, Dict, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import scipy.sparse as ss

from .basic import Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _log_warning
from .callback import log_evaluation, record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
pd_DataFrame)
pd_DataFrame, pd_Series)
from .engine import train

_ArrayLike = Union[List, np.ndarray, pd_Series]
_EvalResultType = Tuple[str, float, bool]

_LGBM_ScikitCustomObjectiveFunction = Union[
Callable[
[_ArrayLike, np.ndarray],
Tuple[np.ndarray, np.ndarray]
],
Callable[
[_ArrayLike, np.ndarray, np.ndarray],
shiyu1994 marked this conversation as resolved.
Show resolved Hide resolved
Tuple[np.ndarray, np.ndarray]
],
]
_LGBM_ScikitCustomEvalFunction = Union[
Callable[
[_ArrayLike, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[_ArrayLike, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[_ArrayLike, np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
]


class _ObjectiveFunctionWrapper:
"""Proxy class for objective function."""

def __init__(self, func):
def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
"""Construct a proxy class.

This class transforms objective function to match objective function with signature ``new_func(preds, dataset)``
Expand Down Expand Up @@ -107,7 +136,7 @@ def __call__(self, preds, dataset):
class _EvalFunctionWrapper:
"""Proxy class for evaluation function."""

def __init__(self, func):
def __init__(self, func: _LGBM_ScikitCustomEvalFunction):
"""Construct a proxy class.

This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)``
Expand Down Expand Up @@ -358,7 +387,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, Callable]] = None,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down