From 2983dc35edabafd19d8d0524747f3cd64f82038f Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 26 Mar 2021 16:19:26 -0700 Subject: [PATCH 1/5] add early stopping --- .../image_classification/default.py | 4 ++ .../image_classification.py | 11 +++++ gluoncv/auto/estimators/utils.py | 49 ++++++++++++++++++- 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/gluoncv/auto/estimators/image_classification/default.py b/gluoncv/auto/estimators/image_classification/default.py index 6b9faf8a2f..5775005021 100644 --- a/gluoncv/auto/estimators/image_classification/default.py +++ b/gluoncv/auto/estimators/image_classification/default.py @@ -53,6 +53,10 @@ class TrainCfg: start_epoch : int = 0 transfer_lr_mult : float = 0.01 # reduce the backbone lr_mult to avoid quickly destroying the features output_lr_mult : float = 0.1 # the learning rate multiplier for last fc layer if trained with transfer learning + early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled + early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics + early_stop_baseline : float = 0 # the baseline value for metric, training won't stop if not reaching baseline + early_stop_max_value : float = 1 # early stop if reaching max value instantly @dataclass class ValidCfg: diff --git a/gluoncv/auto/estimators/image_classification/image_classification.py b/gluoncv/auto/estimators/image_classification/image_classification.py index 3632c2db3a..915d69db08 100644 --- a/gluoncv/auto/estimators/image_classification/image_classification.py +++ b/gluoncv/auto/estimators/image_classification/image_classification.py @@ -23,6 +23,7 @@ from .default import ImageClassificationCfg from ...data.dataset import ImageClassificationDataset from ..conf import _BEST_CHECKPOINT_FILE +from ..utils import EarlyStopperOnPlateau __all__ = ['ImageClassificationEstimator'] @@ -134,6 +135,11 @@ def _train_loop(self, train_data, val_data, time_limit=math.inf): self.teacher.hybridize(static_alloc=True, static_shape=True) self._logger.info('Start training from [Epoch %d]', max(self._cfg.train.start_epoch, self.epoch)) + early_stopper = EarlyStopperOnPlateau( + patience=self._cfg.train.early_stop_patience, + min_delta=self._cfg.train.early_stop_min_delta, + baseline_value=self._cfg.train.early_stop_baseline, + max_value=self._cfg.train.early_stop_max_value) train_metric_score = -1 cp_name = '' self._time_elapsed += time.time() - start_tic @@ -142,6 +148,10 @@ def _train_loop(self, train_data, val_data, time_limit=math.inf): if self._best_acc >= 1.0: self._logger.info('[Epoch {}] Early stopping as acc is reaching 1.0'.format(epoch)) break + should_stop, stop_message = early_stopper.get_early_stop_advice() + if should_stop: + self._logger.info('[Epoch {}] '.format(epoch) + stop_message) + break tic = time.time() last_tic = time.time() mx.nd.waitall() @@ -217,6 +227,7 @@ def _train_loop(self, train_data, val_data, time_limit=math.inf): throughput = int(self.batch_size * i /(time.time() - tic)) top1_val, top5_val = self._evaluate(val_data) + early_stopper.update(top1_val) self._logger.info('[Epoch %d] training: %s=%f', epoch, train_metric_name, train_metric_score) self._logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f', epoch, throughput, time.time()-tic) diff --git a/gluoncv/auto/estimators/utils.py b/gluoncv/auto/estimators/utils.py index 6654610833..af3a6bdffe 100644 --- a/gluoncv/auto/estimators/utils.py +++ b/gluoncv/auto/estimators/utils.py @@ -1,6 +1,53 @@ """Utils for deep learning framework related functions""" +import numpy as np + +__all__ = ['EarlyStopperOnPlateau', '_suggest_load_context'] + + +class EarlyStopperOnPlateau: + def __init__(self, patience=10, metric_fn=None, + min_delta=1e-4, baseline_value=None, max_value=np.Inf): + self.patience = patience if patience > 0 else np.Inf + self.metric_fn = metric_fn + self.min_delta = np.abs(min_delta) + self.baseline_value = baseline_value + self.max_value = max_value + self.reset() + + def reset(self): + self.last_epoch = 0 + self.wait = 0 + self._should_stop = False + self._message = '' + if self.baseline_value is not None: + self.best = self.baseline_value + else: + self.best = -np.Inf + + def update(self, metric_value, epoch=None): + if np.isreal(epoch): + self.last_epoch = epoch + if not np.isreal(metric_value): + return + if self.metric_fn is not None: + metric_value = self.metric_fn(metric_value) + + if metric_value > self.max_value: + self._should_stop = True + self._message = 'EarlyStop given {} vs. max {}'.format(metric_value, self.max_value) + else: + if metric_value - self.min_delta > self.best: + self.best = metric_value + self.wait = 0 + else: + self.wait += 1 + if self.wait >= self.patience: + self._should_stop = True + self._message = 'EarlyStop after {} epochs no better than {}'.format(self.patience, self.best) + + def get_early_stop_advice(self): + return self._should_stop, self._message -__all__ = ['_suggest_load_context'] def _suggest_load_context(model, mode, orig_ctx): """Get the correct context given the mode""" From eef365bdc5c854383b522f23907ecfcb98a6a53c Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 26 Mar 2021 16:29:53 -0700 Subject: [PATCH 2/5] fix default type --- gluoncv/auto/estimators/image_classification/default.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gluoncv/auto/estimators/image_classification/default.py b/gluoncv/auto/estimators/image_classification/default.py index 5775005021..0ec3dc2a43 100644 --- a/gluoncv/auto/estimators/image_classification/default.py +++ b/gluoncv/auto/estimators/image_classification/default.py @@ -55,8 +55,8 @@ class TrainCfg: output_lr_mult : float = 0.1 # the learning rate multiplier for last fc layer if trained with transfer learning early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics - early_stop_baseline : float = 0 # the baseline value for metric, training won't stop if not reaching baseline - early_stop_max_value : float = 1 # early stop if reaching max value instantly + early_stop_baseline : Union[float, int] = 0.0 # the baseline value for metric, training won't stop if not reaching baseline + early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly @dataclass class ValidCfg: From 42ce6aadca1b15d2e5e9183cb9ab6493fa72c739 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 26 Mar 2021 17:03:58 -0700 Subject: [PATCH 3/5] add all --- gluoncv/auto/estimators/base_estimator.py | 4 ++-- gluoncv/auto/estimators/center_net/center_net.py | 11 +++++++++++ gluoncv/auto/estimators/center_net/default.py | 5 +++++ gluoncv/auto/estimators/faster_rcnn/default.py | 5 +++++ .../auto/estimators/faster_rcnn/faster_rcnn.py | 15 +++++++++++++-- .../estimators/image_classification/default.py | 3 ++- gluoncv/auto/estimators/ssd/default.py | 6 ++++++ gluoncv/auto/estimators/ssd/ssd.py | 15 +++++++++++++-- gluoncv/auto/estimators/utils.py | 10 ++++++++-- gluoncv/auto/estimators/yolo/default.py | 5 +++++ gluoncv/auto/estimators/yolo/yolo.py | 15 +++++++++++++-- 11 files changed, 83 insertions(+), 11 deletions(-) diff --git a/gluoncv/auto/estimators/base_estimator.py b/gluoncv/auto/estimators/base_estimator.py index 6fa6a90bb6..2b3242cacd 100644 --- a/gluoncv/auto/estimators/base_estimator.py +++ b/gluoncv/auto/estimators/base_estimator.py @@ -302,7 +302,7 @@ def save(self, filename): """ with open(filename, 'wb') as fid: pickle.dump(self, fid) - self._logger.info('Pickled to %s', filename) + self._logger.debug('Pickled to %s', filename) @classmethod def load(cls, filename, ctx='auto'): @@ -324,7 +324,7 @@ def load(cls, filename, ctx='auto'): """ with open(filename, 'rb') as fid: obj = pickle.load(fid) - obj._logger.info('Unpickled from %s', filename) + obj._logger.debug('Unpickled from %s', filename) new_ctx = _suggest_load_context(obj.net, ctx, obj.ctx) obj.reset_ctx(new_ctx) return obj diff --git a/gluoncv/auto/estimators/center_net/center_net.py b/gluoncv/auto/estimators/center_net/center_net.py index 751c295749..b8079f1903 100644 --- a/gluoncv/auto/estimators/center_net/center_net.py +++ b/gluoncv/auto/estimators/center_net/center_net.py @@ -26,6 +26,7 @@ from .default import CenterNetCfg from ...data.dataset import ObjectDetectionDataset from ..conf import _BEST_CHECKPOINT_FILE +from ..utils import EarlyStopperOnPlateau __all__ = ['CenterNetEstimator'] @@ -142,6 +143,11 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf center_reg_metric = mx.metric.Loss('CenterRegL1') self._logger.info('Start training from [Epoch %d]', max(self._cfg.train.start_epoch, self.epoch)) + early_stopper = EarlyStopperOnPlateau( + patience=self._cfg.train.early_stop_patience, + min_delta=self._cfg.train.early_stop_min_delta, + baseline_value=self._cfg.train.early_stop_baseline, + max_value=self._cfg.train.early_stop_max_value) mean_ap = [-1] cp_name = '' self._time_elapsed += time.time() - start_tic @@ -152,6 +158,10 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf if self._best_map >= 1.0: self._logger.info('[Epoch %d] Early stopping as mAP is reaching 1.0', epoch) break + should_stop, stop_message = early_stopper.get_early_stop_advice() + if should_stop: + self._logger.info('[Epoch {}] '.format(epoch) + stop_message) + break wh_metric.reset() center_reg_metric.reset() heatmap_loss_metric.reset() @@ -225,6 +235,7 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf self._best_map = current_map if self._reporter: self._reporter(epoch=epoch, map_reward=current_map) + early_stopper.update(current_map, epoch=epoch) self._time_elapsed += time.time() - post_tic # map on train data tic = time.time() diff --git a/gluoncv/auto/estimators/center_net/default.py b/gluoncv/auto/estimators/center_net/default.py index 19562ba3e6..59585616da 100644 --- a/gluoncv/auto/estimators/center_net/default.py +++ b/gluoncv/auto/estimators/center_net/default.py @@ -40,6 +40,11 @@ class TrainCfg: momentum : float = 0.9 # SGD momentum wd : float = 1e-4 # weight decay log_interval : int = 100 # logging interval + early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled + early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics + # the baseline value for metric, training won't stop if not reaching baseline + early_stop_baseline : Union[float, int] = 0.0 + early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly @dataclass class ValidCfg: diff --git a/gluoncv/auto/estimators/faster_rcnn/default.py b/gluoncv/auto/estimators/faster_rcnn/default.py index bd20aa90cc..511c109b1a 100644 --- a/gluoncv/auto/estimators/faster_rcnn/default.py +++ b/gluoncv/auto/estimators/faster_rcnn/default.py @@ -157,6 +157,11 @@ class TrainCfg: # but may speed up throughput. Note that when horovod is used, # it is set to 1. executor_threads : int = 4 + early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled + early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics + # the baseline value for metric, training won't stop if not reaching baseline + early_stop_baseline : Union[float, int] = 0.0 + early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly @dataclass diff --git a/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py b/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py index 2977a16622..1f94dccbc0 100644 --- a/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py +++ b/gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py @@ -24,6 +24,7 @@ from .utils import _get_lr_at_iter, _get_dataloader, _split_and_load from ...data.dataset import ObjectDetectionDataset from ..conf import _BEST_CHECKPOINT_FILE +from ..utils import EarlyStopperOnPlateau try: import horovod.mxnet as hvd @@ -127,6 +128,11 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf self.net.collect_params().reset_ctx(self.ctx) self.net.target_generator.collect_params().reset_ctx(self.ctx) + early_stopper = EarlyStopperOnPlateau( + patience=self._cfg.train.early_stop_patience, + min_delta=self._cfg.train.early_stop_min_delta, + baseline_value=self._cfg.train.early_stop_baseline, + max_value=self._cfg.train.early_stop_max_value) mean_ap = [-1] cp_name = '' self._time_elapsed += time.time() - start_tic @@ -137,6 +143,10 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf if self._best_map >= 1.0: self._logger.info('[Epoch %d] Early stopping as mAP is reaching 1.0', epoch) break + should_stop, stop_message = early_stopper.get_early_stop_advice() + if should_stop: + self._logger.info('[Epoch {}] '.format(epoch) + stop_message) + break rcnn_task = ForwardBackwardTask(self.net, self.trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss, mix_ratio=1.0, amp_enabled=self._cfg.faster_rcnn.amp) @@ -231,8 +241,9 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf self.epoch, current_map, self._best_map, cp_name) self.save(cp_name) self._best_map = current_map - if self._reporter: - self._reporter(epoch=epoch, map_reward=current_map) + if self._reporter: + self._reporter(epoch=epoch, map_reward=current_map) + early_stopper.update(current_map, epoch=epoch) self._time_elapsed += time.time() - post_tic # map on train data tic = time.time() diff --git a/gluoncv/auto/estimators/image_classification/default.py b/gluoncv/auto/estimators/image_classification/default.py index 0ec3dc2a43..0a64d3d894 100644 --- a/gluoncv/auto/estimators/image_classification/default.py +++ b/gluoncv/auto/estimators/image_classification/default.py @@ -55,7 +55,8 @@ class TrainCfg: output_lr_mult : float = 0.1 # the learning rate multiplier for last fc layer if trained with transfer learning early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics - early_stop_baseline : Union[float, int] = 0.0 # the baseline value for metric, training won't stop if not reaching baseline + # the baseline value for metric, training won't stop if not reaching baseline + early_stop_baseline : Union[float, int] = 0.0 early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly @dataclass diff --git a/gluoncv/auto/estimators/ssd/default.py b/gluoncv/auto/estimators/ssd/default.py index c8320d5a52..e0c8670936 100644 --- a/gluoncv/auto/estimators/ssd/default.py +++ b/gluoncv/auto/estimators/ssd/default.py @@ -57,6 +57,12 @@ class TrainCfg: # Currently supports only COCO. dali : bool = False + early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled + early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics + # the baseline value for metric, training won't stop if not reaching baseline + early_stop_baseline : Union[float, int] = 0.0 + early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly + @dataclass class ValidCfg: diff --git a/gluoncv/auto/estimators/ssd/ssd.py b/gluoncv/auto/estimators/ssd/ssd.py index f6e916c952..5421308869 100644 --- a/gluoncv/auto/estimators/ssd/ssd.py +++ b/gluoncv/auto/estimators/ssd/ssd.py @@ -27,6 +27,7 @@ from .default import SSDCfg from ...data.dataset import ObjectDetectionDataset from ..conf import _BEST_CHECKPOINT_FILE +from ..utils import EarlyStopperOnPlateau try: import horovod.mxnet as hvd @@ -131,6 +132,11 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf self._logger.info('Start training from [Epoch %d]', max(self._cfg.train.start_epoch, self.epoch)) self.net.collect_params().reset_ctx(self.ctx) + early_stopper = EarlyStopperOnPlateau( + patience=self._cfg.train.early_stop_patience, + min_delta=self._cfg.train.early_stop_min_delta, + baseline_value=self._cfg.train.early_stop_baseline, + max_value=self._cfg.train.early_stop_max_value) mean_ap = [-1] cp_name = '' self._time_elapsed += time.time() - start_tic @@ -141,6 +147,10 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf if self._best_map >= 1.0: self._logger.info('[Epoch {}] Early stopping as mAP is reaching 1.0'.format(epoch)) break + should_stop, stop_message = early_stopper.get_early_stop_advice() + if should_stop: + self._logger.info('[Epoch {}] '.format(epoch) + stop_message) + break while lr_steps and epoch >= lr_steps[0]: new_lr = self.trainer.learning_rate * lr_decay lr_steps.pop(0) @@ -218,8 +228,9 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf self.epoch, current_map, self._best_map, cp_name) self.save(cp_name) self._best_map = current_map - if self._reporter: - self._reporter(epoch=epoch, map_reward=current_map) + if self._reporter: + self._reporter(epoch=epoch, map_reward=current_map) + early_stopper.update(current_map, epoch=epoch) self._time_elapsed += time.time() - post_tic # map on train data tic = time.time() diff --git a/gluoncv/auto/estimators/utils.py b/gluoncv/auto/estimators/utils.py index af3a6bdffe..9839dca456 100644 --- a/gluoncv/auto/estimators/utils.py +++ b/gluoncv/auto/estimators/utils.py @@ -26,7 +26,13 @@ def reset(self): def update(self, metric_value, epoch=None): if np.isreal(epoch): + if np.isreal(self.last_epoch): + diff_epoch = epoch - self.last_epoch + else: + diff_epoch = 1 self.last_epoch = epoch + else: + diff_epoch = 1 if not np.isreal(metric_value): return if self.metric_fn is not None: @@ -40,10 +46,10 @@ def update(self, metric_value, epoch=None): self.best = metric_value self.wait = 0 else: - self.wait += 1 + self.wait += diff_epoch if self.wait >= self.patience: self._should_stop = True - self._message = 'EarlyStop after {} epochs no better than {}'.format(self.patience, self.best) + self._message = 'EarlyStop after {} epochs: no better than {}'.format(self.patience, self.best) def get_early_stop_advice(self): return self._should_stop, self._message diff --git a/gluoncv/auto/estimators/yolo/default.py b/gluoncv/auto/estimators/yolo/default.py index 97383fd63a..3c8bff941f 100644 --- a/gluoncv/auto/estimators/yolo/default.py +++ b/gluoncv/auto/estimators/yolo/default.py @@ -74,6 +74,11 @@ class TrainCfg: no_mixup_epochs : int = 20 # Use label smoothing. label_smooth : bool = False + early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled + early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics + # the baseline value for metric, training won't stop if not reaching baseline + early_stop_baseline : Union[float, int] = 0.0 + early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly @dataclass diff --git a/gluoncv/auto/estimators/yolo/yolo.py b/gluoncv/auto/estimators/yolo/yolo.py index c689c9d392..da86f6fbe1 100644 --- a/gluoncv/auto/estimators/yolo/yolo.py +++ b/gluoncv/auto/estimators/yolo/yolo.py @@ -26,6 +26,7 @@ from .utils import _get_dataloader from ...data.dataset import ObjectDetectionDataset from ..conf import _BEST_CHECKPOINT_FILE +from ..utils import EarlyStopperOnPlateau try: import horovod.mxnet as hvd @@ -123,6 +124,11 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf cls_metrics = mx.metric.Loss('ClassLoss') trainer = self.trainer self._logger.info('Start training from [Epoch %d]', max(self._cfg.train.start_epoch, self.epoch)) + early_stopper = EarlyStopperOnPlateau( + patience=self._cfg.train.early_stop_patience, + min_delta=self._cfg.train.early_stop_min_delta, + baseline_value=self._cfg.train.early_stop_baseline, + max_value=self._cfg.train.early_stop_max_value) mean_ap = [-1] cp_name = '' self._time_elapsed += time.time() - start_tic @@ -131,6 +137,10 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf if self._best_map >= 1.0: self._logger.info('[Epoch {}] Early stopping as mAP is reaching 1.0'.format(epoch)) break + should_stop, stop_message = early_stopper.get_early_stop_advice() + if should_stop: + self._logger.info('[Epoch {}] '.format(epoch) + stop_message) + break tic = time.time() last_tic = time.time() if self._cfg.train.mixup: @@ -216,8 +226,9 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf self.epoch, current_map, self._best_map, cp_name) self.save(cp_name) self._best_map = current_map - if self._reporter: - self._reporter(epoch=epoch, map_reward=current_map) + if self._reporter: + self._reporter(epoch=epoch, map_reward=current_map) + early_stopper.update(current_map, epoch=epoch) self._time_elapsed += time.time() - post_tic # map on train data From 98bff2167226b30c4368daa7a37d718a1d2becff Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 26 Mar 2021 17:24:59 -0700 Subject: [PATCH 4/5] fix --- gluoncv/auto/estimators/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gluoncv/auto/estimators/utils.py b/gluoncv/auto/estimators/utils.py index 9839dca456..cdbbba9821 100644 --- a/gluoncv/auto/estimators/utils.py +++ b/gluoncv/auto/estimators/utils.py @@ -25,15 +25,15 @@ def reset(self): self.best = -np.Inf def update(self, metric_value, epoch=None): - if np.isreal(epoch): - if np.isreal(self.last_epoch): + if _is_real_number(epoch): + if _is_real_number(self.last_epoch): diff_epoch = epoch - self.last_epoch else: diff_epoch = 1 self.last_epoch = epoch else: diff_epoch = 1 - if not np.isreal(metric_value): + if not _is_real_number(metric_value): return if self.metric_fn is not None: metric_value = self.metric_fn(metric_value) @@ -54,6 +54,8 @@ def update(self, metric_value, epoch=None): def get_early_stop_advice(self): return self._should_stop, self._message +def _is_real_number(x): + return isinstance(x, (int, float, complex)) and not isinstance(x, bool) def _suggest_load_context(model, mode, orig_ctx): """Get the correct context given the mode""" From 33b1d69ffb2a9b23625629e9864e0bee1a590592 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 26 Mar 2021 17:59:39 -0700 Subject: [PATCH 5/5] fix pylint --- .../auto/estimators/center_net/center_net.py | 2 +- .../image_classification.py | 2 +- gluoncv/auto/estimators/ssd/ssd.py | 2 +- gluoncv/auto/estimators/utils.py | 41 +++++++++++++++++++ gluoncv/auto/estimators/yolo/yolo.py | 2 +- 5 files changed, 45 insertions(+), 4 deletions(-) diff --git a/gluoncv/auto/estimators/center_net/center_net.py b/gluoncv/auto/estimators/center_net/center_net.py index b8079f1903..67a7e566b5 100644 --- a/gluoncv/auto/estimators/center_net/center_net.py +++ b/gluoncv/auto/estimators/center_net/center_net.py @@ -1,5 +1,5 @@ """CenterNet Estimator""" -# pylint: disable=unused-variable,missing-function-docstring,abstract-method,logging-format-interpolation,arguments-differ +# pylint: disable=unused-variable,missing-function-docstring,abstract-method,logging-format-interpolation,arguments-differ,logging-not-lazy import os import math import time diff --git a/gluoncv/auto/estimators/image_classification/image_classification.py b/gluoncv/auto/estimators/image_classification/image_classification.py index 915d69db08..982a19dc28 100644 --- a/gluoncv/auto/estimators/image_classification/image_classification.py +++ b/gluoncv/auto/estimators/image_classification/image_classification.py @@ -1,5 +1,5 @@ """Classification Estimator""" -# pylint: disable=unused-variable,bad-whitespace,missing-function-docstring,logging-format-interpolation,arguments-differ +# pylint: disable=unused-variable,bad-whitespace,missing-function-docstring,logging-format-interpolation,arguments-differ,logging-not-lazy import time import os import math diff --git a/gluoncv/auto/estimators/ssd/ssd.py b/gluoncv/auto/estimators/ssd/ssd.py index 5421308869..9bc917e88c 100644 --- a/gluoncv/auto/estimators/ssd/ssd.py +++ b/gluoncv/auto/estimators/ssd/ssd.py @@ -1,5 +1,5 @@ """SSD Estimator.""" -# pylint: disable=logging-format-interpolation,abstract-method,arguments-differ +# pylint: disable=logging-format-interpolation,abstract-method,arguments-differ,logging-not-lazy import os import math import time diff --git a/gluoncv/auto/estimators/utils.py b/gluoncv/auto/estimators/utils.py index cdbbba9821..4b05a8365b 100644 --- a/gluoncv/auto/estimators/utils.py +++ b/gluoncv/auto/estimators/utils.py @@ -5,6 +5,25 @@ class EarlyStopperOnPlateau: + """Early stopping on plateau helper. + + Parameters + ---------- + patience : int, default is -1 + How many epochs with no improvement after which train will be early stopped. + Negative patience means infinite petience. + metric_fn : function, default is None + The function to apply to metric value if any. For example, you can use + the `metric_fn` to cast loss to negative values where lower loss is better. + `min_delta`, `baseline_value` and `max_value` are all based on output of `metric_fn`. + min_delta : float, default is 1e-4 + Early stopper ignores changes less than `min_delta` for metrics to ignore tiny fluctuates. + baseline_value : float, default is 0.0 + The baseline metric value to be considered. + max_value : float, default is 1.0 + Instantly early stop if reaching max value. + + """ def __init__(self, patience=10, metric_fn=None, min_delta=1e-4, baseline_value=None, max_value=np.Inf): self.patience = patience if patience > 0 else np.Inf @@ -15,6 +34,7 @@ def __init__(self, patience=10, metric_fn=None, self.reset() def reset(self): + """reset the early stopper""" self.last_epoch = 0 self.wait = 0 self._should_stop = False @@ -25,6 +45,16 @@ def reset(self): self.best = -np.Inf def update(self, metric_value, epoch=None): + """Update with end of epoch metric. + + Parameters + ---------- + metric_value : float + The end of epoch metric. + epoch : int, optional + The real epoch in case the update function is not called in every epoch. + + """ if _is_real_number(epoch): if _is_real_number(self.last_epoch): diff_epoch = epoch - self.last_epoch @@ -52,9 +82,20 @@ def update(self, metric_value, epoch=None): self._message = 'EarlyStop after {} epochs: no better than {}'.format(self.patience, self.best) def get_early_stop_advice(self): + """Get the early stop advice. + + Returns + ------- + (bool, str) + should_stop : bool + Whether the stopper suggest OnPlateau pattern is active. + message : str + The detailed message why early stop is suggested, if `should_stop` is True. + """ return self._should_stop, self._message def _is_real_number(x): + """Check if x is a real number""" return isinstance(x, (int, float, complex)) and not isinstance(x, bool) def _suggest_load_context(model, mode, orig_ctx): diff --git a/gluoncv/auto/estimators/yolo/yolo.py b/gluoncv/auto/estimators/yolo/yolo.py index da86f6fbe1..94b7a55f98 100644 --- a/gluoncv/auto/estimators/yolo/yolo.py +++ b/gluoncv/auto/estimators/yolo/yolo.py @@ -1,5 +1,5 @@ """YOLO Estimator.""" -# pylint: disable=logging-format-interpolation,abstract-method,arguments-differ +# pylint: disable=logging-format-interpolation,abstract-method,arguments-differ,logging-not-lazy import os import math import time