Skip to content

Commit

Permalink
[python] make early_stopping callback pickleable (#5012)
Browse files Browse the repository at this point in the history
* Turn `early_stopping` into a Callable class

* Fix

* Lint

* Remove print

* Fix order

* Revert "Lint"

This reverts commit 7ca8b55.

* Apply suggestion from code review

* Nit

* Lint

* Move callable class outside the func for pickling

* Move _pickle and _unpickle to tests utils

* Add early stopping callback picklability test

* Nit

* Fix

* Lint

* Improve type hint

* Lint

* Lint

* Add cloudpickle to test_windows

* Update tests/python_package_test/test_engine.py

* Fix

* Apply suggestions from code review
  • Loading branch information
Yard1 authored Mar 17, 2022
1 parent eb686a7 commit f77e0ad
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 161 deletions.
2 changes: 1 addition & 1 deletion .ci/test_windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ if ($env:TASK -eq "swig") {
Exit 0
}

conda install -q -y -n $env:CONDA_ENV joblib matplotlib numpy pandas psutil pytest scikit-learn scipy ; Check-Output $?
conda install -q -y -n $env:CONDA_ENV cloudpickle joblib matplotlib numpy pandas psutil pytest scikit-learn scipy ; Check-Output $?
# python-graphviz has to be installed separately to prevent conda from downgrading to pypy
conda install -q -y -n $env:CONDA_ENV libxml2 python-graphviz ; Check-Output $?

Expand Down
245 changes: 123 additions & 122 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@
]


def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta


def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta


class EarlyStopException(Exception):
"""Exception of early stopping."""

Expand Down Expand Up @@ -199,156 +191,165 @@ def _callback(env: CallbackEnv) -> None:
return _callback


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable:
"""Create a callback that activates early stopping.
Activates early stopping.
The model will train until the validation score doesn't improve by at least ``min_delta``.
Validation score needs to improve at least every ``stopping_rounds`` round(s)
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric set ``first_metric_only`` to True.
The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
Parameters
----------
stopping_rounds : int
The possible number of rounds without the trend occurrence.
first_metric_only : bool, optional (default=False)
Whether to use only the first metric for early stopping.
verbose : bool, optional (default=True)
Whether to log message with early stopping information.
By default, standard output resource is used.
Use ``register_logger()`` function to register a custom logger.
min_delta : float or list of float, optional (default=0.0)
Minimum improvement in score to keep training.
If float, this single value is used for all metrics.
If list, its length should match the total number of metrics.
Returns
-------
callback : callable
The callback that activates early stopping.
"""
best_score = []
best_iter = []
best_score_list: list = []
cmp_op = []
enabled = True
first_metric = ''

def _init(env: CallbackEnv) -> None:
nonlocal best_score
nonlocal best_iter
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal first_metric
enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
if not enabled:
class _EarlyStoppingCallback:
"""Internal early stopping callable class."""

def __init__(
self,
stopping_rounds: int,
first_metric_only: bool = False,
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0
) -> None:
self.order = 30
self.before_iteration = False

self.stopping_rounds = stopping_rounds
self.first_metric_only = first_metric_only
self.verbose = verbose
self.min_delta = min_delta

self.enabled = True
self._reset_storages()

def _reset_storages(self) -> None:
self.best_score = []
self.best_iter = []
self.best_score_list = []
self.cmp_op = []
self.first_metric = ''

def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta

def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta

def _init(self, env: CallbackEnv) -> None:
self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
if not self.enabled:
_log_warning('Early stopping is not available in dart mode')
return
if not env.evaluation_result_list:
raise ValueError('For early stopping, '
'at least one dataset and eval metric is required for evaluation')

if stopping_rounds <= 0:
if self.stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")

if verbose:
_log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
if self.verbose:
_log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")

# reset storages
best_score = []
best_iter = []
best_score_list = []
cmp_op = []
first_metric = ''
self._reset_storages()

n_metrics = len(set(m[1] for m in env.evaluation_result_list))
n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(min_delta, list):
if not all(t >= 0 for t in min_delta):
if isinstance(self.min_delta, list):
if not all(t >= 0 for t in self.min_delta):
raise ValueError('Values for early stopping min_delta must be non-negative.')
if len(min_delta) == 0:
if verbose:
if len(self.min_delta) == 0:
if self.verbose:
_log_info('Disabling min_delta for early stopping.')
deltas = [0.0] * n_datasets * n_metrics
elif len(min_delta) == 1:
if verbose:
_log_info(f'Using {min_delta[0]} as min_delta for all metrics.')
deltas = min_delta * n_datasets * n_metrics
elif len(self.min_delta) == 1:
if self.verbose:
_log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
deltas = self.min_delta * n_datasets * n_metrics
else:
if len(min_delta) != n_metrics:
if len(self.min_delta) != n_metrics:
raise ValueError('Must provide a single value for min_delta or as many as metrics.')
if first_metric_only and verbose:
_log_info(f'Using only {min_delta[0]} as early stopping min_delta.')
deltas = min_delta * n_datasets
if self.first_metric_only and self.verbose:
_log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
deltas = self.min_delta * n_datasets
else:
if min_delta < 0:
if self.min_delta < 0:
raise ValueError('Early stopping min_delta must be non-negative.')
if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose:
_log_info(f'Using {min_delta} as min_delta for all metrics.')
deltas = [min_delta] * n_datasets * n_metrics
if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
_log_info(f'Using {self.min_delta} as min_delta for all metrics.')
deltas = [self.min_delta] * n_datasets * n_metrics

# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
for eval_ret, delta in zip(env.evaluation_result_list, deltas):
best_iter.append(0)
best_score_list.append(None)
self.best_iter.append(0)
self.best_score_list.append(None)
if eval_ret[3]: # greater is better
best_score.append(float('-inf'))
cmp_op.append(partial(_gt_delta, delta=delta))
self.best_score.append(float('-inf'))
self.cmp_op.append(partial(self._gt_delta, delta=delta))
else:
best_score.append(float('inf'))
cmp_op.append(partial(_lt_delta, delta=delta))
self.best_score.append(float('inf'))
self.cmp_op.append(partial(self._lt_delta, delta=delta))

def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
nonlocal best_iter
nonlocal best_score_list
def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
if env.iteration == env.end_iteration - 1:
if verbose:
best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
if self.verbose:
best_score_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
_log_info('Did not meet early stopping. '
f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
if first_metric_only:
f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(best_iter[i], best_score_list[i])
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])

def _callback(env: CallbackEnv) -> None:
nonlocal best_score
nonlocal best_iter
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal first_metric
def __call__(self, env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration:
_init(env)
if not enabled:
self._init(env)
if not self.enabled:
return
for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2]
if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
best_score[i] = score
best_iter[i] = env.iteration
best_score_list[i] = env.evaluation_result_list
if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]):
self.best_score[i] = score
self.best_iter[i] = env.iteration
self.best_score_list[i] = env.evaluation_result_list
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if first_metric_only and first_metric != eval_name_splitted[-1]:
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
_final_iteration_check(env, eval_name_splitted, i)
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
self._final_iteration_check(env, eval_name_splitted, i)
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - best_iter[i] >= stopping_rounds:
if verbose:
eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
_log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
if first_metric_only:
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose:
eval_result_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
_log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(best_iter[i], best_score_list[i])
_final_iteration_check(env, eval_name_splitted, i)
_callback.order = 30 # type: ignore
return _callback
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
self._final_iteration_check(env, eval_name_splitted, i)


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
"""Create a callback that activates early stopping.
Activates early stopping.
The model will train until the validation score doesn't improve by at least ``min_delta``.
Validation score needs to improve at least every ``stopping_rounds`` round(s)
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric set ``first_metric_only`` to True.
The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
Parameters
----------
stopping_rounds : int
The possible number of rounds without the trend occurrence.
first_metric_only : bool, optional (default=False)
Whether to use only the first metric for early stopping.
verbose : bool, optional (default=True)
Whether to log message with early stopping information.
By default, standard output resource is used.
Use ``register_logger()`` function to register a custom logger.
min_delta : float or list of float, optional (default=0.0)
Minimum improvement in score to keep training.
If float, this single value is used for all metrics.
If list, its length should match the total number of metrics.
Returns
-------
callback : _EarlyStoppingCallback
The callback that activates early stopping.
"""
return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)
22 changes: 22 additions & 0 deletions tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# coding: utf-8
import pytest

import lightgbm as lgb

from .utils import pickle_obj, unpickle_obj


@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_early_stopping_callback_is_picklable(serializer, tmp_path):
callback = lgb.early_stopping(stopping_rounds=5)
tmp_file = tmp_path / "early_stopping.pkl"
pickle_obj(
obj=callback,
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
assert callback.stopping_rounds == callback_from_disk.stopping_rounds
Loading

0 comments on commit f77e0ad

Please sign in to comment.