From 7ca8b557572446888cf793c0082d9a7efd1e29a7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 17 Feb 2022 16:35:02 +0000 Subject: [PATCH] Lint --- python-package/lightgbm/callback.py | 73 ++++++++++++++++++----------- 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index e66f4376deec..6636eecce6f2 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -237,60 +237,69 @@ def __init__(self, stopping_rounds: int, first_metric_only: bool = False, verbos self.first_metric_only = first_metric_only self.verbose = verbose self.min_delta = min_delta + + self.enabled = True + + self._reset_storages() + + def _reset_storages(self) -> None: + # reset storages self.best_score = [] self.best_iter = [] - self.best_score_list: list = [] + self.best_score_list = [] self.cmp_op = [] - self.enabled = True self.first_metric = '' def _init(self, env: CallbackEnv) -> None: self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias - in _ConfigAliases.get("boosting")) + in _ConfigAliases.get("boosting")) if not self.enabled: _log_warning('Early stopping is not available in dart mode') return if not env.evaluation_result_list: raise ValueError('For early stopping, ' - 'at least one dataset and eval metric is required for evaluation') + 'at least one dataset and eval metric is required for evaluation') if self.stopping_rounds <= 0: - raise ValueError("stopping_rounds should be greater than zero.") + raise ValueError( + "stopping_rounds should be greater than zero.") if self.verbose: - _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds") + _log_info( + f"Training until validation scores don't improve for {self.stopping_rounds} rounds") - # reset storages - self.best_score = [] - self.best_iter = [] - self.best_score_list = [] - self.cmp_op = [] - self.first_metric = '' + self._reset_storages() n_metrics = len(set(m[1] for m in env.evaluation_result_list)) n_datasets = len(env.evaluation_result_list) // n_metrics if isinstance(self.min_delta, list): if not all(t >= 0 for t in self.min_delta): - raise ValueError('Values for early stopping min_delta must be non-negative.') + raise ValueError( + 'Values for early stopping min_delta must be non-negative.') if len(self.min_delta) == 0: if self.verbose: _log_info('Disabling min_delta for early stopping.') deltas = [0.0] * n_datasets * n_metrics elif len(self.min_delta) == 1: if self.verbose: - _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.') + _log_info( + f'Using {self.min_delta[0]} as min_delta for all metrics.') deltas = self.min_delta * n_datasets * n_metrics else: if len(self.min_delta) != n_metrics: - raise ValueError('Must provide a single value for min_delta or as many as metrics.') + raise ValueError( + 'Must provide a single value for min_delta or as many as metrics.') if self.first_metric_only and self.verbose: - _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.') + _log_info( + f'Using only {self.min_delta[0]} as early stopping min_delta.') deltas = self.min_delta * n_datasets else: if self.min_delta < 0: - raise ValueError('Early stopping min_delta must be non-negative.') + raise ValueError( + 'Early stopping min_delta must be non-negative.') if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose: - _log_info(f'Using {self.min_delta} as min_delta for all metrics.') + _log_info( + f'Using {self.min_delta} as min_delta for all metrics.') deltas = [self.min_delta] * n_datasets * n_metrics # split is needed for " " case (e.g. "train l1") @@ -309,12 +318,14 @@ def _init(self, env: CallbackEnv) -> None: def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: if env.iteration == env.end_iteration - 1: if self.verbose: - best_score_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]]) + best_score_str = '\t'.join( + [_format_eval_result(x) for x in self.best_score_list[i]]) _log_info('Did not meet early stopping. ' - f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}') + f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}') if self.first_metric_only: _log_info(f"Evaluated only: {eval_name_splitted[-1]}") - raise EarlyStopException(self.best_iter[i], self.best_score_list[i]) + raise EarlyStopException( + self.best_iter[i], self.best_score_list[i]) def __call__(self, env: CallbackEnv) -> None: if env.iteration == env.begin_iteration: @@ -328,20 +339,26 @@ def __call__(self, env: CallbackEnv) -> None: self.best_iter[i] = env.iteration self.best_score_list[i] = env.evaluation_result_list # split is needed for " " case (e.g. "train l1") - eval_name_splitted = env.evaluation_result_list[i][1].split(" ") + eval_name_splitted = env.evaluation_result_list[i][1].split( + " ") if self.first_metric_only and self.first_metric != eval_name_splitted[-1]: continue # use only the first metric for early stopping if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train" - or env.evaluation_result_list[i][0] == env.model._train_data_name)): + or env.evaluation_result_list[i][0] == env.model._train_data_name)): self._final_iteration_check(env, eval_name_splitted, i) - continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train) + # train data for lgb.cv or sklearn wrapper (underlying lgb.train) + continue elif env.iteration - self.best_iter[i] >= self.stopping_rounds: if self.verbose: - eval_result_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]]) - _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}") + eval_result_str = '\t'.join( + [_format_eval_result(x) for x in self.best_score_list[i]]) + _log_info( + f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}") if self.first_metric_only: - _log_info(f"Evaluated only: {eval_name_splitted[-1]}") - raise EarlyStopException(self.best_iter[i], self.best_score_list[i]) + _log_info( + f"Evaluated only: {eval_name_splitted[-1]}") + raise EarlyStopException( + self.best_iter[i], self.best_score_list[i]) self._final_iteration_check(env, eval_name_splitted, i) _early_stopping_callback.order = 30 # type: ignore