diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 4648b75e9060..56e9cbba190a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -20,6 +20,9 @@ from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series from .libpath import find_lib_path +_LGBM_EvalFunctionResultType = Tuple[str, float, bool] +_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool] + ZERO_THRESHOLD = 1e-35 @@ -2617,6 +2620,16 @@ def _dump_text(self, filename: Union[str, Path]) -> "Dataset": [np.ndarray, Dataset], Tuple[np.ndarray, np.ndarray] ] +_LGBM_CustomEvalFunction = Union[ + Callable[ + [np.ndarray, Dataset], + _LGBM_EvalFunctionResultType + ], + Callable[ + [np.ndarray, Dataset], + List[_LGBM_EvalFunctionResultType] + ] +] class Booster: @@ -3255,7 +3268,12 @@ def lower_bound(self) -> float: ctypes.byref(ret))) return ret.value - def eval(self, data, name, feval=None): + def eval( + self, + data: Dataset, + name: str, + feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None + ) -> List[_LGBM_BoosterEvalMethodResultType]: """Evaluate for data. Parameters @@ -3286,7 +3304,7 @@ def eval(self, data, name, feval=None): Returns ------- result : list - List with evaluation results. + List with (dataset_name, eval_name, eval_result, is_higher_better) tuples. """ if not isinstance(data, Dataset): raise TypeError("Can only eval for Dataset instance") @@ -3305,7 +3323,10 @@ def eval(self, data, name, feval=None): return self.__inner_eval(name, data_idx, feval) - def eval_train(self, feval=None): + def eval_train( + self, + feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None + ) -> List[_LGBM_BoosterEvalMethodResultType]: """Evaluate for training data. Parameters @@ -3332,11 +3353,14 @@ def eval_train(self, feval=None): Returns ------- result : list - List with evaluation results. + List with (train_dataset_name, eval_name, eval_result, is_higher_better) tuples. """ return self.__inner_eval(self._train_data_name, 0, feval) - def eval_valid(self, feval=None): + def eval_valid( + self, + feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None + ) -> List[_LGBM_BoosterEvalMethodResultType]: """Evaluate for validation data. Parameters @@ -3363,7 +3387,7 @@ def eval_valid(self, feval=None): Returns ------- result : list - List with evaluation results. + List with (validation_dataset_name, eval_name, eval_result, is_higher_better) tuples. """ return [item for i in range(1, self.__num_dataset) for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)] @@ -3969,7 +3993,12 @@ def add(root): else: return hist, bin_edges - def __inner_eval(self, data_name, data_idx, feval=None): + def __inner_eval( + self, + data_name: str, + data_idx: int, + feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None + ) -> List[_LGBM_BoosterEvalMethodResultType]: """Evaluate training or validation data.""" if data_idx >= self.__num_dataset: raise ValueError("Data_idx should be smaller than number of dataset") diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 05539b6396ac..c1412e424e8a 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -4,10 +4,10 @@ from functools import partial from typing import Any, Callable, Dict, List, Tuple, Union -from .basic import _ConfigAliases, _log_info, _log_warning +from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning _EvalResultTuple = Union[ - List[Tuple[str, str, float, bool]], + List[_LGBM_BoosterEvalMethodResultType], List[Tuple[str, str, float, bool, float]] ] diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 1f320886ebfc..cb7b14036929 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -6,7 +6,8 @@ import numpy as np -from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _log_warning +from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _LGBM_EvalFunctionResultType, + _log_warning) from .callback import record_evaluation from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray, _LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase, @@ -14,8 +15,6 @@ dt_DataTable, pd_DataFrame) from .engine import train -_EvalResultType = Tuple[str, float, bool] - _LGBM_ScikitCustomObjectiveFunction = Union[ Callable[ [np.ndarray, np.ndarray], @@ -33,15 +32,15 @@ _LGBM_ScikitCustomEvalFunction = Union[ Callable[ [np.ndarray, np.ndarray], - Union[_EvalResultType, List[_EvalResultType]] + Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] ], Callable[ [np.ndarray, np.ndarray, np.ndarray], - Union[_EvalResultType, List[_EvalResultType]] + Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] ], Callable[ [np.ndarray, np.ndarray, np.ndarray, np.ndarray], - Union[_EvalResultType, List[_EvalResultType]] + Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] ], ]