-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-535] Fix bugs in LR Schedulers and add warmup #11234
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
"""Scheduling learning rate.""" | ||
import logging | ||
from math import cos, pi | ||
|
||
class LRScheduler(object): | ||
"""Base class of a learning rate scheduler. | ||
|
@@ -29,8 +30,31 @@ class LRScheduler(object): | |
base_lr : float, optional | ||
The initial learning rate. | ||
""" | ||
def __init__(self, base_lr=0.01): | ||
def __init__(self, base_lr=0.01, warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'): | ||
self.base_lr = base_lr | ||
assert isinstance(warmup_steps, int) | ||
self.warmup_steps = warmup_steps | ||
|
||
self.warmup_final_lr = base_lr | ||
self.warmup_begin_lr = warmup_begin_lr | ||
if self.warmup_begin_lr > self.warmup_final_lr: | ||
raise ValueError("Base lr has to be higher than warmup_begin_lr") | ||
if self.warmup_steps < 0: | ||
raise ValueError("Warmup steps has to be positive or 0") | ||
if warmup_mode not in ['linear', 'constant']: | ||
raise ValueError("Supports only linear and constant modes of warmup") | ||
self.warmup_mode = warmup_mode | ||
|
||
def get_warmup_lr(self, num_update): | ||
assert num_update < self.warmup_steps | ||
if self.warmup_mode == 'linear': | ||
increase = (self.warmup_final_lr - self.warmup_begin_lr) \ | ||
* float(num_update)/float(self.warmup_steps) | ||
return self.warmup_begin_lr + increase | ||
elif self.warmup_mode == 'constant': | ||
return self.warmup_begin_lr | ||
else: | ||
raise ValueError("Invalid warmup mode %s"%self.warmup_mode) | ||
|
||
def __call__(self, num_update): | ||
"""Return a new learning rate. | ||
|
@@ -66,8 +90,9 @@ class FactorScheduler(LRScheduler): | |
stop_factor_lr : float, optional | ||
Stop updating the learning rate if it is less than this value. | ||
""" | ||
def __init__(self, step, factor=1, stop_factor_lr=1e-8): | ||
super(FactorScheduler, self).__init__() | ||
def __init__(self, step, factor=1, stop_factor_lr=1e-8, base_lr=0.01, | ||
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'): | ||
super(FactorScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode) | ||
if step < 1: | ||
raise ValueError("Schedule step must be greater or equal than 1 round") | ||
if factor > 1.0: | ||
|
@@ -78,6 +103,9 @@ def __init__(self, step, factor=1, stop_factor_lr=1e-8): | |
self.count = 0 | ||
|
||
def __call__(self, num_update): | ||
if num_update < self.warmup_steps: | ||
return self.get_warmup_lr(num_update) | ||
|
||
# NOTE: use while rather than if (for continuing training via load_epoch) | ||
while num_update > self.count + self.step: | ||
self.count += self.step | ||
|
@@ -109,8 +137,10 @@ class MultiFactorScheduler(LRScheduler): | |
factor: float | ||
The factor to change the learning rate. | ||
""" | ||
def __init__(self, step, factor=1): | ||
super(MultiFactorScheduler, self).__init__() | ||
def __init__(self, step, factor=1, base_lr=0.01, warmup_steps=0, warmup_begin_lr=0, | ||
warmup_mode='linear'): | ||
super(MultiFactorScheduler, self).__init__(base_lr, warmup_steps, | ||
warmup_begin_lr, warmup_mode) | ||
assert isinstance(step, list) and len(step) >= 1 | ||
for i, _step in enumerate(step): | ||
if i != 0 and step[i] <= step[i-1]: | ||
|
@@ -125,6 +155,9 @@ def __init__(self, step, factor=1): | |
self.count = 0 | ||
|
||
def __call__(self, num_update): | ||
if num_update < self.warmup_steps: | ||
return self.get_warmup_lr(num_update) | ||
|
||
# NOTE: use while rather than if (for continuing training via load_epoch) | ||
while self.cur_step_ind <= len(self.step)-1: | ||
if num_update > self.step[self.cur_step_ind]: | ||
|
@@ -138,33 +171,73 @@ def __call__(self, num_update): | |
return self.base_lr | ||
|
||
class PolyScheduler(LRScheduler): | ||
""" Reduce the learning rate according to a polynomial of given power. | ||
|
||
Calculate the new learning rate by:: | ||
|
||
final_lr + (start_lr - final_lr) * (1-nup/max_nup)^pwr | ||
if nup < max_nup, 0 otherwise. | ||
|
||
Parameters | ||
---------- | ||
max_update: maximum number of updates before the decay reaches final learning rate. | ||
base_lr: base learning rate to start from | ||
pwr: power of the decay term as a function of the current number of updates. | ||
final_lr: final learning rate after all steps | ||
warmup_steps: number of warmup steps used before this scheduler starts decay | ||
""" | ||
|
||
def __init__(self, max_update, base_lr=0.01, pwr=2, final_lr=0, | ||
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'): | ||
super(PolyScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode) | ||
assert isinstance(max_update, int) | ||
if max_update < 1: | ||
raise ValueError("maximum number of updates must be strictly positive") | ||
self.power = pwr | ||
self.base_lr_orig = self.base_lr | ||
self.max_update = max_update | ||
self.final_lr = final_lr | ||
self.max_steps = self.max_update - self.warmup_steps | ||
|
||
def __call__(self, num_update): | ||
if num_update < self.warmup_steps: | ||
return self.get_warmup_lr(num_update) | ||
if num_update <= self.max_update: | ||
self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * \ | ||
pow(1 - float(num_update - self.warmup_steps) / float(self.max_steps), self.power) | ||
return self.base_lr | ||
|
||
class CosineScheduler(LRScheduler): | ||
""" Reduce the learning rate by given a list of steps. | ||
|
||
Calculate the new learning rate by:: | ||
|
||
base_lr * (1-nup/max_nup)^pwr | ||
final_lr + (start_lr - final_lr) * (1+cos(pi * nup/max_nup))/2 | ||
if nup < max_nup, 0 otherwise. | ||
|
||
Parameters | ||
---------- | ||
max_update: maximum number of updates before the decay reaches 0. | ||
max_update: maximum number of updates before the decay reaches 0 | ||
base_lr: base learning rate | ||
pwr: power of the decay term as a funtion of the current number of updates. | ||
|
||
final_lr: final learning rate after all steps | ||
warmup_steps: number of warmup steps used before this scheduler starts decay | ||
""" | ||
|
||
def __init__(self, max_update, base_lr=0.01, pwr=2): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't remove base_lr, it will break API. |
||
super(PolyScheduler, self).__init__(base_lr) | ||
def __init__(self, max_update, base_lr=0.01, final_lr=0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why did you remove pwr? This is API breakage There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've not removed it. Git is getting confused :/ It thinks I've changed PolyScheduler to CosineScheduler when in fact I've modified PolyScheduler and added a new CosineScheduler. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please refer #11234 (comment) |
||
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'): | ||
super(CosineScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode) | ||
assert isinstance(max_update, int) | ||
if max_update < 1: | ||
raise ValueError("maximum number of updates must be strictly positive") | ||
self.base_lr_orig = self.base_lr | ||
self.base_lr_orig = base_lr | ||
self.max_update = max_update | ||
self.power = pwr | ||
self.base_lr = self.base_lr_orig | ||
self.final_lr = final_lr | ||
self.max_steps = self.max_update - self.warmup_steps | ||
|
||
def __call__(self, num_update): | ||
if num_update < self.warmup_steps: | ||
return self.get_warmup_lr(num_update) | ||
if num_update <= self.max_update: | ||
self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update), | ||
self.power) | ||
self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * \ | ||
(1 + cos(pi * (num_update - self.warmup_steps) / self.max_steps)) / 2 | ||
return self.base_lr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind adding documentation for warmup_begin_lr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was some for the inherited classes, but not for this base abstract class. Anyway, now added for all. Please check.