Skip to content

Commit

Permalink
boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 11, 2024
1 parent c20387c commit 0e70af3
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
import itertools
from typing import Any, Optional, Union

import jax
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies

from axlearn.common import causal_lm, config
from axlearn.common.attention import (
BaseStackedTransformerLayer,
FusedGroupedQKVLinear,
FusedQKVLinear,
GroupedQKVLinear,
GroupedQueryAttention,
MultiheadAttention,
RepeatedTransformerLayer,
Expand Down Expand Up @@ -174,6 +176,12 @@ def get_trainer_kwargs(
train_batch_size=train_batch_size,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
)
)
elif model_size == "3B":
trainer_kwargs = dict(
Expand All @@ -192,6 +200,12 @@ def get_trainer_kwargs(
train_batch_size=train_batch_size,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
)
)
elif model_size == "7B":
trainer_kwargs = dict(
Expand Down Expand Up @@ -287,6 +301,14 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
(
"neuron-(trn1|trn1n).32xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=8),
),
),
)
elif model_size == "8B":
Expand Down Expand Up @@ -367,6 +389,10 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
),
)
elif model_size == "70B":
Expand Down Expand Up @@ -417,12 +443,17 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
mesh_shape_from_axes(fsdp=-1, model=4),
),
),
)
else:
raise NotImplementedError(f"Unknown model size {model_size}.")
model_kwargs = trainer_kwargs.pop("model_kwargs")
model_kwargs.setdefault("vocab_size", vocab_size)
model_kwargs.setdefault("stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config())
trainer_kwargs["model_cfg"] = model_config(**model_kwargs)
trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config(
max_step=trainer_kwargs["max_step"],
Expand Down Expand Up @@ -473,7 +504,9 @@ def model_config(
ffn_dim = scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=256)
if num_kv_heads:
atten_cfg = GroupedQueryAttention.default_config()
atten_input_linear = FusedGroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads)
backend = jax.default_backend()
qkv_linear = FusedGroupedQKVLinear if backend != "neuron" else GroupedQKVLinear
atten_input_linear = qkv_linear.default_config().set(num_kv_heads=num_kv_heads)
else:
atten_cfg = MultiheadAttention.default_config()
atten_input_linear = FusedQKVLinear.default_config()
Expand Down

0 comments on commit 0e70af3

Please sign in to comment.