Skip to content

Commit

Permalink
[python-package] add type hints on Booster eval methods (#5433)
Browse files Browse the repository at this point in the history
* [python-package] add type hints on Booster eval methods

* remove unnecessary changes

* fix hints
  • Loading branch information
jameslamb authored Aug 25, 2022
1 parent 39eb041 commit 581d53c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 15 deletions.
43 changes: 36 additions & 7 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series
from .libpath import find_lib_path

_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]

ZERO_THRESHOLD = 1e-35


Expand Down Expand Up @@ -2635,6 +2638,16 @@ def _dump_text(self, filename: Union[str, Path]) -> "Dataset":
[np.ndarray, Dataset],
Tuple[np.ndarray, np.ndarray]
]
_LGBM_CustomEvalFunction = Union[
Callable[
[np.ndarray, Dataset],
_LGBM_EvalFunctionResultType
],
Callable[
[np.ndarray, Dataset],
List[_LGBM_EvalFunctionResultType]
]
]


class Booster:
Expand Down Expand Up @@ -3273,7 +3286,12 @@ def lower_bound(self) -> float:
ctypes.byref(ret)))
return ret.value

def eval(self, data, name, feval=None):
def eval(
self,
data: Dataset,
name: str,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for data.
Parameters
Expand Down Expand Up @@ -3304,7 +3322,7 @@ def eval(self, data, name, feval=None):
Returns
-------
result : list
List with evaluation results.
List with (dataset_name, eval_name, eval_result, is_higher_better) tuples.
"""
if not isinstance(data, Dataset):
raise TypeError("Can only eval for Dataset instance")
Expand All @@ -3323,7 +3341,10 @@ def eval(self, data, name, feval=None):

return self.__inner_eval(name, data_idx, feval)

def eval_train(self, feval=None):
def eval_train(
self,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for training data.
Parameters
Expand All @@ -3350,11 +3371,14 @@ def eval_train(self, feval=None):
Returns
-------
result : list
List with evaluation results.
List with (train_dataset_name, eval_name, eval_result, is_higher_better) tuples.
"""
return self.__inner_eval(self._train_data_name, 0, feval)

def eval_valid(self, feval=None):
def eval_valid(
self,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for validation data.
Parameters
Expand All @@ -3381,7 +3405,7 @@ def eval_valid(self, feval=None):
Returns
-------
result : list
List with evaluation results.
List with (validation_dataset_name, eval_name, eval_result, is_higher_better) tuples.
"""
return [item for i in range(1, self.__num_dataset)
for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)]
Expand Down Expand Up @@ -3987,7 +4011,12 @@ def add(root):
else:
return hist, bin_edges

def __inner_eval(self, data_name, data_idx, feval=None):
def __inner_eval(
self,
data_name: str,
data_idx: int,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate training or validation data."""
if data_idx >= self.__num_dataset:
raise ValueError("Data_idx should be smaller than number of dataset")
Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union

from .basic import _ConfigAliases, _log_info, _log_warning
from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning

_EvalResultTuple = Union[
List[Tuple[str, str, float, bool]],
List[_LGBM_BoosterEvalMethodResultType],
List[Tuple[str, str, float, bool, float]]
]

Expand Down
11 changes: 5 additions & 6 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@

import numpy as np

from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _log_warning
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _LGBM_EvalFunctionResultType,
_log_warning)
from .callback import record_evaluation
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMCpuCount, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase,
dt_DataTable, pd_DataFrame)
from .engine import train

_EvalResultType = Tuple[str, float, bool]

_LGBM_ScikitCustomObjectiveFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Expand All @@ -33,15 +32,15 @@
_LGBM_ScikitCustomEvalFunction = Union[
Callable[
[np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
Union[_EvalResultType, List[_EvalResultType]]
Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]
],
]

Expand Down

0 comments on commit 581d53c

Please sign in to comment.