From c720878805e41b57961369bd1b08f92b312d8da3 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Tue, 3 Dec 2024 23:25:30 -0800 Subject: [PATCH] enable special remat for neuron --- axlearn/common/attention.py | 10 ++-- axlearn/common/attention_test.py | 70 +++++++++++++++++++++++++- axlearn/experiments/text/gpt/common.py | 11 +--- axlearn/experiments/text/gpt/fuji.py | 35 ++++++++++++- 4 files changed, 111 insertions(+), 15 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 37baf3d8b..09924b0d9 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -3956,15 +3956,19 @@ def policy(prim, *_, **params): return policy -SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)" -FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*" +# Regex patterns for matching remat names +class RematRegexSavePatterns(enum.Enum): + QKV_PROJ = r".*\.?(k|q|v)_proj" + LINEAR1_X = r".*\.?linear1_[01]" + SELF_ATTENTION = ".*([qkvo]_proj|context)" + FEED_FORWARD = ".*linear[12]_.*" def build_remat_spec( stack_cfg: Union[ BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore ], - save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN, + save_pattern: _SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value, offload_pattern: _SavePattern = None, offload_dst: str = "pinned_host", ) -> Optional[RematSpec]: diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 1e188ecc0..41d2513c8 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -41,7 +41,6 @@ from axlearn.common import attention, attention_bias, test_utils, utils from axlearn.common.attention import ( - FEED_FORWARD_SAVE_PATTERN, BaseStackedTransformerLayer, BaseTransformerLayer, BottleNeckAdapterTransformerLayer, @@ -58,6 +57,7 @@ PipelinedTransformerLayer, QKVLinear, QLinear, + RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, StackedTransformerLayer, @@ -3420,7 +3420,7 @@ def f(x, layer_params): jax.remat( f, policy=_save_and_offload_only_these_names_regex( - names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN, + names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value, names_which_can_be_offloaded=None, offload_src="device", offload_dst="pinned_host", @@ -3875,6 +3875,72 @@ def f(x, layer_params): 5, ) + def test_build_remat_spec_neuron(self): + model_dim, num_heads = 6, 2 + cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) + cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + cfg.feed_forward.hidden_dim = model_dim * 4 + cfg.vlog = 5 + + layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + + batch_size, tgt_len = 2, 5 + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) + + def f(x, layer_params): + forward_outputs, _ = F( + layer, + inputs=dict( + data=x, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + return forward_outputs + + # Ignore type errors. + spec: Any = build_remat_spec(mock.MagicMock()) + + policy = ( + config_for_function(_save_and_offload_only_these_names_regex) + .set( + names_which_can_be_saved="|".join( + [ + RematRegexSavePatterns.QKV_PROJ.value, + RematRegexSavePatterns.LINEAR1_X.value, + RematRegexSavePatterns.ATTENTION_OUTPUT.value, + RematRegexSavePatterns.FEED_FORWARD_OUTPUT.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ) + .instantiate() + ) + + _, default_policy_backward = jax.linearize( + jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse), + jnp.asarray(target), + layer_params, + ) + _, full_remat_backward = jax.linearize( + jax.remat(f), + jnp.asarray(target), + layer_params, + ) + + # Eliminated the remat of qkv_proj, o_proj and linear1_0 = 5 dots. This assumes + # FlashAttention is not enabled. + self.assertEqual( + str(full_remat_backward).count(" dot_general") + - str(default_policy_backward).count(" dot_general"), + 5, + ) + class TestStackModel(BaseLayer): """A dummy transformer stack.""" diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 973fb9234..c01789a39 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -34,6 +34,7 @@ BaseQKVLinear, MultiheadAttention, RepeatedTransformerLayer, + StackedTransformerLayer, TransformerLayer, build_remat_spec, set_double_shard_weights_config, @@ -190,20 +191,12 @@ def update_model_remat_config( ): """Recomputes and sets the remat_spec based on provided layer_cfg. - Only applied if the stack_cfg is a RepeatedTransformerLayer. - Args: stack_cfg: The transformer stack config. layer_cfg: The transformer layer config. offload_dst: Destination of remat checkptoing offloading. - Raises: - NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer. """ - if stack_cfg.klass is not RepeatedTransformerLayer: - raise NotImplementedError( - f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}" - ) remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg)) layer_cfg.set(remat_spec=remat_spec) @@ -277,7 +270,7 @@ def model_config( layer_cfg.self_attention.attention.input_linear = attention_qkv_linear layer_cfg.self_attention.structure = atten_structure layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap - if stack_cfg.klass is RepeatedTransformerLayer: + if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)): update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) # Stack. transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..61a8ff073 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -24,8 +24,10 @@ FusedQKVLinear, GroupedQueryAttention, MultiheadAttention, + RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, + _save_and_offload_only_these_names_regex, ) from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function @@ -85,7 +87,6 @@ class Version(enum.Enum): Version.V3: 5e5, } - # Mapping from Fuji versions to total number of tokens used in training. TOTAL_TOKENS = { Version.V1: { @@ -417,6 +418,38 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=config_for_function( + _save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved="|".join( + [ + RematRegexSavePatterns.QKV_PROJ.value, + RematRegexSavePatterns.LINEAR1_X.value, + RematRegexSavePatterns.RESIDUAL_ADD.value, + RematRegexSavePatterns.MLP_RESIDUAL.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ), + ), + } + ), + ], + ), + ), ), ) else: