Skip to content

Commit

Permalink
Merge pull request #936 from scap3yvt/935-feature-add-the-ademamix-op…
Browse files Browse the repository at this point in the history
…timizer

Added the ademamix optimizer
  • Loading branch information
sarthakpati authored Sep 11, 2024
2 parents b678958 + 1115527 commit f3cdff8
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 16 deletions.
3 changes: 3 additions & 0 deletions GANDLF/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from .wrap_monai import novograd_wrapper

from .ademamix import ademamix_wrapper

global_optimizer_dict = {
"sgd": sgd,
"asgd": asgd,
Expand All @@ -29,6 +31,7 @@
"radam": radam,
"novograd": novograd_wrapper,
"nadam": nadam,
"ademamix": ademamix_wrapper,
}


Expand Down
204 changes: 204 additions & 0 deletions GANDLF/optimizers/ademamix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import math
from typing import Callable, Iterable, List, Optional, Tuple

import torch
from torch import Tensor
from torch.optim import Optimizer


class AdEMAMix(Optimizer):
r"""Adapted from https://github.com/frgfm/Holocron/blob/main/holocron/optim/ademamix.py
Implements the AdEMAMix optimizer from `"The AdEMAMix Optimizer: Better, Faster, Older" <https://arxiv.org/pdf/2409.03137>`_.
The estimation of momentums is described as follows, :math:`\forall t \geq 1`:
.. math::
m_{1,t} \leftarrow \beta_1 m_{1, t-1} + (1 - \beta_1) g_t \\
m_{2,t} \leftarrow \beta_3 m_{2, t-1} + (1 - \beta_3) g_t \\
s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon
where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2, \beta_3 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`m_{1,0} = 0,\ m_{2,0} = 0,\ s_0 = 0`, :math:`\epsilon > 0`.
Then we correct their biases using:
.. math::
\hat{m_{1,t}} \leftarrow \frac{m_{1,t}}{1 - \beta_1^t} \\
\hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t}
And finally the update step is performed using the following rule:
.. math::
\theta_t \leftarrow \theta_{t-1} - \eta \frac{\hat{m_{1,t}} + \alpha m_{2,t}}{\sqrt{\hat{s_t}} + \epsilon}
where :math:`\theta_t` is the parameter value at step :math:`t` (:math:`\theta_0` being the initialization value),
:math:`\eta` is the learning rate, :math:`\alpha > 0` :math:`\epsilon > 0`.
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate
betas (Tuple[float, float, float], optional): coefficients used for running averages (default: (0.9, 0.999, 0.9999))
alpha (float, optional): the exponential decay rate of the second moment estimates (default: 5.0)
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (bool, optional): whether to use the AMSGrad variant (default: False)
"""

def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
eps: float = 1e-8,
weight_decay: float = 0.0,
) -> None:
assert lr >= 0.0, f"Invalid learning rate: {lr}"
assert eps >= 0.0, f"Invalid epsilon value: {eps}"
assert all(
0.0 <= beta < 1.0 for beta in betas
), f"Invalid beta parameters: {betas}"
defaults = {
"lr": lr,
"betas": betas,
"alpha": alpha,
"eps": eps,
"weight_decay": weight_decay,
}
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override]
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avgs_slow = []
exp_avg_sqs = []
state_steps = []

for p in group["params"]:
if p.grad is not None:
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError(
f"{self.__class__.__name__} does not support sparse gradients"
)
grads.append(p.grad)

state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["exp_avg_slow"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

exp_avgs.append(state["exp_avg"])
exp_avgs_slow.append(state["exp_avg_slow"])
exp_avg_sqs.append(state["exp_avg_sq"])

# update the steps for each param group update
state["step"] += 1
# record the step after step update
state_steps.append(state["step"])

beta1, beta2, beta3 = group["betas"]
_update_ademamix(
params_with_grad,
grads,
exp_avgs,
exp_avgs_slow,
exp_avg_sqs,
state_steps,
beta1,
beta2,
beta3,
group["alpha"],
group["lr"],
group["weight_decay"],
group["eps"],
)
return loss


def _update_ademamix(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avgs_slow: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[int],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
lr: float,
weight_decay: float,
eps: float,
) -> None:
r"""Functional API that performs AdaBelief algorithm computation.
See :class:`~holocron.optim.AdaBelief` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
m1 = exp_avgs[i]
m2 = exp_avgs_slow[i]
nu = exp_avg_sqs[i]
step = state_steps[i]

bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step

if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)

# Decay the first and second moment running average coefficient
m1.mul_(beta1).add_(grad, alpha=1 - beta1)
nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
m2.mul_(beta3).add_(grad, alpha=1 - beta3)

denom = (nu.sqrt() / math.sqrt(bias_correction2)).add_(eps)

param.addcdiv_(m1 / bias_correction1 + alpha * m2, denom, value=-lr)


def ademamix_wrapper(parameters: dict) -> torch.optim.Optimizer:
"""
Creates an AdEMAMix optimizer from the PyTorch `torch.optim` module using the input parameters.
Args:
parameters (dict): A dictionary containing the input parameters for the optimizer.
Returns:
torch.optim.Optimizer: An AdEMAMix optimizer.
"""

return AdEMAMix(
params=parameters["model_parameters"],
lr=parameters.get("learning_rate", 1e-3),
betas=parameters.get("betas", (0.9, 0.999, 0.9999)),
alpha=parameters.get("alpha", 5.0),
eps=parameters.get("eps", 1e-8),
weight_decay=parameters.get("weight_decay", 0.0),
)
5 changes: 3 additions & 2 deletions GANDLF/optimizers/wrap_monai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import monai
from monai.optimizers import Novograd


def novograd_wrapper(parameters):
def novograd_wrapper(parameters: dict) -> monai.optimizers.Novograd:
return Novograd(
parameters["model_parameters"],
lr=parameters.get("learning_rate"),
lr=parameters.get("learning_rate", 1e-3),
betas=parameters["optimizer"].get("betas", (0.9, 0.999)),
eps=parameters["optimizer"].get("eps", 1e-8),
weight_decay=parameters["optimizer"].get("weight_decay", 3e-05),
Expand Down
27 changes: 13 additions & 14 deletions GANDLF/optimizers/wrap_torch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch.optim import (
SGD,
ASGD,
Expand All @@ -14,7 +15,7 @@
)


def sgd(parameters):
def sgd(parameters: dict) -> torch.optim.SGD:
"""
Creates a Stochastic Gradient Descent optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand All @@ -26,7 +27,7 @@ def sgd(parameters):
"""
# Create the optimizer using the input parameters
optimizer = SGD(
return SGD(
parameters["model_parameters"],
lr=parameters.get("learning_rate"),
momentum=parameters["optimizer"].get("momentum", 0.99),
Expand All @@ -35,10 +36,8 @@ def sgd(parameters):
nesterov=parameters["optimizer"].get("nesterov", True),
)

return optimizer


def asgd(parameters):
def asgd(parameters: dict) -> torch.optim.ASGD:
"""
Creates an Averaged Stochastic Gradient Descent optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand All @@ -60,7 +59,7 @@ def asgd(parameters):
)


def adam(parameters, opt_type="normal"):
def adam(parameters: dict, opt_type: str = "normal") -> torch.optim.Adam:
"""
Creates an Adam or AdamW optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand Down Expand Up @@ -91,7 +90,7 @@ def adam(parameters, opt_type="normal"):
)


def adamw(parameters):
def adamw(parameters: dict) -> torch.optim.AdamW:
"""
Creates an AdamW optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand All @@ -105,7 +104,7 @@ def adamw(parameters):
return adam(parameters, opt_type="AdamW")


def adamax(parameters):
def adamax(parameters: dict) -> torch.optim.Adamax:
"""
Creates an Adamax optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand Down Expand Up @@ -141,7 +140,7 @@ def adamax(parameters):
# )


def rprop(parameters):
def rprop(parameters: dict) -> torch.optim.Rprop:
"""
Creates a Resilient Backpropagation optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand All @@ -161,7 +160,7 @@ def rprop(parameters):
)


def adadelta(parameters):
def adadelta(parameters: dict) -> torch.optim.Adadelta:
"""
Creates an Adadelta optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand All @@ -182,7 +181,7 @@ def adadelta(parameters):
)


def adagrad(parameters):
def adagrad(parameters: dict) -> torch.optim.Adagrad:
"""
Creates an Adagrad optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand All @@ -204,7 +203,7 @@ def adagrad(parameters):
)


def rmsprop(parameters):
def rmsprop(parameters: dict) -> torch.optim.RMSprop:
"""
Creates an RMSprop optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand All @@ -227,7 +226,7 @@ def rmsprop(parameters):
)


def radam(parameters):
def radam(parameters: dict) -> torch.optim.RAdam:
"""
Creates a RAdam optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand All @@ -248,7 +247,7 @@ def radam(parameters):
)


def nadam(parameters):
def nadam(parameters: dict) -> torch.optim.NAdam:
"""
Creates a NAdam optimizer from the PyTorch `torch.optim` module using the input parameters.
Expand Down

0 comments on commit f3cdff8

Please sign in to comment.