Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] add more hints in sklearn.py #5460

Merged
merged 8 commits into from
Sep 12, 2022
34 changes: 20 additions & 14 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"""Scikit-learn wrapper interface for LightGBM."""
import copy
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -43,6 +44,11 @@
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
],
]
_LGBM_ScikitEvalMetricType = Union[
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
str,
_LGBM_ScikitCustomEvalFunction,
List[Union[str, _LGBM_ScikitCustomEvalFunction]]
]


class _ObjectiveFunctionWrapper:
Expand Down Expand Up @@ -686,16 +692,16 @@ def fit(
init_score=None,
group=None,
eval_set=None,
eval_names=None,
eval_names: Optional[List[str]] = None,
eval_sample_weight=None,
eval_class_weight=None,
eval_init_score=None,
eval_group=None,
eval_metric=None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
init_model: Optional[Union[str, Path, Booster, "LGBMModel"]] = None
):
"""Docstring is set after definition, using a template."""
params = self._process_params(stage="fit")
Expand Down Expand Up @@ -979,14 +985,14 @@ def fit(
sample_weight=None,
init_score=None,
eval_set=None,
eval_names=None,
eval_names: Optional[List[str]] = None,
eval_sample_weight=None,
eval_init_score=None,
eval_metric=None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None
):
"""Docstring is inherited from the LGBMModel."""
super().fit(
Expand Down Expand Up @@ -1025,15 +1031,15 @@ def fit(
sample_weight=None,
init_score=None,
eval_set=None,
eval_names=None,
eval_names: Optional[List[str]] = None,
eval_sample_weight=None,
eval_class_weight=None,
eval_init_score=None,
eval_metric=None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None
):
"""Docstring is inherited from the LGBMModel."""
_LGBMAssertAllFinite(y)
Expand Down Expand Up @@ -1187,16 +1193,16 @@ def fit(
init_score=None,
group=None,
eval_set=None,
eval_names=None,
eval_names: Optional[List[str]] = None,
eval_sample_weight=None,
eval_init_score=None,
eval_group=None,
eval_metric=None,
eval_at=(1, 2, 3, 4, 5),
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
eval_at: Iterable[int] = (1, 2, 3, 4, 5),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be Sequence[int], given that for example range(5) is an iterable but is not a valid type for this.

Copy link
Collaborator Author

@jameslamb jameslamb Sep 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤩 Excellent point. I was just following the docstring and didn't consider this. Thanks very much for noting it!!!

I tried the following from the root of the repo, just to see what would happen:

sample code (click me)
from pathlib import Path
import numpy as np
import lightgbm as lgb
from sklearn.datasets import load_svmlight_file

rank_example_dir = Path('examples/lambdarank')
X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train'))
X_test, y_test = load_svmlight_file(str(rank_example_dir / 'rank.test'))
q_train = np.loadtxt(str(rank_example_dir / 'rank.train.query'))
q_test = np.loadtxt(str(rank_example_dir / 'rank.test.query'))
gbm = lgb.LGBMRanker(n_estimators=10)
gbm.fit(
    X_train,
    y_train,
    group=q_train,
    eval_set=[(X_test, y_test)],
    eval_group=[q_test],
    eval_at=range(3),
    callbacks=[
        lgb.early_stopping(10),
        lgb.reset_parameter(learning_rate=lambda x: max(0.01, 0.1 - 0.01 * x))
    ]
)

And you're right...passing a range for eval_at causes a failure when serializing the parameters to string to pass them through the C API functions.

  File ".../site-packages/lightgbm/basic.py", line 326, in param_dict_to_str
    raise TypeError(f'Unknown type of parameter:{key}, got:{type(val).__name__}')
TypeError: Unknown type of parameter:eval_at, got:range

Given that, I think the hint here should be even stricter than typing.Sequence. Since this keyword argument is passed directly through to params and there's no other code in LightGBM manipulating its value, I think it can only accept values that are valid for lightgbm.basic.param_dict_to_str().

For eval_at, I think that means only a list of ints or tuple of ints is valid. param_dict_to_str() supports list, tuple, and set, but set isn't appropriate for eval_at because sets aren't iterable (e.g. don't have any ordering).

if isinstance(val, (list, tuple, set)) or is_numpy_1d_array(val):

I just pushed 81c234f which:

  • sets the hint for eval_at to Union[List[int], Tuple[int]]
  • replaces use of the word "iterable" in the relevant docstrings with "list or tuple of int"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmoralez I won't merge this until you have a chance to respond, since what I did here is slightly different than what you suggested.

feature_name='auto',
categorical_feature='auto',
callbacks=None,
init_model=None
init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None
):
"""Docstring is inherited from the LGBMModel."""
# check group data
Expand Down