Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] add type hints for custom objective and metric functions in scikit-learn interface #4547

Merged
merged 8 commits into from
Nov 15, 2021
22 changes: 11 additions & 11 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from copy import deepcopy
from enum import Enum, auto
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from urllib.parse import urlparse

import numpy as np
Expand All @@ -21,8 +21,8 @@
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note,
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction,
_lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict)

_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
Expand Down Expand Up @@ -400,7 +400,7 @@ def _train(
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
**kwargs: Any
) -> LGBMModel:
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def _lgb_dask_fit(
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1165,7 +1165,7 @@ def fit(
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
Expand Down Expand Up @@ -1281,7 +1281,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1348,7 +1348,7 @@ def fit(
eval_names: Optional[List[str]] = None,
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMRegressor":
Expand Down Expand Up @@ -1446,7 +1446,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[str] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1516,7 +1516,7 @@ def fit(
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Iterable[int] = (1, 2, 3, 4, 5),
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
Expand Down
38 changes: 33 additions & 5 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Scikit-learn wrapper interface for LightGBM."""
import copy
from inspect import signature
from typing import Callable, Dict, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -11,14 +11,42 @@
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
pd_DataFrame)
pd_DataFrame, pd_Series)
from .engine import train

_ArrayLike = Union[List, np.ndarray, pd_Series]
_EvalResultType = Tuple[str, float, bool]

_LGBM_ScikitCustomObjectiveFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Tuple[_ArrayLike, _ArrayLike]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Tuple[_ArrayLike, _ArrayLike]
],
]
_LGBM_ScikitCustomEvalFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
],
]


class _ObjectiveFunctionWrapper:
"""Proxy class for objective function."""

def __init__(self, func):
def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
"""Construct a proxy class.

This class transforms objective function to match objective function with signature ``new_func(preds, dataset)``
Expand Down Expand Up @@ -107,7 +135,7 @@ def __call__(self, preds, dataset):
class _EvalFunctionWrapper:
"""Proxy class for evaluation function."""

def __init__(self, func):
def __init__(self, func: _LGBM_ScikitCustomEvalFunction):
"""Construct a proxy class.

This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)``
Expand Down Expand Up @@ -358,7 +386,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, Callable]] = None,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down