diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 06a7cf89bd4f..7338f8b1e3eb 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -75,9 +75,9 @@ def __init__(self, *args, **kwargs): from sklearn.utils.validation import assert_all_finite, check_array, check_X_y try: from sklearn.exceptions import NotFittedError - from sklearn.model_selection import GroupKFold, StratifiedKFold + from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold except ImportError: - from sklearn.cross_validation import GroupKFold, StratifiedKFold + from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold from sklearn.utils.validation import NotFittedError try: from sklearn.utils.validation import _check_sample_weight @@ -90,6 +90,7 @@ def _check_sample_weight(sample_weight, X, dtype=None): return sample_weight SKLEARN_INSTALLED = True + _LGBMBaseCrossValidator = BaseCrossValidator _LGBMModelBase = BaseEstimator _LGBMRegressorBase = RegressorMixin _LGBMClassifierBase = ClassifierMixin diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 71a9a115d342..2edf18435c17 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -4,19 +4,24 @@ import copy from operator import attrgetter from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np from . import callback from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, _log_warning -from .compat import SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold +from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold _LGBM_CustomMetricFunction = Callable[ [np.ndarray, Dataset], Tuple[str, float, bool] ] +_LGBM_PreprocFunction = Callable[ + [Dataset, Dataset, Dict[str, Any]], + Tuple[Dataset, Dataset, Dict[str, Any]] +] + def train( params: Dict[str, Any], @@ -373,12 +378,25 @@ def _agg_cv_result(raw_results): return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()] -def cv(params, train_set, num_boost_round=100, - folds=None, nfold=5, stratified=True, shuffle=True, - metrics=None, feval=None, init_model=None, - feature_name='auto', categorical_feature='auto', - fpreproc=None, seed=0, callbacks=None, eval_train_metric=False, - return_cvbooster=False): +def cv( + params: Dict[str, Any], + train_set: Dataset, + num_boost_round: int = 100, + folds: Optional[Union[Iterable[Tuple[np.ndarray, np.ndarray]], _LGBMBaseCrossValidator]] = None, + nfold: int = 5, + stratified: bool = True, + shuffle: bool = True, + metrics: Optional[Union[str, List[str]]] = None, + feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None, + init_model: Optional[Union[str, Path, Booster]] = None, + feature_name: Union[str, List[str]] = 'auto', + categorical_feature: Union[str, List[str], List[int]] = 'auto', + fpreproc: Optional[_LGBM_PreprocFunction] = None, + seed: int = 0, + callbacks: Optional[List[Callable]] = None, + eval_train_metric: bool = False, + return_cvbooster: bool = False +) -> Dict[str, Any]: """Perform the cross-validation with given parameters. Parameters