diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py index 3f37ad3ac591..dc983ea6f8d4 100755 --- a/example/image-classification/common/fit.py +++ b/example/image-classification/common/fit.py @@ -49,7 +49,8 @@ def _get_lr_scheduler(args, kv): steps = [epoch_size * (x - begin_epoch) for x in step_epochs if x - begin_epoch > 0] if steps: - return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor)) + return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor, + base_lr=args.lr)) else: return (lr, None) diff --git a/python/mxnet/lr_scheduler.py b/python/mxnet/lr_scheduler.py index 963560d17853..51e0bfffbbcc 100644 --- a/python/mxnet/lr_scheduler.py +++ b/python/mxnet/lr_scheduler.py @@ -17,6 +17,7 @@ """Scheduling learning rate.""" import logging +from math import cos, pi class LRScheduler(object): """Base class of a learning rate scheduler. @@ -29,8 +30,31 @@ class LRScheduler(object): base_lr : float, optional The initial learning rate. """ - def __init__(self, base_lr=0.01): + def __init__(self, base_lr=0.01, warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'): self.base_lr = base_lr + assert isinstance(warmup_steps, int) + self.warmup_steps = warmup_steps + + self.warmup_final_lr = base_lr + self.warmup_begin_lr = warmup_begin_lr + if self.warmup_begin_lr > self.warmup_final_lr: + raise ValueError("Base lr has to be higher than warmup_begin_lr") + if self.warmup_steps < 0: + raise ValueError("Warmup steps has to be positive or 0") + if warmup_mode not in ['linear', 'constant']: + raise ValueError("Supports only linear and constant modes of warmup") + self.warmup_mode = warmup_mode + + def get_warmup_lr(self, num_update): + assert num_update < self.warmup_steps + if self.warmup_mode == 'linear': + increase = (self.warmup_final_lr - self.warmup_begin_lr) \ + * float(num_update)/float(self.warmup_steps) + return self.warmup_begin_lr + increase + elif self.warmup_mode == 'constant': + return self.warmup_begin_lr + else: + raise ValueError("Invalid warmup mode %s"%self.warmup_mode) def __call__(self, num_update): """Return a new learning rate. @@ -66,8 +90,9 @@ class FactorScheduler(LRScheduler): stop_factor_lr : float, optional Stop updating the learning rate if it is less than this value. """ - def __init__(self, step, factor=1, stop_factor_lr=1e-8): - super(FactorScheduler, self).__init__() + def __init__(self, step, factor=1, stop_factor_lr=1e-8, base_lr=0.01, + warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'): + super(FactorScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode) if step < 1: raise ValueError("Schedule step must be greater or equal than 1 round") if factor > 1.0: @@ -78,6 +103,9 @@ def __init__(self, step, factor=1, stop_factor_lr=1e-8): self.count = 0 def __call__(self, num_update): + if num_update < self.warmup_steps: + return self.get_warmup_lr(num_update) + # NOTE: use while rather than if (for continuing training via load_epoch) while num_update > self.count + self.step: self.count += self.step @@ -109,8 +137,10 @@ class MultiFactorScheduler(LRScheduler): factor: float The factor to change the learning rate. """ - def __init__(self, step, factor=1): - super(MultiFactorScheduler, self).__init__() + def __init__(self, step, factor=1, base_lr=0.01, warmup_steps=0, warmup_begin_lr=0, + warmup_mode='linear'): + super(MultiFactorScheduler, self).__init__(base_lr, warmup_steps, + warmup_begin_lr, warmup_mode) assert isinstance(step, list) and len(step) >= 1 for i, _step in enumerate(step): if i != 0 and step[i] <= step[i-1]: @@ -125,6 +155,9 @@ def __init__(self, step, factor=1): self.count = 0 def __call__(self, num_update): + if num_update < self.warmup_steps: + return self.get_warmup_lr(num_update) + # NOTE: use while rather than if (for continuing training via load_epoch) while self.cur_step_ind <= len(self.step)-1: if num_update > self.step[self.cur_step_ind]: @@ -138,33 +171,73 @@ def __call__(self, num_update): return self.base_lr class PolyScheduler(LRScheduler): + """ Reduce the learning rate according to a polynomial of given power. + + Calculate the new learning rate by:: + + final_lr + (start_lr - final_lr) * (1-nup/max_nup)^pwr + if nup < max_nup, 0 otherwise. + + Parameters + ---------- + max_update: maximum number of updates before the decay reaches final learning rate. + base_lr: base learning rate to start from + pwr: power of the decay term as a function of the current number of updates. + final_lr: final learning rate after all steps + warmup_steps: number of warmup steps used before this scheduler starts decay + """ + + def __init__(self, max_update, base_lr=0.01, pwr=2, final_lr=0, + warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'): + super(PolyScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode) + assert isinstance(max_update, int) + if max_update < 1: + raise ValueError("maximum number of updates must be strictly positive") + self.power = pwr + self.base_lr_orig = self.base_lr + self.max_update = max_update + self.final_lr = final_lr + self.max_steps = self.max_update - self.warmup_steps + + def __call__(self, num_update): + if num_update < self.warmup_steps: + return self.get_warmup_lr(num_update) + if num_update <= self.max_update: + self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * \ + pow(1 - float(num_update - self.warmup_steps) / float(self.max_steps), self.power) + return self.base_lr + +class CosineScheduler(LRScheduler): """ Reduce the learning rate by given a list of steps. Calculate the new learning rate by:: - base_lr * (1-nup/max_nup)^pwr + final_lr + (start_lr - final_lr) * (1+cos(pi * nup/max_nup))/2 if nup < max_nup, 0 otherwise. Parameters ---------- - max_update: maximum number of updates before the decay reaches 0. + max_update: maximum number of updates before the decay reaches 0 base_lr: base learning rate - pwr: power of the decay term as a funtion of the current number of updates. - + final_lr: final learning rate after all steps + warmup_steps: number of warmup steps used before this scheduler starts decay """ - def __init__(self, max_update, base_lr=0.01, pwr=2): - super(PolyScheduler, self).__init__(base_lr) + def __init__(self, max_update, base_lr=0.01, final_lr=0, + warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'): + super(CosineScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode) assert isinstance(max_update, int) if max_update < 1: raise ValueError("maximum number of updates must be strictly positive") - self.base_lr_orig = self.base_lr + self.base_lr_orig = base_lr self.max_update = max_update - self.power = pwr - self.base_lr = self.base_lr_orig + self.final_lr = final_lr + self.max_steps = self.max_update - self.warmup_steps def __call__(self, num_update): + if num_update < self.warmup_steps: + return self.get_warmup_lr(num_update) if num_update <= self.max_update: - self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update), - self.power) + self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * \ + (1 + cos(pi * (num_update - self.warmup_steps) / self.max_steps)) / 2 return self.base_lr diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index 90762f7620ff..b0658de6b690 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -1033,6 +1033,55 @@ def test_adagrad(): w_stype='row_sparse', g_stype='row_sparse') +def test_factor_scheduler(): + base_lr = 1 + step = 100 + factor = 0.1 + sched = mx.lr_scheduler.FactorScheduler(step, factor, stop_factor_lr=1e-4, base_lr=base_lr, + warmup_steps=20, warmup_begin_lr=0.1, warmup_mode='constant') + assert (sched(0) == 0.1) + np.testing.assert_almost_equal(sched(10), 0.1) + assert (sched(21) == base_lr), sched(21) + np.testing.assert_almost_equal(sched(101), base_lr * factor) + np.testing.assert_almost_equal(sched(201), base_lr * factor * factor) + np.testing.assert_almost_equal(sched(1000), 1e-4) + +def test_multifactor_scheduler(): + base_lr = 0.1 + steps = [15, 25] + factor = 0.1 + sched = mx.lr_scheduler.MultiFactorScheduler(steps, factor, base_lr=base_lr, + warmup_steps=10, warmup_begin_lr=0.05, warmup_mode='linear') + assert sched(0) == 0.05 + np.testing.assert_almost_equal(sched(5), 0.05 + (base_lr - 0.05)/2) + np.testing.assert_almost_equal(sched(15), base_lr) + np.testing.assert_almost_equal(sched(16), base_lr * factor) + np.testing.assert_almost_equal(sched(20), base_lr * factor) + np.testing.assert_almost_equal(sched(26), base_lr * factor * factor) + np.testing.assert_almost_equal(sched(100), base_lr * factor * factor) + +def test_poly_scheduler(): + base_lr = 3 + final_lr = 0 + steps = 1000 + poly_sched = mx.lr_scheduler.PolyScheduler(steps, base_lr=base_lr, pwr=2, final_lr=final_lr, + warmup_steps=100, warmup_begin_lr=0, warmup_mode='linear') + np.testing.assert_almost_equal(poly_sched(0), 0) + np.testing.assert_almost_equal(poly_sched(50), float(base_lr)/2) + np.testing.assert_almost_equal(poly_sched(100), base_lr) + assert (poly_sched(101) < poly_sched(100)) + assert (poly_sched(500) < 1.6) + np.testing.assert_almost_equal(poly_sched(steps), final_lr) + +def test_cosine_scheduler(): + # also tests case without warmup + base_lr = 3 + final_lr = 0.1 + steps = 1000 + cosine_sched = mx.lr_scheduler.CosineScheduler(steps, base_lr=base_lr, final_lr=final_lr) + np.testing.assert_almost_equal(cosine_sched(0), base_lr) + np.testing.assert_almost_equal(cosine_sched(steps), final_lr) + assert (cosine_sched(500) > 1.5) if __name__ == '__main__': import nose