Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 committed Feb 17, 2022
1 parent 6a0b91b commit 7ca8b55
Showing 1 changed file with 45 additions and 28 deletions.
73 changes: 45 additions & 28 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,60 +237,69 @@ def __init__(self, stopping_rounds: int, first_metric_only: bool = False, verbos
self.first_metric_only = first_metric_only
self.verbose = verbose
self.min_delta = min_delta

self.enabled = True

self._reset_storages()

def _reset_storages(self) -> None:
# reset storages
self.best_score = []
self.best_iter = []
self.best_score_list: list = []
self.best_score_list = []
self.cmp_op = []
self.enabled = True
self.first_metric = ''

def _init(self, env: CallbackEnv) -> None:
self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
in _ConfigAliases.get("boosting"))
if not self.enabled:
_log_warning('Early stopping is not available in dart mode')
return
if not env.evaluation_result_list:
raise ValueError('For early stopping, '
'at least one dataset and eval metric is required for evaluation')
'at least one dataset and eval metric is required for evaluation')

if self.stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")
raise ValueError(
"stopping_rounds should be greater than zero.")

if self.verbose:
_log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
_log_info(
f"Training until validation scores don't improve for {self.stopping_rounds} rounds")

# reset storages
self.best_score = []
self.best_iter = []
self.best_score_list = []
self.cmp_op = []
self.first_metric = ''
self._reset_storages()

n_metrics = len(set(m[1] for m in env.evaluation_result_list))
n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(self.min_delta, list):
if not all(t >= 0 for t in self.min_delta):
raise ValueError('Values for early stopping min_delta must be non-negative.')
raise ValueError(
'Values for early stopping min_delta must be non-negative.')
if len(self.min_delta) == 0:
if self.verbose:
_log_info('Disabling min_delta for early stopping.')
deltas = [0.0] * n_datasets * n_metrics
elif len(self.min_delta) == 1:
if self.verbose:
_log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
_log_info(
f'Using {self.min_delta[0]} as min_delta for all metrics.')
deltas = self.min_delta * n_datasets * n_metrics
else:
if len(self.min_delta) != n_metrics:
raise ValueError('Must provide a single value for min_delta or as many as metrics.')
raise ValueError(
'Must provide a single value for min_delta or as many as metrics.')
if self.first_metric_only and self.verbose:
_log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
_log_info(
f'Using only {self.min_delta[0]} as early stopping min_delta.')
deltas = self.min_delta * n_datasets
else:
if self.min_delta < 0:
raise ValueError('Early stopping min_delta must be non-negative.')
raise ValueError(
'Early stopping min_delta must be non-negative.')
if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
_log_info(f'Using {self.min_delta} as min_delta for all metrics.')
_log_info(
f'Using {self.min_delta} as min_delta for all metrics.')
deltas = [self.min_delta] * n_datasets * n_metrics

# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
Expand All @@ -309,12 +318,14 @@ def _init(self, env: CallbackEnv) -> None:
def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
if env.iteration == env.end_iteration - 1:
if self.verbose:
best_score_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
best_score_str = '\t'.join(
[_format_eval_result(x) for x in self.best_score_list[i]])
_log_info('Did not meet early stopping. '
f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
raise EarlyStopException(
self.best_iter[i], self.best_score_list[i])

def __call__(self, env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration:
Expand All @@ -328,20 +339,26 @@ def __call__(self, env: CallbackEnv) -> None:
self.best_iter[i] = env.iteration
self.best_score_list[i] = env.evaluation_result_list
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
eval_name_splitted = env.evaluation_result_list[i][1].split(
" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
self._final_iteration_check(env, eval_name_splitted, i)
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
# train data for lgb.cv or sklearn wrapper (underlying lgb.train)
continue
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose:
eval_result_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
_log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
eval_result_str = '\t'.join(
[_format_eval_result(x) for x in self.best_score_list[i]])
_log_info(
f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
_log_info(
f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(
self.best_iter[i], self.best_score_list[i])
self._final_iteration_check(env, eval_name_splitted, i)

_early_stopping_callback.order = 30 # type: ignore
Expand Down

0 comments on commit 7ca8b55

Please sign in to comment.