-
Notifications
You must be signed in to change notification settings - Fork 30
/
lookahead_pytorch.py
106 lines (88 loc) · 4.17 KB
/
lookahead_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from collections import defaultdict
import torch
from torch.optim.optimizer import Optimizer
class Lookahead(Optimizer):
r"""PyTorch implementation of the lookahead wrapper.
Lookahead Optimizer: https://arxiv.org/abs/1907.08610
"""
def __init__(self, optimizer, la_steps=5, la_alpha=0.8, pullback_momentum="none"):
"""optimizer: inner optimizer
la_steps (int): number of lookahead steps
la_alpha (float): linear interpolation factor. 1.0 recovers the inner optimizer.
pullback_momentum (str): change to inner optimizer momentum on interpolation update
"""
self.optimizer = optimizer
self._la_step = 0 # counter for inner optimizer
self.la_alpha = la_alpha
self._total_la_steps = la_steps
pullback_momentum = pullback_momentum.lower()
assert pullback_momentum in ["reset", "pullback", "none"]
self.pullback_momentum = pullback_momentum
self.state = defaultdict(dict)
# Cache the current optimizer parameters
for group in optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['cached_params'] = torch.zeros_like(p.data)
param_state['cached_params'].copy_(p.data)
if self.pullback_momentum == "pullback":
param_state['cached_mom'] = torch.zeros_like(p.data)
def __getstate__(self):
return {
'state': self.state,
'optimizer': self.optimizer,
'la_alpha': self.la_alpha,
'_la_step': self._la_step,
'_total_la_steps': self._total_la_steps,
'pullback_momentum': self.pullback_momentum
}
def zero_grad(self):
self.optimizer.zero_grad()
def get_la_step(self):
return self._la_step
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
def _backup_and_load_cache(self):
"""Useful for performing evaluation on the slow weights (which typically generalize better)
"""
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['backup_params'] = torch.zeros_like(p.data)
param_state['backup_params'].copy_(p.data)
p.data.copy_(param_state['cached_params'])
def _clear_and_load_backup(self):
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.copy_(param_state['backup_params'])
del param_state['backup_params']
@property
def param_groups(self):
return self.optimizer.param_groups
def step(self, closure=None):
"""Performs a single Lookahead optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = self.optimizer.step(closure)
self._la_step += 1
if self._la_step >= self._total_la_steps:
self._la_step = 0
# Lookahead and cache the current optimizer parameters
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.mul_(self.la_alpha).add_(param_state['cached_params'], alpha=1.0 - self.la_alpha) # crucial line
param_state['cached_params'].copy_(p.data)
if self.pullback_momentum == "pullback":
internal_momentum = self.optimizer.state[p]["momentum_buffer"]
self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.la_alpha).add_(
1.0 - self.la_alpha, param_state["cached_mom"])
param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
elif self.pullback_momentum == "reset":
self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)
return loss