diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 092dfb463f8e..45e21f298480 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -128,15 +128,15 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: """ if not isinstance(eval_result, dict): raise TypeError('eval_result should be a dictionary') - eval_result.clear() def _init(env: CallbackEnv) -> None: + eval_result.clear() for data_name, eval_name, _, _ in env.evaluation_result_list: eval_result.setdefault(data_name, collections.OrderedDict()) eval_result[data_name].setdefault(eval_name, []) def _callback(env: CallbackEnv) -> None: - if not eval_result: + if env.iteration == env.begin_iteration: _init(env) for data_name, eval_name, result, _ in env.evaluation_result_list: eval_result[data_name][eval_name].append(result) @@ -221,7 +221,6 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos best_score_list: list = [] cmp_op = [] enabled = True - inited = False first_metric = '' def _init(env: CallbackEnv) -> None: @@ -230,7 +229,6 @@ def _init(env: CallbackEnv) -> None: nonlocal best_score_list nonlocal cmp_op nonlocal enabled - nonlocal inited nonlocal first_metric enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias in _ConfigAliases.get("boosting")) @@ -249,7 +247,6 @@ def _init(env: CallbackEnv) -> None: best_iter = [] best_score_list = [] cmp_op = [] - inited = True first_metric = '' n_metrics = len(set(m[1] for m in env.evaluation_result_list)) @@ -293,7 +290,6 @@ def _init(env: CallbackEnv) -> None: def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: nonlocal best_iter nonlocal best_score_list - nonlocal inited if env.iteration == env.end_iteration - 1: if verbose: best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]]) @@ -301,7 +297,6 @@ def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: i f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}') if first_metric_only: _log_info(f"Evaluated only: {eval_name_splitted[-1]}") - inited = False raise EarlyStopException(best_iter[i], best_score_list[i]) def _callback(env: CallbackEnv) -> None: @@ -310,9 +305,8 @@ def _callback(env: CallbackEnv) -> None: nonlocal best_score_list nonlocal cmp_op nonlocal enabled - nonlocal inited nonlocal first_metric - if not inited: + if env.iteration == env.begin_iteration: _init(env) if not enabled: return @@ -336,7 +330,6 @@ def _callback(env: CallbackEnv) -> None: _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}") if first_metric_only: _log_info(f"Evaluated only: {eval_name_splitted[-1]}") - inited = False raise EarlyStopException(best_iter[i], best_score_list[i]) _final_iteration_check(env, eval_name_splitted, i) _callback.order = 30 # type: ignore diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index acc3b9c512f6..1f998b13621b 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -294,20 +294,23 @@ def test_stacking_regressor(): def test_grid_search(): X, y = load_iris(return_X_y=True) y = y.astype(str) # utilize label encoder at it's max power - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, - random_state=42) - X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, - random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) + X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42) params = dict(subsample=0.8, subsample_freq=1) grid_params = dict(boosting_type=['rf', 'gbdt'], n_estimators=[4, 6], reg_alpha=[0.01, 0.005]) - fit_params = dict(eval_set=[(X_val, y_val)], - eval_metric=constant_metric, - callbacks=[lgb.early_stopping(2)]) - grid = GridSearchCV(estimator=lgb.LGBMClassifier(**params), param_grid=grid_params, - cv=2) + evals_result = {} + fit_params = dict( + eval_set=[(X_val, y_val)], + eval_metric=constant_metric, + callbacks=[ + lgb.early_stopping(2), + lgb.record_evaluation(evals_result) + ] + ) + grid = GridSearchCV(estimator=lgb.LGBMClassifier(**params), param_grid=grid_params, cv=2) grid.fit(X_train, y_train, **fit_params) score = grid.score(X_test, y_test) # utilizes GridSearchCV default refit=True assert grid.best_params_['boosting_type'] in ['rf', 'gbdt'] @@ -319,6 +322,7 @@ def test_grid_search(): assert grid.best_estimator_.best_score_['valid_0']['error'] == 0 assert score >= 0.2 assert score <= 1. + assert evals_result == grid.best_estimator_.evals_result_ def test_random_search():