Skip to content

Commit

Permalink
Optimizer offloading through weight-only offload
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Dec 4, 2024
1 parent 1af2ba8 commit edb04df
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 40 deletions.
7 changes: 4 additions & 3 deletions axlearn/common/factorized_rms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from axlearn.common import factorized_rms
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.optimizer_base import (
NestedOptStateSpec,
Nested,
OptParam,
OptStateSpec,
PartitionedGradientTransformation,
)
from axlearn.common.optimizers import OptStateSpec, with_partition_fn
from axlearn.common.optimizers import with_partition_fn
from axlearn.common.test_utils import TestCase
from axlearn.common.utils import PartitionSpec, flatten_items

Expand Down Expand Up @@ -59,7 +60,7 @@ def testParity(self, factored, dtype):

# The 'exp' optimizer is partitioned according to the mesh_axes of parameters and
# factorization spec.
exp_partition: NestedOptStateSpec = exp.partition(param_specs)
exp_partition: Nested[OptStateSpec] = exp.partition(param_specs)
# Used for `count`.
count_spec = OptStateSpec(
dtype=jnp.int32,
Expand Down
8 changes: 3 additions & 5 deletions axlearn/common/optimizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
- weight_decay_scale: control the weight decay rate.
"""
import dataclasses
from collections.abc import Sequence
from typing import Any, Callable, NamedTuple, Optional, Union

import optax
import typing_extensions

from axlearn.common.base_layer import FactorizationSpec, NestedParameterSpec
from axlearn.common.utils import Tensor, TensorSpec
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.utils import Nested, Tensor, TensorSpec


@dataclasses.dataclass
Expand Down Expand Up @@ -66,8 +65,7 @@ def __call__(

# Specification of an optimizer state array.
OptStateSpec = TensorSpec
NestedOptStateSpec = Union[OptStateSpec, dict, Sequence]
TransformPartitionSpecFn = Callable[[NestedParameterSpec], NestedOptStateSpec]
TransformPartitionSpecFn = Callable[[Nested[ParameterSpec]], Nested[OptStateSpec]]


class PartitionedGradientTransformation(NamedTuple):
Expand Down
125 changes: 115 additions & 10 deletions axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@
import typing_extensions
from absl import logging
from jax import numpy as jnp
from jax._src.sharding_impls import TransferToMemoryKind
from optax._src import numerics

from axlearn.common import schedule, struct
from axlearn.common.base_layer import NestedParameterSpec, ParameterSpec, PartitionSpec
from axlearn.common.base_layer import ParameterSpec, PartitionSpec
from axlearn.common.config import ConfigOr, maybe_instantiate
from axlearn.common.factorized_rms import scale_by_factored_rms
from axlearn.common.module import current_context
Expand All @@ -51,8 +52,8 @@
TransformPartitionSpecFn,
)
from axlearn.common.utils import (
MemoryKind,
Nested,
NestedPartitionSpec,
NestedTensor,
NestedTree,
Tensor,
Expand Down Expand Up @@ -139,10 +140,24 @@ def update_fn(
return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)


def copy_partition(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def copy_partition(
param_specs: Nested[ParameterSpec], *, memory_kind: Optional[MemoryKind] = None
) -> Nested[OptStateSpec]:
"""Creates OptStateSpec from ParameterSpec with possibly a different memory kind.
Args:
param_specs: Nested[ParameterSpec] to copy from.
memory_kind: New memory kind. Default to None, which means the memory kind of `param_specs`
is kept in the result.
Returns:
A Nested[OptStateSpec] with possibly a different memory kind.
"""
return jax.tree.map(
lambda param_spec: OptStateSpec(
dtype=param_spec.dtype, shape=param_spec.shape, mesh_axes=param_spec.mesh_axes
dtype=param_spec.dtype,
shape=param_spec.shape,
mesh_axes=param_spec.mesh_axes,
memory_kind=memory_kind or param_spec.memory_kind,
),
param_specs,
)
Expand All @@ -151,7 +166,7 @@ def copy_partition(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def trace_partition(
base: optax.GradientTransformation,
) -> PartitionedGradientTransformation:
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return optax.TraceState(trace=copy_partition(param_specs))

return with_partition_fn(base, partition_fn)
Expand All @@ -160,7 +175,7 @@ def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def adam_partition(base: optax.GradientTransformation) -> PartitionedGradientTransformation:
state: optax.ScaleByAdamState = base.init({})

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return optax.ScaleByAdamState(
count=OptStateSpec(
dtype=state.count.dtype, shape=state.count.shape, mesh_axes=PartitionSpec()
Expand Down Expand Up @@ -950,7 +965,7 @@ def _update(value: Tensor, ema: Tensor, qstep_size: Tensor, count: Tensor) -> _U
)
return updates, new_state

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
def get_ema_partition(param_spec: ParameterSpec) -> OptStateSpec:
# Store momentum in accumulator_dtype if it is set and p is not scalar.
if param_spec.shape and accumulator_dtype is not None:
Expand Down Expand Up @@ -1412,7 +1427,7 @@ def _is_valid_step(
drop_stats=new_drop_stats,
)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
if use_adaptive_drop_norm:
one = jnp.ones([], jnp.float32)
dict_thresholds = drop_norm(count=one, mean=one, stddev=one)
Expand Down Expand Up @@ -1571,7 +1586,7 @@ def update_fn(updates, state, params):
)
return updates, ParamEmaState(count=count_inc, ema=new_ema)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return ParamEmaState(
count=OptStateSpec(dtype=jnp.int32, shape=[], mesh_axes=PartitionSpec()),
ema=copy_partition(param_specs),
Expand Down Expand Up @@ -1617,7 +1632,7 @@ def update_fn(updates, state, params=None):
updates = jax.tree.map(lambda g, m: jnp.sign((1.0 - b1) * g + b1 * m), updates, state.mu)
return updates, ScaleByLionState(count=count_inc, mu=mu)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
mu_specs = param_specs
if mu_dtype is not None:
mu_specs = jax.tree.map(
Expand Down Expand Up @@ -1993,3 +2008,93 @@ def _update2(u: Tensor, param: OptParam):
partition=lambda _: OptStateSpec(shape=[], dtype=jnp.int32, mesh_axes=PartitionSpec()),
)
return named_chain(**tx)


def offload_optimizer(
optimizer: ConfigOr[PartitionedGradientTransformation],
*,
offload_src: Optional[MemoryKind] = "device",
offload_dst: Optional[MemoryKind] = "pinned_host",
) -> PartitionedGradientTransformation:
"""Offload the state of the wrapped optimizer to `offload_dst`.
Args:
optimizer: The optimizer to offload.
offload_src: Offload-from memory kind. Default to "device".
offload_dst: Offload-to memory kind. Default to "pinned_host".
Returns:
A optimizer whose state is on `offload_dst` and does the same computation as `optimizer`.
Raises:
ValueError: when the `update` function of the returned optimizer is called outside of jit
context.
This function returns a new `PartitionedGradientTransformation` that
1. Puts all states of the wrapped optimizer on `offload_dst` through the partition function
during state initialization in the trainer.
2. Copies the states to `offload_src` before `optimizer.update` is called.
3. Copies the updated states to `offload_dst` after `optimizer.update` is called.
The .update function of the returned `PartitionedGradientTransformation` must be called within
a jit function.
Example usage:
```python
your_opt = adamw_optimizer(...)
offloaded_opt = offload_optimizer(your_opt)
```
Only wrap the optimizer that you actually want to offload with this function to avoid
unneseccary overhead. This is usually the optimizer that occupies the most HBM. For example,
when you have chained optimizers:
```python
# Recommended
chain([
some_preprocessing(...),
clip_by_global_norm(...),
offload_optimizer(adamw_decoupled_optimizer(...)),
])
# Not recommended
offload_optimizer(chain([
some_preprocessing(...),
clip_by_global_norm(...),
adamw_decoupled_optimizer(...),
]))
```
When using `skip_and_clip_by_global_norm` with this offload optimizer, you must wrap the entire
`skip_and_clip_by_global_norm` inside. Do not wrap the inner of `skip_and_clip_by_global_norm`
or you will get errors. Correct example:
```
offloaded_opt = offload_optimizer(skip_and_clip_by_global_norm(inner=adamw_optimizer(...)))
```
The reason is that `skip_and_clip_by_global_norm` conditionally chooses the previous optimizer
state and the updated new optimizer state using `jnp.where`, which doesn't support tensors on
`pinned_host` memory space.
"""
optimizer = maybe_instantiate(optimizer)
if offload_src is None or offload_dst is None:
raise ValueError(
"offload_src and offload_dst cannot be None when using optimizer offloading."
)

logging.info("Optimizer offloading enabled.")

def init_fn(params: NestedOptParam):
return optimizer.init(params)

def update_fn(updates: optax.Updates, state: optax.OptState, params: NestedOptParam):
# TransferToMemoryKind let us change the memory kind of tensors without specifying the full
# sharding (i.e. jax.sharding.NamedSharding). Although there's no documentation about it,
# it's specified in the API signature. Reference:
# https://github.com/jax-ml/jax/blob/21f8885a9e104b8828c9a8b721eed0c68b622691/jax/_src/api.py#L2220
state = jax.device_put(state, TransferToMemoryKind(offload_src))
updates, state = optimizer.update(updates, state, params)
state = jax.device_put(state, TransferToMemoryKind(offload_dst), donate=True)
return updates, state

def partition_fn(param_spec: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return copy_partition(optimizer.partition(param_spec), memory_kind=offload_dst)

return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)
59 changes: 45 additions & 14 deletions axlearn/common/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ema,
l2_regularizer,
lion_optimizer,
offload_optimizer,
opt_param_values,
param_ema,
per_param_scale_by_path,
Expand Down Expand Up @@ -379,12 +380,25 @@ def _check_dtypes(x, y, z):
jax.tree.map(_check_dtypes, init_state, partition_state, update_state)

def _test_optimizer(self, optimizer):
params = OptParam(
value=jnp.asarray([0, 1, 2, -3], dtype=jnp.float32),
factorization_spec=None,
weight_decay_scale=1.0,
)
state = optimizer.init(params)
self._test_optimizer_helper(optimizer, True)
self._test_optimizer_helper(optimizer, False)

def _test_optimizer_helper(self, optimizer, offload):
if offload:
optimizer = offload_optimizer(optimizer)
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)

def create_opt_params(x):
return jax.tree.map(
lambda y: OptParam(
value=y,
factorization_spec=None,
weight_decay_scale=1.0,
),
x,
)

state = optimizer.init(create_opt_params(params))

param_spec = ParameterSpec(shape=[4], mesh_axes=PartitionSpec("model"), factorization=None)
state_partition_spec = optimizer.partition(param_spec)
Expand All @@ -399,13 +413,23 @@ def check_partition_spec(spec: OptStateSpec, tree):

jax.tree.map(check_partition_spec, state_partition_spec, state)

def compute_loss(x):
return -jax.nn.log_softmax(x)[1]
@jax.jit
def jit_fn(params, state):
def compute_loss(x):
return -jax.nn.log_softmax(x)[1]

loss, grads = jax.value_and_grad(compute_loss)(params.value)
updates, _ = optimizer.update(grads, state=state, params=params)
updated_params = optax.apply_updates(params.value, updates)
new_loss = compute_loss(updated_params)
params = create_opt_params(params)
loss, grads = jax.value_and_grad(compute_loss)(params.value)
updates, _ = optimizer.update(grads, state=state, params=params)
updated_params = optax.apply_updates(params.value, updates)
return loss, compute_loss(updated_params)

if offload:
self.assertIn(
"TransferToMemoryKind(memory_kind='pinned_host')",
str(jax.make_jaxpr(jit_fn)(params, state)),
)
loss, new_loss = jit_fn(params, state)
self.assertLess(new_loss, loss)

@parameterized.product(
Expand Down Expand Up @@ -788,14 +812,17 @@ def loss_fn(x):
config_for_function(drop_norm_by_grad_norm_ema).set(multipliers=[0.1, 1]),
config_for_function(drop_norm_by_grad_norm_stddev).set(multipliers=[20, 40]),
),
offload=(True, False),
)
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm):
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm, offload):
clip = skip_and_clip_by_global_norm(
inner=_counter(),
drop_norm=drop_norm,
max_norm=max_norm,
grad_norm_ema_decay=0.99,
)
if offload:
clip = offload_optimizer(clip)
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)
state = clip.init(params)
init_ema = state.grad_norm_ema
Expand All @@ -821,7 +848,11 @@ def loss_fn(x):
else:
is_valid_step = drop_norm is None or g_norm < drop_norm

updates, state = clip.update(grads, state=state, params=params)
@jax.jit
def jit_fn(grads, state, params):
return clip.update(grads, state=state, params=params)

updates, state = jit_fn(grads, state, params)
if is_valid_step:
if max_norm is None or g_norm < max_norm:
np.testing.assert_allclose(updates, grads, atol=1e-6)
Expand Down
12 changes: 6 additions & 6 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@
HybridMeshShape,
MeshShape,
Nested,
NestedPartitionSpec,
NestedTensor,
PartitionSpec,
Tensor,
TensorSpec,
count_model_params,
flatten_items,
match_regex_rules,
Expand All @@ -62,9 +62,9 @@


class TrainerState(NamedTuple):
prng_key: Union[Tensor, NestedPartitionSpec]
model: Union[NestedTensor, NestedPartitionSpec]
learner: Union[NestedTensor, NestedPartitionSpec]
prng_key: Union[Tensor, TensorSpec, jax.sharding.NamedSharding]
model: Union[NestedTensor, Nested[TensorSpec], Nested[jax.sharding.NamedSharding]]
learner: Union[NestedTensor, Nested[TensorSpec], Nested[jax.sharding.NamedSharding]]


# pylint: disable-next=too-many-instance-attributes
Expand Down Expand Up @@ -309,8 +309,8 @@ def __init__(
model=self._model_param_specs,
learner=self._learner_state_partition_specs,
)
self._trainer_state_partition_specs = jax.tree.map(
lambda spec: spec.mesh_axes, self._trainer_state_specs
self._trainer_state_partition_specs: TrainerState = jax.tree.map(
lambda spec: spec.sharding, self._trainer_state_specs
)
# Create evalers, which depend on model_param_partition_specs.
self._evalers = {}
Expand Down
Loading

0 comments on commit edb04df

Please sign in to comment.