Skip to content

Commit

Permalink
enable special remat for neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin authored and dgourab-aws committed Jan 6, 2025
1 parent 3ae8f9f commit c720878
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 15 deletions.
10 changes: 7 additions & 3 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3956,15 +3956,19 @@ def policy(prim, *_, **params):
return policy


SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)"
FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*"
# Regex patterns for matching remat names
class RematRegexSavePatterns(enum.Enum):
QKV_PROJ = r".*\.?(k|q|v)_proj"
LINEAR1_X = r".*\.?linear1_[01]"
SELF_ATTENTION = ".*([qkvo]_proj|context)"
FEED_FORWARD = ".*linear[12]_.*"


def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
],
save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN,
save_pattern: _SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value,
offload_pattern: _SavePattern = None,
offload_dst: str = "pinned_host",
) -> Optional[RematSpec]:
Expand Down
70 changes: 68 additions & 2 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

from axlearn.common import attention, attention_bias, test_utils, utils
from axlearn.common.attention import (
FEED_FORWARD_SAVE_PATTERN,
BaseStackedTransformerLayer,
BaseTransformerLayer,
BottleNeckAdapterTransformerLayer,
Expand All @@ -58,6 +57,7 @@
PipelinedTransformerLayer,
QKVLinear,
QLinear,
RematRegexSavePatterns,
RepeatedTransformerLayer,
RoFormerQKVLinear,
StackedTransformerLayer,
Expand Down Expand Up @@ -3420,7 +3420,7 @@ def f(x, layer_params):
jax.remat(
f,
policy=_save_and_offload_only_these_names_regex(
names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN,
names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value,
names_which_can_be_offloaded=None,
offload_src="device",
offload_dst="pinned_host",
Expand Down Expand Up @@ -3875,6 +3875,72 @@ def f(x, layer_params):
5,
)

def test_build_remat_spec_neuron(self):
model_dim, num_heads = 6, 2
cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim)
cfg.self_attention.attention.set(num_heads=num_heads, causal=True)
cfg.feed_forward.hidden_dim = model_dim * 4
cfg.vlog = 5

layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))

batch_size, tgt_len = 2, 5
rng = np.random.default_rng(seed=123)
target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32)

def f(x, layer_params):
forward_outputs, _ = F(
layer,
inputs=dict(
data=x,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
return forward_outputs

# Ignore type errors.
spec: Any = build_remat_spec(mock.MagicMock())

policy = (
config_for_function(_save_and_offload_only_these_names_regex)
.set(
names_which_can_be_saved="|".join(
[
RematRegexSavePatterns.QKV_PROJ.value,
RematRegexSavePatterns.LINEAR1_X.value,
RematRegexSavePatterns.ATTENTION_OUTPUT.value,
RematRegexSavePatterns.FEED_FORWARD_OUTPUT.value,
]
),
names_which_can_be_offloaded=None,
offload_src=None,
offload_dst=None,
)
.instantiate()
)

_, default_policy_backward = jax.linearize(
jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse),
jnp.asarray(target),
layer_params,
)
_, full_remat_backward = jax.linearize(
jax.remat(f),
jnp.asarray(target),
layer_params,
)

# Eliminated the remat of qkv_proj, o_proj and linear1_0 = 5 dots. This assumes
# FlashAttention is not enabled.
self.assertEqual(
str(full_remat_backward).count(" dot_general")
- str(default_policy_backward).count(" dot_general"),
5,
)


class TestStackModel(BaseLayer):
"""A dummy transformer stack."""
Expand Down
11 changes: 2 additions & 9 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BaseQKVLinear,
MultiheadAttention,
RepeatedTransformerLayer,
StackedTransformerLayer,
TransformerLayer,
build_remat_spec,
set_double_shard_weights_config,
Expand Down Expand Up @@ -190,20 +191,12 @@ def update_model_remat_config(
):
"""Recomputes and sets the remat_spec based on provided layer_cfg.
Only applied if the stack_cfg is a RepeatedTransformerLayer.
Args:
stack_cfg: The transformer stack config.
layer_cfg: The transformer layer config.
offload_dst: Destination of remat checkptoing offloading.
Raises:
NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer.
"""
if stack_cfg.klass is not RepeatedTransformerLayer:
raise NotImplementedError(
f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}"
)

remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg))
layer_cfg.set(remat_spec=remat_spec)
Expand Down Expand Up @@ -277,7 +270,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)):
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
# Stack.
transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg)
Expand Down
35 changes: 34 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
FusedQKVLinear,
GroupedQueryAttention,
MultiheadAttention,
RematRegexSavePatterns,
RepeatedTransformerLayer,
RoFormerQKVLinear,
_save_and_offload_only_these_names_regex,
)
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import config_for_function
Expand Down Expand Up @@ -85,7 +87,6 @@ class Version(enum.Enum):
Version.V3: 5e5,
}


# Mapping from Fuji versions to total number of tokens used in training.
TOTAL_TOKENS = {
Version.V1: {
Expand Down Expand Up @@ -417,6 +418,38 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=config_for_function(
_save_and_offload_only_these_names_regex
).set(
names_which_can_be_saved="|".join(
[
RematRegexSavePatterns.QKV_PROJ.value,
RematRegexSavePatterns.LINEAR1_X.value,
RematRegexSavePatterns.RESIDUAL_ADD.value,
RematRegexSavePatterns.MLP_RESIDUAL.value,
]
),
names_which_can_be_offloaded=None,
offload_src=None,
offload_dst=None,
),
),
}
),
],
),
),
),
)
else:
Expand Down

0 comments on commit c720878

Please sign in to comment.