-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlr.py
30 lines (25 loc) · 1.11 KB
/
lr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torch.optim.lr_scheduler import _LRScheduler
class PolynomialDecayLR(_LRScheduler):
def __init__(self, optimizer, warmup_updates, tot_updates, lr, end_lr, power, last_epoch=-1, verbose=False):
self.warmup_updates = warmup_updates
self.tot_updates = tot_updates
self.lr = lr
self.end_lr = end_lr
self.power = power
super(PolynomialDecayLR, self).__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if self._step_count <= self.warmup_updates:
self.warmup_factor = self._step_count / float(self.warmup_updates)
lr = self.warmup_factor * self.lr
elif self._step_count >= self.tot_updates:
lr = self.end_lr
else:
warmup = self.warmup_updates
lr_range = self.lr - self.end_lr
pct_remaining = 1 - (self._step_count - warmup) / (
self.tot_updates - warmup
)
lr = lr_range * pct_remaining ** (self.power) + self.end_lr
return [lr for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
assert False