Skip to content

Commit

Permalink
[python-package] add more hints in sklearn.py (#5460)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Sep 12, 2022
1 parent 8b105ce commit c3cf335
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 26 deletions.
22 changes: 11 additions & 11 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction,
_LGBM_ScikitCustomObjectiveFunction, _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit,
_lgbmmodel_doc_predict)
_LGBM_ScikitCustomObjectiveFunction, _LGBM_ScikitEvalMetricType, _lgbmmodel_doc_custom_eval_note,
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict)

_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
Expand Down Expand Up @@ -405,8 +405,8 @@ def _train(
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
eval_at: Optional[Union[List[int], Tuple[int]]] = None,
**kwargs: Any
) -> LGBMModel:
"""Inner train routine.
Expand Down Expand Up @@ -454,7 +454,7 @@ def _train(
If list, it can be a list of built-in metrics, a list of custom evaluation metrics, or a mix of both.
In either case, the ``metric`` from the Dask model parameters (or inferred from the objective) will be evaluated and used as well.
Default: 'l2' for DaskLGBMRegressor, 'binary(multi)_logloss' for DaskLGBMClassifier, 'ndcg' for DaskLGBMRanker.
eval_at : iterable of int, optional (default=None)
eval_at : list or tuple of int, optional (default=None)
The evaluation positions of the specified ranking metric.
**kwargs
Other parameters passed to ``fit`` method of the local underlying model.
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def _lgb_dask_fit(
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
eval_at: Optional[Iterable[int]] = None,
**kwargs: Any
) -> "_DaskLGBMModel":
Expand Down Expand Up @@ -1163,7 +1163,7 @@ def fit(
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
Expand Down Expand Up @@ -1334,7 +1334,7 @@ def fit(
eval_names: Optional[List[str]] = None,
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
**kwargs: Any
) -> "DaskLGBMRegressor":
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
Expand Down Expand Up @@ -1489,8 +1489,8 @@ def fit(
eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Iterable[int] = (1, 2, 3, 4, 5),
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
eval_at: Union[List[int], Tuple[int]] = (1, 2, 3, 4, 5),
**kwargs: Any
) -> "DaskLGBMRanker":
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
Expand Down Expand Up @@ -1527,7 +1527,7 @@ def fit(
+ _base_doc[_base_doc.find('eval_init_score :'):])

_base_doc = (_base_doc[:_base_doc.find('feature_name :')]
+ "eval_at : iterable of int, optional (default=(1, 2, 3, 4, 5))\n"
+ "eval_at : list or tuple of int, optional (default=(1, 2, 3, 4, 5))\n"
+ f"{' ':8}The evaluation positions of the specified metric.\n"
+ f"{' ':4}{_base_doc[_base_doc.find('feature_name :'):]}")

Expand Down
36 changes: 21 additions & 15 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"""Scikit-learn wrapper interface for LightGBM."""
import copy
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -43,6 +44,11 @@
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
],
]
_LGBM_ScikitEvalMetricType = Union[
str,
_LGBM_ScikitCustomEvalFunction,
List[Union[str, _LGBM_ScikitCustomEvalFunction]]
]


class _ObjectiveFunctionWrapper:
Expand Down Expand Up @@ -686,16 +692,16 @@ def fit(
init_score=None,
group=None,
eval_set=None,
eval_names=None,
eval_names: Optional[List[str]] = None,
eval_sample_weight=None,
eval_class_weight=None,
eval_init_score=None,
eval_group=None,
eval_metric=None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
init_model: Optional[Union[str, Path, Booster, "LGBMModel"]] = None
):
"""Docstring is set after definition, using a template."""
params = self._process_params(stage="fit")
Expand Down Expand Up @@ -979,14 +985,14 @@ def fit(
sample_weight=None,
init_score=None,
eval_set=None,
eval_names=None,
eval_names: Optional[List[str]] = None,
eval_sample_weight=None,
eval_init_score=None,
eval_metric=None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None
):
"""Docstring is inherited from the LGBMModel."""
super().fit(
Expand Down Expand Up @@ -1025,15 +1031,15 @@ def fit(
sample_weight=None,
init_score=None,
eval_set=None,
eval_names=None,
eval_names: Optional[List[str]] = None,
eval_sample_weight=None,
eval_class_weight=None,
eval_init_score=None,
eval_metric=None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None
):
"""Docstring is inherited from the LGBMModel."""
_LGBMAssertAllFinite(y)
Expand Down Expand Up @@ -1187,16 +1193,16 @@ def fit(
init_score=None,
group=None,
eval_set=None,
eval_names=None,
eval_names: Optional[List[str]] = None,
eval_sample_weight=None,
eval_init_score=None,
eval_group=None,
eval_metric=None,
eval_at=(1, 2, 3, 4, 5),
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
eval_at: Union[List[int], Tuple[int]] = (1, 2, 3, 4, 5),
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None
):
"""Docstring is inherited from the LGBMModel."""
# check group data
Expand Down Expand Up @@ -1240,6 +1246,6 @@ def fit(
+ _base_doc[_base_doc.find('eval_init_score :'):]) # type: ignore
_base_doc = fit.__doc__
_before_feature_name, _feature_name, _after_feature_name = _base_doc.partition('feature_name :')
fit.__doc__ = f"""{_before_feature_name}eval_at : iterable of int, optional (default=(1, 2, 3, 4, 5))
fit.__doc__ = f"""{_before_feature_name}eval_at : list or tuple of int, optional (default=(1, 2, 3, 4, 5))
The evaluation positions of the specified metric.
{_feature_name}{_after_feature_name}"""

0 comments on commit c3cf335

Please sign in to comment.