-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #936 from scap3yvt/935-feature-add-the-ademamix-op…
…timizer Added the ademamix optimizer
- Loading branch information
Showing
4 changed files
with
223 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import math | ||
from typing import Callable, Iterable, List, Optional, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.optim import Optimizer | ||
|
||
|
||
class AdEMAMix(Optimizer): | ||
r"""Adapted from https://github.com/frgfm/Holocron/blob/main/holocron/optim/ademamix.py | ||
Implements the AdEMAMix optimizer from `"The AdEMAMix Optimizer: Better, Faster, Older" <https://arxiv.org/pdf/2409.03137>`_. | ||
The estimation of momentums is described as follows, :math:`\forall t \geq 1`: | ||
.. math:: | ||
m_{1,t} \leftarrow \beta_1 m_{1, t-1} + (1 - \beta_1) g_t \\ | ||
m_{2,t} \leftarrow \beta_3 m_{2, t-1} + (1 - \beta_3) g_t \\ | ||
s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon | ||
where :math:`g_t` is the gradient of :math:`\theta_t`, | ||
:math:`\beta_1, \beta_2, \beta_3 \in [0, 1]^3` are the exponential average smoothing coefficients, | ||
:math:`m_{1,0} = 0,\ m_{2,0} = 0,\ s_0 = 0`, :math:`\epsilon > 0`. | ||
Then we correct their biases using: | ||
.. math:: | ||
\hat{m_{1,t}} \leftarrow \frac{m_{1,t}}{1 - \beta_1^t} \\ | ||
\hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t} | ||
And finally the update step is performed using the following rule: | ||
.. math:: | ||
\theta_t \leftarrow \theta_{t-1} - \eta \frac{\hat{m_{1,t}} + \alpha m_{2,t}}{\sqrt{\hat{s_t}} + \epsilon} | ||
where :math:`\theta_t` is the parameter value at step :math:`t` (:math:`\theta_0` being the initialization value), | ||
:math:`\eta` is the learning rate, :math:`\alpha > 0` :math:`\epsilon > 0`. | ||
Args: | ||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups | ||
lr (float, optional): learning rate | ||
betas (Tuple[float, float, float], optional): coefficients used for running averages (default: (0.9, 0.999, 0.9999)) | ||
alpha (float, optional): the exponential decay rate of the second moment estimates (default: 5.0) | ||
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) | ||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | ||
amsgrad (bool, optional): whether to use the AMSGrad variant (default: False) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
params: Iterable[torch.nn.Parameter], | ||
lr: float = 1e-3, | ||
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), | ||
alpha: float = 5.0, | ||
eps: float = 1e-8, | ||
weight_decay: float = 0.0, | ||
) -> None: | ||
assert lr >= 0.0, f"Invalid learning rate: {lr}" | ||
assert eps >= 0.0, f"Invalid epsilon value: {eps}" | ||
assert all( | ||
0.0 <= beta < 1.0 for beta in betas | ||
), f"Invalid beta parameters: {betas}" | ||
defaults = { | ||
"lr": lr, | ||
"betas": betas, | ||
"alpha": alpha, | ||
"eps": eps, | ||
"weight_decay": weight_decay, | ||
} | ||
super().__init__(params, defaults) | ||
|
||
@torch.no_grad() | ||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override] | ||
"""Performs a single optimization step. | ||
Arguments: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
with torch.enable_grad(): | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
params_with_grad = [] | ||
grads = [] | ||
exp_avgs = [] | ||
exp_avgs_slow = [] | ||
exp_avg_sqs = [] | ||
state_steps = [] | ||
|
||
for p in group["params"]: | ||
if p.grad is not None: | ||
params_with_grad.append(p) | ||
if p.grad.is_sparse: | ||
raise RuntimeError( | ||
f"{self.__class__.__name__} does not support sparse gradients" | ||
) | ||
grads.append(p.grad) | ||
|
||
state = self.state[p] | ||
# Lazy state initialization | ||
if len(state) == 0: | ||
state["step"] = 0 | ||
# Exponential moving average of gradient values | ||
state["exp_avg"] = torch.zeros_like( | ||
p, memory_format=torch.preserve_format | ||
) | ||
state["exp_avg_slow"] = torch.zeros_like( | ||
p, memory_format=torch.preserve_format | ||
) | ||
# Exponential moving average of squared gradient values | ||
state["exp_avg_sq"] = torch.zeros_like( | ||
p, memory_format=torch.preserve_format | ||
) | ||
|
||
exp_avgs.append(state["exp_avg"]) | ||
exp_avgs_slow.append(state["exp_avg_slow"]) | ||
exp_avg_sqs.append(state["exp_avg_sq"]) | ||
|
||
# update the steps for each param group update | ||
state["step"] += 1 | ||
# record the step after step update | ||
state_steps.append(state["step"]) | ||
|
||
beta1, beta2, beta3 = group["betas"] | ||
_update_ademamix( | ||
params_with_grad, | ||
grads, | ||
exp_avgs, | ||
exp_avgs_slow, | ||
exp_avg_sqs, | ||
state_steps, | ||
beta1, | ||
beta2, | ||
beta3, | ||
group["alpha"], | ||
group["lr"], | ||
group["weight_decay"], | ||
group["eps"], | ||
) | ||
return loss | ||
|
||
|
||
def _update_ademamix( | ||
params: List[Tensor], | ||
grads: List[Tensor], | ||
exp_avgs: List[Tensor], | ||
exp_avgs_slow: List[Tensor], | ||
exp_avg_sqs: List[Tensor], | ||
state_steps: List[int], | ||
beta1: float, | ||
beta2: float, | ||
beta3: float, | ||
alpha: float, | ||
lr: float, | ||
weight_decay: float, | ||
eps: float, | ||
) -> None: | ||
r"""Functional API that performs AdaBelief algorithm computation. | ||
See :class:`~holocron.optim.AdaBelief` for details. | ||
""" | ||
for i, param in enumerate(params): | ||
grad = grads[i] | ||
m1 = exp_avgs[i] | ||
m2 = exp_avgs_slow[i] | ||
nu = exp_avg_sqs[i] | ||
step = state_steps[i] | ||
|
||
bias_correction1 = 1 - beta1**step | ||
bias_correction2 = 1 - beta2**step | ||
|
||
if weight_decay != 0: | ||
grad = grad.add(param, alpha=weight_decay) | ||
|
||
# Decay the first and second moment running average coefficient | ||
m1.mul_(beta1).add_(grad, alpha=1 - beta1) | ||
nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) | ||
m2.mul_(beta3).add_(grad, alpha=1 - beta3) | ||
|
||
denom = (nu.sqrt() / math.sqrt(bias_correction2)).add_(eps) | ||
|
||
param.addcdiv_(m1 / bias_correction1 + alpha * m2, denom, value=-lr) | ||
|
||
|
||
def ademamix_wrapper(parameters: dict) -> torch.optim.Optimizer: | ||
""" | ||
Creates an AdEMAMix optimizer from the PyTorch `torch.optim` module using the input parameters. | ||
Args: | ||
parameters (dict): A dictionary containing the input parameters for the optimizer. | ||
Returns: | ||
torch.optim.Optimizer: An AdEMAMix optimizer. | ||
""" | ||
|
||
return AdEMAMix( | ||
params=parameters["model_parameters"], | ||
lr=parameters.get("learning_rate", 1e-3), | ||
betas=parameters.get("betas", (0.9, 0.999, 0.9999)), | ||
alpha=parameters.get("alpha", 5.0), | ||
eps=parameters.get("eps", 1e-8), | ||
weight_decay=parameters.get("weight_decay", 0.0), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters