Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto] Add early stopping strategies #1641

Merged
merged 5 commits into from
Mar 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gluoncv/auto/estimators/base_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def save(self, filename):
"""
with open(filename, 'wb') as fid:
pickle.dump(self, fid)
self._logger.info('Pickled to %s', filename)
self._logger.debug('Pickled to %s', filename)

@classmethod
def load(cls, filename, ctx='auto'):
Expand All @@ -324,7 +324,7 @@ def load(cls, filename, ctx='auto'):
"""
with open(filename, 'rb') as fid:
obj = pickle.load(fid)
obj._logger.info('Unpickled from %s', filename)
obj._logger.debug('Unpickled from %s', filename)
new_ctx = _suggest_load_context(obj.net, ctx, obj.ctx)
obj.reset_ctx(new_ctx)
return obj
Expand Down
13 changes: 12 additions & 1 deletion gluoncv/auto/estimators/center_net/center_net.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""CenterNet Estimator"""
# pylint: disable=unused-variable,missing-function-docstring,abstract-method,logging-format-interpolation,arguments-differ
# pylint: disable=unused-variable,missing-function-docstring,abstract-method,logging-format-interpolation,arguments-differ,logging-not-lazy
import os
import math
import time
Expand All @@ -26,6 +26,7 @@
from .default import CenterNetCfg
from ...data.dataset import ObjectDetectionDataset
from ..conf import _BEST_CHECKPOINT_FILE
from ..utils import EarlyStopperOnPlateau

__all__ = ['CenterNetEstimator']

Expand Down Expand Up @@ -142,6 +143,11 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf
center_reg_metric = mx.metric.Loss('CenterRegL1')

self._logger.info('Start training from [Epoch %d]', max(self._cfg.train.start_epoch, self.epoch))
early_stopper = EarlyStopperOnPlateau(
patience=self._cfg.train.early_stop_patience,
min_delta=self._cfg.train.early_stop_min_delta,
baseline_value=self._cfg.train.early_stop_baseline,
max_value=self._cfg.train.early_stop_max_value)
mean_ap = [-1]
cp_name = ''
self._time_elapsed += time.time() - start_tic
Expand All @@ -152,6 +158,10 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf
if self._best_map >= 1.0:
self._logger.info('[Epoch %d] Early stopping as mAP is reaching 1.0', epoch)
break
should_stop, stop_message = early_stopper.get_early_stop_advice()
if should_stop:
self._logger.info('[Epoch {}] '.format(epoch) + stop_message)
break
wh_metric.reset()
center_reg_metric.reset()
heatmap_loss_metric.reset()
Expand Down Expand Up @@ -225,6 +235,7 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf
self._best_map = current_map
if self._reporter:
self._reporter(epoch=epoch, map_reward=current_map)
early_stopper.update(current_map, epoch=epoch)
self._time_elapsed += time.time() - post_tic
# map on train data
tic = time.time()
Expand Down
5 changes: 5 additions & 0 deletions gluoncv/auto/estimators/center_net/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class TrainCfg:
momentum : float = 0.9 # SGD momentum
wd : float = 1e-4 # weight decay
log_interval : int = 100 # logging interval
early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled
early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics
# the baseline value for metric, training won't stop if not reaching baseline
early_stop_baseline : Union[float, int] = 0.0
early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly

@dataclass
class ValidCfg:
Expand Down
5 changes: 5 additions & 0 deletions gluoncv/auto/estimators/faster_rcnn/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ class TrainCfg:
# but may speed up throughput. Note that when horovod is used,
# it is set to 1.
executor_threads : int = 4
early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled
early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics
# the baseline value for metric, training won't stop if not reaching baseline
early_stop_baseline : Union[float, int] = 0.0
early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly


@dataclass
Expand Down
15 changes: 13 additions & 2 deletions gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .utils import _get_lr_at_iter, _get_dataloader, _split_and_load
from ...data.dataset import ObjectDetectionDataset
from ..conf import _BEST_CHECKPOINT_FILE
from ..utils import EarlyStopperOnPlateau

try:
import horovod.mxnet as hvd
Expand Down Expand Up @@ -127,6 +128,11 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf
self.net.collect_params().reset_ctx(self.ctx)
self.net.target_generator.collect_params().reset_ctx(self.ctx)

early_stopper = EarlyStopperOnPlateau(
patience=self._cfg.train.early_stop_patience,
min_delta=self._cfg.train.early_stop_min_delta,
baseline_value=self._cfg.train.early_stop_baseline,
max_value=self._cfg.train.early_stop_max_value)
mean_ap = [-1]
cp_name = ''
self._time_elapsed += time.time() - start_tic
Expand All @@ -137,6 +143,10 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf
if self._best_map >= 1.0:
self._logger.info('[Epoch %d] Early stopping as mAP is reaching 1.0', epoch)
break
should_stop, stop_message = early_stopper.get_early_stop_advice()
if should_stop:
self._logger.info('[Epoch {}] '.format(epoch) + stop_message)
break
rcnn_task = ForwardBackwardTask(self.net, self.trainer, rpn_cls_loss, rpn_box_loss,
rcnn_cls_loss, rcnn_box_loss, mix_ratio=1.0,
amp_enabled=self._cfg.faster_rcnn.amp)
Expand Down Expand Up @@ -231,8 +241,9 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf
self.epoch, current_map, self._best_map, cp_name)
self.save(cp_name)
self._best_map = current_map
if self._reporter:
self._reporter(epoch=epoch, map_reward=current_map)
if self._reporter:
self._reporter(epoch=epoch, map_reward=current_map)
early_stopper.update(current_map, epoch=epoch)
self._time_elapsed += time.time() - post_tic
# map on train data
tic = time.time()
Expand Down
5 changes: 5 additions & 0 deletions gluoncv/auto/estimators/image_classification/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class TrainCfg:
start_epoch : int = 0
transfer_lr_mult : float = 0.01 # reduce the backbone lr_mult to avoid quickly destroying the features
output_lr_mult : float = 0.1 # the learning rate multiplier for last fc layer if trained with transfer learning
early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled
early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics
# the baseline value for metric, training won't stop if not reaching baseline
early_stop_baseline : Union[float, int] = 0.0
early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly

@dataclass
class ValidCfg:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Classification Estimator"""
# pylint: disable=unused-variable,bad-whitespace,missing-function-docstring,logging-format-interpolation,arguments-differ
# pylint: disable=unused-variable,bad-whitespace,missing-function-docstring,logging-format-interpolation,arguments-differ,logging-not-lazy
import time
import os
import math
Expand All @@ -23,6 +23,7 @@
from .default import ImageClassificationCfg
from ...data.dataset import ImageClassificationDataset
from ..conf import _BEST_CHECKPOINT_FILE
from ..utils import EarlyStopperOnPlateau

__all__ = ['ImageClassificationEstimator']

Expand Down Expand Up @@ -134,6 +135,11 @@ def _train_loop(self, train_data, val_data, time_limit=math.inf):
self.teacher.hybridize(static_alloc=True, static_shape=True)

self._logger.info('Start training from [Epoch %d]', max(self._cfg.train.start_epoch, self.epoch))
early_stopper = EarlyStopperOnPlateau(
patience=self._cfg.train.early_stop_patience,
min_delta=self._cfg.train.early_stop_min_delta,
baseline_value=self._cfg.train.early_stop_baseline,
max_value=self._cfg.train.early_stop_max_value)
train_metric_score = -1
cp_name = ''
self._time_elapsed += time.time() - start_tic
Expand All @@ -142,6 +148,10 @@ def _train_loop(self, train_data, val_data, time_limit=math.inf):
if self._best_acc >= 1.0:
self._logger.info('[Epoch {}] Early stopping as acc is reaching 1.0'.format(epoch))
break
should_stop, stop_message = early_stopper.get_early_stop_advice()
if should_stop:
self._logger.info('[Epoch {}] '.format(epoch) + stop_message)
break
tic = time.time()
last_tic = time.time()
mx.nd.waitall()
Expand Down Expand Up @@ -217,6 +227,7 @@ def _train_loop(self, train_data, val_data, time_limit=math.inf):
throughput = int(self.batch_size * i /(time.time() - tic))

top1_val, top5_val = self._evaluate(val_data)
early_stopper.update(top1_val)

self._logger.info('[Epoch %d] training: %s=%f', epoch, train_metric_name, train_metric_score)
self._logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f', epoch, throughput, time.time()-tic)
Expand Down
6 changes: 6 additions & 0 deletions gluoncv/auto/estimators/ssd/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ class TrainCfg:
# Currently supports only COCO.
dali : bool = False

early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled
early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics
# the baseline value for metric, training won't stop if not reaching baseline
early_stop_baseline : Union[float, int] = 0.0
early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly


@dataclass
class ValidCfg:
Expand Down
17 changes: 14 additions & 3 deletions gluoncv/auto/estimators/ssd/ssd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""SSD Estimator."""
# pylint: disable=logging-format-interpolation,abstract-method,arguments-differ
# pylint: disable=logging-format-interpolation,abstract-method,arguments-differ,logging-not-lazy
import os
import math
import time
Expand Down Expand Up @@ -27,6 +27,7 @@
from .default import SSDCfg
from ...data.dataset import ObjectDetectionDataset
from ..conf import _BEST_CHECKPOINT_FILE
from ..utils import EarlyStopperOnPlateau

try:
import horovod.mxnet as hvd
Expand Down Expand Up @@ -131,6 +132,11 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf
self._logger.info('Start training from [Epoch %d]', max(self._cfg.train.start_epoch, self.epoch))

self.net.collect_params().reset_ctx(self.ctx)
early_stopper = EarlyStopperOnPlateau(
patience=self._cfg.train.early_stop_patience,
min_delta=self._cfg.train.early_stop_min_delta,
baseline_value=self._cfg.train.early_stop_baseline,
max_value=self._cfg.train.early_stop_max_value)
mean_ap = [-1]
cp_name = ''
self._time_elapsed += time.time() - start_tic
Expand All @@ -141,6 +147,10 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf
if self._best_map >= 1.0:
self._logger.info('[Epoch {}] Early stopping as mAP is reaching 1.0'.format(epoch))
break
should_stop, stop_message = early_stopper.get_early_stop_advice()
if should_stop:
self._logger.info('[Epoch {}] '.format(epoch) + stop_message)
break
while lr_steps and epoch >= lr_steps[0]:
new_lr = self.trainer.learning_rate * lr_decay
lr_steps.pop(0)
Expand Down Expand Up @@ -218,8 +228,9 @@ def _train_loop(self, train_data, val_data, train_eval_data, time_limit=math.inf
self.epoch, current_map, self._best_map, cp_name)
self.save(cp_name)
self._best_map = current_map
if self._reporter:
self._reporter(epoch=epoch, map_reward=current_map)
if self._reporter:
self._reporter(epoch=epoch, map_reward=current_map)
early_stopper.update(current_map, epoch=epoch)
self._time_elapsed += time.time() - post_tic
# map on train data
tic = time.time()
Expand Down
98 changes: 97 additions & 1 deletion gluoncv/auto/estimators/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,102 @@
"""Utils for deep learning framework related functions"""
import numpy as np

__all__ = ['_suggest_load_context']
__all__ = ['EarlyStopperOnPlateau', '_suggest_load_context']


class EarlyStopperOnPlateau:
"""Early stopping on plateau helper.

Parameters
----------
patience : int, default is -1
How many epochs with no improvement after which train will be early stopped.
Negative patience means infinite petience.
metric_fn : function, default is None
The function to apply to metric value if any. For example, you can use
the `metric_fn` to cast loss to negative values where lower loss is better.
`min_delta`, `baseline_value` and `max_value` are all based on output of `metric_fn`.
min_delta : float, default is 1e-4
Early stopper ignores changes less than `min_delta` for metrics to ignore tiny fluctuates.
baseline_value : float, default is 0.0
The baseline metric value to be considered.
max_value : float, default is 1.0
Instantly early stop if reaching max value.

"""
def __init__(self, patience=10, metric_fn=None,
min_delta=1e-4, baseline_value=None, max_value=np.Inf):
self.patience = patience if patience > 0 else np.Inf
self.metric_fn = metric_fn
self.min_delta = np.abs(min_delta)
self.baseline_value = baseline_value
self.max_value = max_value
self.reset()

def reset(self):
"""reset the early stopper"""
self.last_epoch = 0
self.wait = 0
self._should_stop = False
self._message = ''
if self.baseline_value is not None:
self.best = self.baseline_value
else:
self.best = -np.Inf

def update(self, metric_value, epoch=None):
"""Update with end of epoch metric.

Parameters
----------
metric_value : float
The end of epoch metric.
epoch : int, optional
The real epoch in case the update function is not called in every epoch.

"""
if _is_real_number(epoch):
if _is_real_number(self.last_epoch):
diff_epoch = epoch - self.last_epoch
else:
diff_epoch = 1
self.last_epoch = epoch
else:
diff_epoch = 1
if not _is_real_number(metric_value):
return
if self.metric_fn is not None:
metric_value = self.metric_fn(metric_value)

if metric_value > self.max_value:
self._should_stop = True
self._message = 'EarlyStop given {} vs. max {}'.format(metric_value, self.max_value)
else:
if metric_value - self.min_delta > self.best:
self.best = metric_value
self.wait = 0
else:
self.wait += diff_epoch
if self.wait >= self.patience:
self._should_stop = True
self._message = 'EarlyStop after {} epochs: no better than {}'.format(self.patience, self.best)

def get_early_stop_advice(self):
"""Get the early stop advice.

Returns
-------
(bool, str)
should_stop : bool
Whether the stopper suggest OnPlateau pattern is active.
message : str
The detailed message why early stop is suggested, if `should_stop` is True.
"""
return self._should_stop, self._message

def _is_real_number(x):
"""Check if x is a real number"""
return isinstance(x, (int, float, complex)) and not isinstance(x, bool)

def _suggest_load_context(model, mode, orig_ctx):
"""Get the correct context given the mode"""
Expand Down
5 changes: 5 additions & 0 deletions gluoncv/auto/estimators/yolo/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class TrainCfg:
no_mixup_epochs : int = 20
# Use label smoothing.
label_smooth : bool = False
early_stop_patience : int = -1 # epochs with no improvement after which train is early stopped, negative: disabled
early_stop_min_delta : float = 0.001 # ignore changes less than min_delta for metrics
# the baseline value for metric, training won't stop if not reaching baseline
early_stop_baseline : Union[float, int] = 0.0
early_stop_max_value : Union[float, int] = 1.0 # early stop if reaching max value instantly


@dataclass
Expand Down
Loading