diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..2f72fbac0 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -15,6 +15,7 @@ import itertools from typing import Any, Optional, Union +import jax from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies from axlearn.common import causal_lm, config @@ -22,6 +23,7 @@ BaseStackedTransformerLayer, FusedGroupedQKVLinear, FusedQKVLinear, + GroupedQKVLinear, GroupedQueryAttention, MultiheadAttention, RepeatedTransformerLayer, @@ -174,6 +176,12 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=4), + ), + ) ) elif model_size == "3B": trainer_kwargs = dict( @@ -192,6 +200,12 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=4), + ), + ) ) elif model_size == "7B": trainer_kwargs = dict( @@ -287,6 +301,14 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=4), + ), + ( + "neuron-(trn1|trn1n).32xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=8), + ), ), ) elif model_size == "8B": @@ -367,6 +389,10 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=4), + ), ), ) elif model_size == "70B": @@ -417,12 +443,17 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + mesh_shape_from_axes(fsdp=-1, model=4), + ), ), ) else: raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") model_kwargs.setdefault("vocab_size", vocab_size) + model_kwargs.setdefault("stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config()) trainer_kwargs["model_cfg"] = model_config(**model_kwargs) trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config( max_step=trainer_kwargs["max_step"], @@ -473,7 +504,9 @@ def model_config( ffn_dim = scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=256) if num_kv_heads: atten_cfg = GroupedQueryAttention.default_config() - atten_input_linear = FusedGroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads) + backend = jax.default_backend() + qkv_linear = FusedGroupedQKVLinear if backend != "neuron" else GroupedQKVLinear + atten_input_linear = qkv_linear.default_config().set(num_kv_heads=num_kv_heads) else: atten_cfg = MultiheadAttention.default_config() atten_input_linear = FusedQKVLinear.default_config()