Skip to content

Commit

Permalink
LARC clipping+documentation (#6)
Browse files Browse the repository at this point in the history
* Proper implementation of LARC clipping
 * Documentation of LARC class
 * Modification of FP16_Optimizer to absorb optimizer instance that's being wrapped instead of creating new optimizer instance of same class.
  • Loading branch information
raulpuric authored and mcarilli committed Jul 3, 2018
1 parent 3458238 commit 88effd5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
2 changes: 1 addition & 1 deletion apex/fp16_utils/fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(self,
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)

self.optimizer = init_optimizer.__class__(init_optimizer.param_groups)
self.optimizer = optimizer

if dynamic_loss_scale:
self.dynamic_loss_scale = True
Expand Down
56 changes: 50 additions & 6 deletions apex/parallel/LARC.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,45 @@
from torch.nn.parameter import Parameter

class LARC(object):
def __init__(self, optimizer, trust_coefficient=0.02, epsilon=1e-8):
"""
:class:`LARC` is a pytorch implementation of both the scaling and clipping varients of LARC,
in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
local learning rate for each individual parameter. The algorithm is designed to improve
convergence of large batch training.
See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
```
model = ...
optim = torch.optim.Adam(model.parameters(), lr=...)
optim = LARC(optim)
```
It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.
```
model = ...
optim = torch.optim.Adam(model.parameters(), lr=...)
optim = LARC(optim)
optim = apex.fp16_utils.FP16_Optimizer(optim)
```
Args:
optimizer: Pytorch optimizer to wrap and modify learning rate for.
trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
eps: epsilon kludge to help with numerical stability while calculating adaotive_lr
"""

def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
self.param_groups = optimizer.param_groups
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.eps = epsilon
self.eps = eps
self.clip = clip

def __getstate__(self):
return self.optim.__getstate__()
Expand Down Expand Up @@ -43,10 +77,20 @@ def step(self):
if p.grad is None:
continue
param_norm = torch.norm(p.data)
# calculate adaptive lr + weight decay
adaptive_lr = (param_norm + self.eps) / (torch.norm(p.grad.data) + param_norm * weight_decay + self.eps)
p.grad.data += weight_decay * p.data
p.grad.data *= self.trust_coefficient * adaptive_lr
grad_norm = torch.norm(p.grad.data)

if param_norm != 0 and grad_norm != 0:
# calculate adaptive lr + weight decay
adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)

# clip learning rate for LARC
if self.clip:
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
adaptive_lr = min(adaptive_lr/group['lr'], 1)

p.grad.data += weight_decay * p.data
p.grad.data *= adaptive_lr

self.optim.step()
# return weight decay control to optimizer
for i, group in enumerate(self.optim.param_groups):
Expand Down

0 comments on commit 88effd5

Please sign in to comment.