diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index cc9798afd..d02ce5eb1 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -54,6 +54,7 @@ from axlearn.common.utils import ( NestedTensor, Tensor, + maybe_shard, partial_with_fn_metadata, with_sharding_constraint, ) @@ -331,6 +332,10 @@ class Config(BaseNormalizationLayer.Config): eps: float = 1e-8 # Cast input to this dtype for the 'forward' call. If None, do not cast. forward_dtype: Optional[jnp.dtype] = jnp.float32 + # If not None, how to partition input activation values. + input_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition output activation values. + output_partition_spec: Optional[tuple[Optional[str]]] = None def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: cfg = self.config @@ -341,6 +346,7 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: def forward(self, x: Tensor, *, paddings: Optional[Tensor] = None) -> Tensor: del paddings # paddings do not affect LayerNorm results cfg = self.config + x = maybe_shard(x, cfg.input_partition_spec) x_dtype = x.dtype if cfg.forward_dtype is not None: x = x.astype(cfg.forward_dtype) @@ -348,6 +354,7 @@ def forward(self, x: Tensor, *, paddings: Optional[Tensor] = None) -> Tensor: x = x * jax.lax.rsqrt(moment2 + cfg.eps) x = x.astype(x_dtype) x = x * self.parameters["scale"] + x = maybe_shard(x, cfg.output_partition_spec) return x @@ -780,6 +787,12 @@ class Config(BaseLayer.Config): num_embeddings: Required[int] = REQUIRED # Maximum number of embeddings in table. dim: Required[int] = REQUIRED # Embedding vector dimensionality. + # If not None, how to partition input activation values. + input_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition embedding table. + embedding_partition_spec: Optional[tuple[Optional[str]]] = None + # If not None, how to partition output activation values. + output_partition_spec: Optional[tuple[Optional[str]]] = None @classmethod def default_config(cls): @@ -814,8 +827,13 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: ) def forward(self, x: Tensor) -> Tensor: + cfg = self.config + x = maybe_shard(x, cfg.input_partition_spec) emb = self.parameters["weight"] - return emb[x] + emb = maybe_shard(emb, cfg.embedding_partition_spec) + activation = emb[x] + activation = maybe_shard(activation, cfg.output_partition_spec) + return activation def attend(self, x: Tensor) -> Tensor: """Apply query array 'x' to the embedding weight array. diff --git a/axlearn/common/layers_test.py b/axlearn/common/layers_test.py index c8b4c11f4..18e600851 100644 --- a/axlearn/common/layers_test.py +++ b/axlearn/common/layers_test.py @@ -18,6 +18,7 @@ from collections.abc import Sequence from functools import partial from typing import Optional, Union +from unittest import mock import jax.random import numpy as np @@ -507,6 +508,57 @@ def test_rms_norm(self): # The output_norm should be close to 2 * sqrt(dim). assert_allclose(output_norm, np.ones_like(output_norm) * 2.0 * math.sqrt(dim)) + @mock.patch("axlearn.common.utils.with_sharding_constraint") + def test_rms_norm_partition_specs_constraint(self, mock_with_sharding_constraint): + # Configure mock to return its input. + mock_with_sharding_constraint.side_effect = lambda x, *args: x + + dim = 6 + cfg = RMSNorm.default_config().set( + name="norm", + input_dim=dim, + input_partition_spec=("fsdp", "model", None), + output_partition_spec=("fsdp", None, None), + ) + layer: RMSNorm = cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [2, 3, dim]) + + # Run forward pass. + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + state=layer_params, + prng_key=prng_key, + ) + + # Verify with_sharding_constraint calls. + calls = mock_with_sharding_constraint.call_args_list + # Should be called twice - once for input, once for output. + self.assertEqual(len(calls), 2) + + # 1. Input tensor constraint. + input_spec = calls[0].args[1] + self.assertEqual(input_spec, ("fsdp", "model", None)) + self.assertEqual(calls[0].args[0].shape, (2, 3, dim)) + self.assertEqual(calls[0].args[0].dtype, jnp.float32) + np.testing.assert_array_equal(calls[0].args[0], inputs) + + # 2. Output tensor constraint. + output_spec = calls[1].args[1] + self.assertEqual(output_spec, ("fsdp", None, None)) + self.assertEqual(calls[1].args[0].shape, (2, 3, dim)) + self.assertEqual(calls[1].args[0].dtype, jnp.float32) + np.testing.assert_array_equal(calls[1].args[0], outputs) + def test_l2_norm(self): cfg = L2Norm.default_config().set(name="norm") layer: L2Norm = cfg.instantiate(parent=None) @@ -1207,6 +1259,59 @@ def test_embed_attend(self, seq_len, dim, num_embeddings, is_training): )[0] assert_allclose(jnp.dot(x, state["weight"].T), actual_attends) + @mock.patch("axlearn.common.utils.with_sharding_constraint") + def test_embed_partition_specs_constraint(self, mock_with_sharding_constraint): + # Configure mock to return its input. + mock_with_sharding_constraint.side_effect = lambda x, *args: x + + dim = 16 + num_embeddings = 100 + seq_len = 5 + rng = jax.random.PRNGKey(1) + + # Configure embedding with partition specs. + cfg = Embedding.default_config().set( + name="embed", + dim=dim, + num_embeddings=num_embeddings, + input_partition_spec=("fsdp", None), + output_partition_spec=("fsdp", "model"), + embedding_partition_spec=("model", "fsdp"), + ) + + # Instantiate embedding. + emb = cfg.instantiate(parent=None) + state = emb.initialize_parameters_recursively(rng) + + # Test lookup functionality. + ixs = jax.random.randint(rng, minval=0, maxval=num_embeddings, shape=(3, seq_len)) + actual_embeds, _ = module.functional(emb, rng, state=state, inputs=[ixs], is_training=True) + + # Verify with_sharding_constraint was called in correct order with proper specs. + calls = mock_with_sharding_constraint.call_args_list + self.assertEqual(len(calls), 3) + + # 1. Input activation constraint (indices tensor). + input_spec = calls[0].args[1] + self.assertEqual(input_spec, ("fsdp", None)) + self.assertEqual(calls[0].args[0].shape, (3, seq_len)) + self.assertEqual(calls[0].args[0].dtype, jnp.int32) + np.testing.assert_array_equal(calls[0].args[0], ixs) + + # 2. Embedding weight constraint. + weight_spec = calls[1].args[1] + self.assertEqual(weight_spec, ("model", "fsdp")) + self.assertEqual(calls[1].args[0].shape, (num_embeddings, dim)) + self.assertEqual(calls[1].args[0].dtype, jnp.float32) + np.testing.assert_array_equal(calls[1].args[0], state["weight"]) + + # 3. Output activation constraint (after lookup). + output_spec = calls[2].args[1] + self.assertEqual(output_spec, ("fsdp", "model")) + self.assertEqual(calls[2].args[0].shape, (3, seq_len, dim)) + self.assertEqual(calls[2].args[0].dtype, jnp.float32) + np.testing.assert_array_equal(calls[2].args[0], actual_embeds) + class BiasLayer(BaseLayer): """A test layer with bias.""" diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 1401337f8..13c5e30d4 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -444,6 +444,12 @@ def with_sharding_constraint(x, shardings): return jax.lax.with_sharding_constraint(x, shardings) +def maybe_shard(x: NestedTensor, partition_spec: Optional[PartitionSpec]) -> NestedTensor: + if partition_spec is None: + return x + return with_sharding_constraint(x, PartitionSpec(*partition_spec)) + + def replicate_to_local_data(x: NestedTensor) -> NestedTensor: """Replicates and converts Tensors in `x` to local DeviceArrays.