Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SophiaG optimizer #844

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/fairseq2/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from fairseq2.optim.dynamic_loss_scaler import DynamicLossScaler as DynamicLossScaler
from fairseq2.optim.dynamic_loss_scaler import LossScaleResult as LossScaleResult
from fairseq2.optim.factory import AdamWConfig as AdamWConfig
from fairseq2.optim.factory import SophiaGConfig as SophiaGConfig
from fairseq2.optim.factory import create_adamw_optimizer as create_adamw_optimizer
from fairseq2.optim.factory import create_optimizer as create_optimizer
from fairseq2.optim.factory import optimizer_factories as optimizer_factories
from fairseq2.optim.factory import optimizer_factory as optimizer_factory
from fairseq2.optim.optimizer import AbstractOptimizer as AbstractOptimizer
from fairseq2.optim.optimizer import ParameterCollection as ParameterCollection
from fairseq2.optim.sophiag import SophiaG as SophiaG
44 changes: 44 additions & 0 deletions src/fairseq2/optim/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from fairseq2.factory_registry import ConfigBoundFactoryRegistry
from fairseq2.optim.adamw import AdamW
from fairseq2.optim.optimizer import ParameterCollection
from fairseq2.optim.sophiag import SophiaG

optimizer_factories = ConfigBoundFactoryRegistry[[ParameterCollection], Optimizer]()

Expand Down Expand Up @@ -95,3 +96,46 @@ def create_adamw_optimizer(config: AdamWConfig, params: ParameterCollection) ->
impl=config.impl,
use_fp32=config.use_fp32,
)


@dataclass(kw_only=True)
class SophiaGConfig:
"""Holds the configuration of a :class:`SophiaG`."""

lr: float = 1e-4
"""The learning rate."""

betas: tuple[float, float] = (0.965, 0.99)
"""The coefficients used for computing running averages of gradient and its
square."""

rho: float = 0.04
"""The parameter clipping threshold."""

k: int = 10
"""The number of optimizer steps before updating the parameter Hessian values."""

weight_decay: float = 1e-1
"""The weight decay coefficient."""

maximize: bool = False
"""If ``True``, maximizes the parameters instead of minimizing."""

capturable: bool = False
"""If ``True``, it is safe to capture this instance in a CUDA graph."""


@optimizer_factory("sophiag")
def create_sophia_optimizer(
config: SophiaGConfig, params: ParameterCollection
) -> SophiaG:
return SophiaG(
params,
lr=config.lr,
betas=config.betas,
rho=config.rho,
k=config.k,
weight_decay=config.weight_decay,
maximize=config.maximize,
capturable=config.capturable,
)
213 changes: 213 additions & 0 deletions src/fairseq2/optim/sophiag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Union

import torch
from torch import Tensor

from fairseq2.optim.optimizer import AbstractOptimizer, ParameterCollection


class SophiaG(AbstractOptimizer):
"""Represents a SophiaG optimizer."""

def __init__(
self,
params: ParameterCollection,
lr: float = 1e-4,
betas: tuple[float, float] = (0.965, 0.99),
k: int = 10,
rho: float = 0.04,
weight_decay: float = 1e-1,
*,
maximize: bool = False,
capturable: bool = False,
) -> None:
"""
:param params:
The parameters to optimize.
:param lr:
The learning rate.
:param betas:
The coefficients used for computing running averages of gradient and
its square.
:param rho:
The parameter clipping threshold.
:param k:
The number of optimizer steps before updating the parameter Hessian values.
:param weight_decay:
The weight decay coefficient.
:param maximize:
If ``True``, maximizes the parameters instead of minimizing.
:param capturable:
If ``True``, it is safe to capture this instance in a CUDA graph.
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= rho:
raise ValueError(f"Invalid rho parameter: {rho}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay parameter: {weight_decay}")

defaults = {
"lr": lr,
"betas": betas,
"rho": rho,
"k": k,
"weight_decay": weight_decay,
"maximize": maximize,
"capturable": capturable,
"differentiable": False,
}

super().__init__(params, defaults)

def __setstate__(self, state: Dict[str, Any]) -> None:
super().__setstate__(state)

for group in self.param_groups:
group.setdefault("maximize", False)
group.setdefault("capturable", False)

state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["step"]
)

if not step_is_tensor:
for value in state_values:
value["step"] = torch.tensor(float(value["step"]))

def _do_step(self, bs: Union[int, Tensor] = 5120) -> None:
for group in self.param_groups:
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
exp_avgs: list[Tensor] = []
hessian: list[Tensor] = []
state_steps: list[Tensor] = []
beta1, beta2 = group["betas"]

for p in group["params"]:
if p.grad is None:
continue

params_with_grad.append(p)

if p.grad.is_sparse:
raise RuntimeError("Hero does not support sparse gradients")

grads.append(p.grad)
state = self.state[p]

# State initialization.
if len(state) == 0:
state["step"] = (
torch.zeros((1,), dtype=torch.float, device=p.device)
if self.defaults["capturable"]
else torch.tensor(0.0)
)
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["hessian"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

if "hessian" not in state.keys():
state["hessian"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

# Hessian update.
if int(state["step"]) % group["k"] == 0:
state["hessian"].mul_(beta2).addcmul_(
p.grad, p.grad, value=1 - beta2
)

exp_avgs.append(state["exp_avg"])
state_steps.append(state["step"])
hessian.append(state["hessian"])

if self.defaults["capturable"]:
bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs

sophiag(
params_with_grad,
grads,
exp_avgs,
hessian,
state_steps,
bs=bs,
beta1=beta1,
beta2=beta2,
rho=group["rho"],
lr=group["lr"],
weight_decay=group["weight_decay"],
maximize=group["maximize"],
capturable=group["capturable"],
)


def sophiag(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
hessian: List[Tensor],
state_steps: List[Tensor],
capturable: bool = False,
*,
bs: Union[int, Tensor],
beta1: float,
beta2: float,
rho: float,
lr: Union[float, Tensor],
weight_decay: float,
maximize: bool,
) -> None:
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)

for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
hess = hessian[i]
step_t = state_steps[i]

if capturable:
assert param.is_cuda and step_t.is_cuda and bs.is_cuda # type: ignore[union-attr]

if torch.is_complex(param):
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
hess = torch.view_as_real(hess)
param = torch.view_as_real(param)

step_t += 1

# Perform stepweight decay.
param.mul_(1 - lr * weight_decay)

# Decay the first and second moment running average coefficient.
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

if capturable:
step_size = lr
step_size_neg = step_size.neg() # type: ignore[union-attr]

ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1)
param.addcmul_(exp_avg.sign(), ratio, value=float(step_size_neg))
else:
step_size_neg = -lr

ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1)
param.addcmul_(exp_avg.sign(), ratio, value=float(step_size_neg))
13 changes: 12 additions & 1 deletion src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from fairseq2.nn.checkpointing import use_layerwise_activation_checkpointing
from fairseq2.nn.transformer import enable_memory_efficient_torch_sdpa
from fairseq2.optim import AdamWConfig, create_optimizer
from fairseq2.optim import AdamWConfig, SophiaGConfig, create_optimizer
from fairseq2.optim.lr_scheduler import CosineAnnealingLRConfig, create_lr_scheduler
from fairseq2.recipes.common_metrics import SequenceMetricBag
from fairseq2.recipes.evaluator import AbstractEvalUnit
Expand Down Expand Up @@ -281,6 +281,17 @@ def _llama2_70b_chat() -> InstructionFinetuneConfig:
return config


@instruction_finetune_preset("llama3_2_1b_instruct_sophiag")
def _llama3_2_1b_instruct_sophiag() -> InstructionFinetuneConfig:
config = InstructionFinetuneConfig()

config.model = "llama3_2_1b_instruct"
config.optimizer = "sophiag"
config.optimizer_config = SophiaGConfig()

return config


def load_instruction_finetuner(
config: InstructionFinetuneConfig, output_dir: Path
) -> Trainer[SequenceBatch]:
Expand Down
68 changes: 68 additions & 0 deletions tests/unit/optim/test_sophiag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import pytest
import torch
from torch import Tensor, tensor
from torch.nn import Conv2d, Module
from torch.nn.functional import relu

from fairseq2.optim import SophiaG
from fairseq2.utils.rng import temporary_manual_seed
from tests.common import assert_close, device


class SophiaGTestNet(Module):
def __init__(self) -> None:
super().__init__()

self.conv1 = Conv2d(4, 2, 1, device=device, dtype=torch.float32)
self.conv2 = Conv2d(2, 1, 1, device=device, dtype=torch.float32)

def forward(self, x: Tensor) -> Tensor:
return self.conv2(relu(self.conv1(x))) # type: ignore[no-any-return]


class TestSophiaG:
@pytest.mark.skipif(device.type != "cpu", reason="requires CPU")
def test_step_updates_params_correctly(self) -> None:
net = self.run_step()
torch.set_printoptions(precision=5)

weights = [
[
[[[0.11474]], [[-0.11898]], [[0.13711]], [[-0.02554]]],
[[[0.21359]], [[0.11903]], [[-0.05746]], [[-0.40423]]],
],
[0.11416, -0.44267],
[[[[0.09293]], [[0.04699]]]],
[-0.15549],
]

expected = list(map(lambda t: tensor(t, device=device), weights))

for p, weight in zip(net.parameters(), expected):
assert_close(p.data, weight)

def run_step(self) -> Module:
with temporary_manual_seed(2, device):
net = SophiaGTestNet()
x = torch.randn((2, 4, 12, 4), device=device, dtype=torch.float32)

optimizer = SophiaG(
params=[ # type: ignore[arg-type]
{"params": net.conv1.parameters()},
{"params": net.conv2.parameters()},
],
)

out = net(x).sum()
out.backward()
optimizer.step()

return net