From fd6b372b36e617c511e1a19cd7718ff608347994 Mon Sep 17 00:00:00 2001 From: bailin_wang Date: Thu, 14 Nov 2024 14:53:57 +0800 Subject: [PATCH 1/3] add mamab2 --- axlearn/common/ssm.py | 1052 ++++++++++++++++- axlearn/common/ssm_kernels/ssd_kernels.py | 781 ++++++++++++ .../common/ssm_kernels/ssd_kernels_test.py | 389 ++++++ axlearn/common/ssm_test.py | 603 +++++++++- 4 files changed, 2822 insertions(+), 3 deletions(-) create mode 100644 axlearn/common/ssm_kernels/ssd_kernels.py create mode 100644 axlearn/common/ssm_kernels/ssd_kernels_test.py diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index 7848058d9..25599d7c6 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -29,6 +29,7 @@ import jax import jax.ad_checkpoint +from einops import rearrange, repeat from jax import numpy as jnp from jax._src.mesh import thread_resources from jax.experimental.shard_map import shard_map @@ -43,10 +44,15 @@ ) from axlearn.common.base_layer import BaseLayer, ParameterSpec from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class -from axlearn.common.layers import Conv1D, Linear, MultiLinear, RMSNorm +from axlearn.common.layers import Conv1D, GroupNorm, Linear, MultiLinear, NormType, RMSNorm from axlearn.common.module import Module from axlearn.common.param_init import FanAxes, Initializer, Shape, constant_initializer, uniform from axlearn.common.ssm_kernels.mamba_kernels import compute_mamba_scan +from axlearn.common.ssm_kernels.ssd_kernels import ( + ssd, + ssd_linear_scan_w_hidden_states, + ssd_linear_scan_w_timestep, +) from axlearn.common.utils import Nested, Tensor, with_sharding_constraint @@ -72,6 +78,7 @@ class Config(Initializer.Config): # Clamp dt projection's bias to at least this value. dt_init_floor: float = 1e-4 # One of 'random' or 'constant'. + # If 'constant', the projection matrix is initialized to a constant; otherwise, random. # pylint: disable=C0301 mode: str = "random" def initialize( @@ -1128,7 +1135,7 @@ class Config(BaseSSMLayer.Config): """Configures a Mamba block.""" norm: InstantiableConfig = RMSNorm.default_config() - mamba_layer: MambaMixerLayer.Config = MambaMixerLayer.default_config() + mamba_layer: BaseLayer.Config = MambaMixerLayer.default_config() residual_mode: BlockResidualMode = BlockResidualMode.FP32 def __init__(self, cfg: Config, *, parent: Module): @@ -1445,3 +1452,1044 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): for i in range(cfg.num_layers) ] super().__init__(cfg.set(layer=layers), parent=parent) + + +# Naming convention for Mamba2: +# * `SSD` is used to denote any kernel-specific parameters/functions (consistent with the kernel), +# * `Mamba2`` (where SSD is a sub-module) is used to denote layer-level parameters/functions. + + +class SSDdtBiasInitializer(Initializer): + """Initializes the bias of the dt projection in the SSD layer of Mamba2. + + The weight matrix of the dt projection is seperately constructed and initialized. + """ + + @config_class + class Config(Initializer.Config): + """Configures SSDdtBiasInitializer. + + The initialization is different from Mamba1 in that there is no low-rank parameterization. + and we only need to initialize the bias term. + + Reference: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2.py. + """ + + # Initialization stddev is set to `dt_scale` * 1/sqrt{dt_rank} when random. + dt_scale: float = 1.0 + # Minimum value of the dt projection's bias after applying softplus. + dt_min: float = 1e-3 + # Maximum value of the dt projection's bias after applying softplus. + dt_max: float = 1e-1 + # Clamp dt projection's bias to at least this value. + dt_init_floor: float = 1e-4 + + def initialize( + self, + name: str, + *, + prng_key: Tensor, + shape: Shape, + dtype: jnp.dtype, + axes: Optional[FanAxes] = None, + ) -> Tensor: + """Initializes the SSD dt projection bias following the official implementation.""" + if axes is not None: + raise ValueError("SSDdtBiasInitializer does not support FanAxes.") + cfg = self.config + assert 0 < cfg.dt_min < cfg.dt_max, "`dt_min` must be < `dt_max`." + dt = jnp.exp( + uniform(scale=1.0, dtype=dtype)(prng_key, shape) + * (math.log(cfg.dt_max) - math.log(cfg.dt_min)) + + math.log(cfg.dt_min) + ).astype( + dtype + ) # math.log may return float64, so we need to cast to dtype + dt = jnp.clip(dt, a_min=cfg.dt_init_floor) + # Get inverse of softplus. + inv_dt = dt + jnp.log(-jnp.expm1(-dt)) + return inv_dt + + +class SSDLLogAInitializer(Initializer): + """Initializes SSD's log-log A parameter, a = exp(-exp(llog_a)).""" + + @config_class + class Config(Initializer.Config): + """Configures SSDLLogAInitializer. + + Reference: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2.py. + """ + + # `A` will be initialized within the range of [a_min, a_max], usually not tuned. + a_min: int = 1 + a_max: int = 16 + + def initialize( + self, + name: str, + *, + prng_key: Tensor, + shape: Shape, + dtype: jnp.dtype, + axes: Optional[FanAxes] = None, + ) -> jnp.ndarray: + """Returns a [num_heads] shaped vector.""" + if axes is not None: + raise ValueError("SSDLLogAInitializer does not support FanAxes.") + + cfg = self.config + return jnp.log( + jax.random.uniform(prng_key, shape, dtype=dtype, minval=cfg.a_min, maxval=cfg.a_max) + ) + + +class BaseSSDRecurrence(BaseLayer): + """An abstract class representing a layer that computes the SSD recurrence.""" + + class Output(NamedTuple): + """Defines the output of the SSD recurrence.""" + + data: Tensor # [batch, num_heads, target_length, head_dim] + states: Tensor # [batch, num_heads, target_length, state_dim, head_dim] + + @config_class + class Config(BaseLayer.Config): + """Configures a BaseSSDRecurrence.""" + + output_mode: MambaRecurrenceOutputMode = MambaRecurrenceOutputMode.OUTPUTS + + def forward( + self, x: Tensor, *, log_a: Tensor, b: Tensor, c: Tensor, delta: Tensor, d: Tensor + ) -> Output: + """Computes the Mamba2's SSD recurrence output given full-sequence inputs and parameters. + + Args: + x: [batch_size, num_heads, seq_len, head_dim] + log_a: [num_heads] + b: [batch_size, num_groups, seq_len, state_dim] + c: [batch_size, num_groups, seq_len, state_dim] + delta: [batch_size, num_heads, seq_len] + d: [head_dim] + + Returns: + An instance of BaseSSDRecurrence.Output. + """ + raise NotImplementedError(type(self)) + + +class PallasSSDRecurrence(BaseSSDRecurrence): + """A layer that computes the Mamba2's SSD recurrence with a Pallas-based chunk-wise scan.""" + + @config_class + class Config(BaseSSDRecurrence.Config): + """Configures a PallasSSDRecurrence.""" + + mamba2_dim_to_partition_spec: dict[str, PartitionSpec] = { + "bhtd": PartitionSpec(None), + "bht": PartitionSpec(None), + } + + output_partition_spec: PartitionSpec = PartitionSpec(None) + + def forward( + self, x: Tensor, *, log_a: Tensor, b: Tensor, c: Tensor, delta: Tensor, d: Tensor + ) -> BaseSSDRecurrence.Output: + """Computes Mamba2's SSD recurrence with a Pallas-based chunk-wise scan. + + Args: + x: [batch_size, num_heads, seq_len, head_dim] + log_a: [1, num_heads, 1] + b: [batch_size, num_groups, seq_len, state_dim] + c: [batch_size, num_groups, seq_len, state_dim] + delta: [batch_size, num_heads, seq_len] + d: [1, num_heads, 1, 1] + + Returns: + An BaseSSDRecurrence.Output instance, where .data is the same shape as x and .states is + None (no need to return hidden states during training). + + Unlike the Mamba recurrence, discretizations of parameters are not explicitly computed. + More specifically, \bar a (i.e., discretized a) is computed outside the kernel whereas + \bar b is computed implicitly via adding the delta term to the input + x -- \bar x = x * delta. + See the following line from the official repo for details - + https://github.com/state-spaces/mamba/blob/8ffd905c91d207f5c0cc84fc2a2fb748655094f0/mamba_ssm/modules/ssd_minimal.py#L103 + + Note that `ssd` functions need to be wrapped, otherwise the following error will be raised: + ``NotImplementedError: Mosaic kernels cannot be automatically partitioned.`` + The current version of `ssd` function assumes that h0 is None, so there is no need to + provide its partition spec. + """ + cfg = self.config + + sharded_ssd = shard_map( + ssd, + mesh=thread_resources.env.physical_mesh, + in_specs=( + cfg.mamba2_dim_to_partition_spec["bhtd"], + cfg.mamba2_dim_to_partition_spec["bhtd"], + cfg.mamba2_dim_to_partition_spec["bhtd"], + cfg.mamba2_dim_to_partition_spec["bht"], + ), + out_specs=cfg.output_partition_spec, + check_rep=False, + ) + # The kernel code `ssd_kernels.py` uses q/k/v notations, which corresponds to b/c/x. + x_bar = x * jnp.expand_dims(delta, axis=-1) + loga_bar = log_a * delta + o = sharded_ssd(c, b, x_bar, loga_bar) + + o = o + d * x + return BaseSSDRecurrence.Output(data=o, states=None) + + +class LinearScanSSDRecurrence(BaseSSDRecurrence): + """A layer that computes the Mamba2's SSD recurrence with a Jax-based linear scan.""" + + def forward( + self, + x: Tensor, + *, + log_a: Tensor, + b: Tensor, + c: Tensor, + delta: Tensor, + d: Tensor, + time_step: Optional[Tensor] = None, + ) -> BaseSSDRecurrence.Output: + """Computes the Mamba2's SSD recurrence with a Jax-based linear scan. + + Args: + x: [batch_size, num_heads, seq_len, head_dim] + log_a: [1, num_heads, 1] + b: [batch_size, num_groups, seq_len, state_dim] + c: [batch_size, num_groups, seq_len, state_dim] + delta: [batch_size, num_heads, seq_len] + d: [1, num_heads, 1, 1] + time_step: [batch_size] or None + + Returns: + An BaseSSDRecurrence.Output instance, where .data is the same shape as x and .states is + the hidden states of shape [batch_size, num_heads, seq_len, state_dim, head_dim] for + the given time step if `time_step` is not None, otherwise the full hidden states of + shape [batch_size, num_heads, seq_len, state_dim, head_dim] is returned. + """ + # Same procedure as the pallas version above. + x_bar = x * jnp.expand_dims(delta, axis=-1) + loga_bar = log_a * delta + + if time_step is None: + # Return the full hidden states. + o, states = ssd_linear_scan_w_hidden_states(c, b, x_bar, loga_bar) + else: + # Return the hidden states at the given time step. + o, states = ssd_linear_scan_w_timestep(c, b, x_bar, loga_bar, time_step) + + o = o + d * x + return BaseSSDRecurrence.Output(data=o, states=states) + + +class Mamba2MixerLayer(BaseLayer): + """A layer that computes the Mamba2 recurrence over its input.""" + + @config_class + class Config(BaseLayer.Config): + """Configures a Mamba2MixerLayer.""" + + # `d_model` increases as models get larger. + input_dim: Required[int] = REQUIRED + # `d_state` typically in {64, 128} + state_dim: Required[int] = REQUIRED + # num_heads = input_dim // head_dim, head_dim is typically 128. + num_heads: Required[int] = REQUIRED + + # `G` in the paper, typically 8 + num_groups: Required[int] = REQUIRED + + # See sec 8.2 for the parameterization. More details (e.g., conv + # for bc projection) can be found in the following link: + # https://github.com/state-spaces/mamba/blob/8ffd905c91d207f5c0cc84fc2a2fb748655094f0/mamba_ssm/modules/mamba2.py # pylint: disable=C0301 + + xz_proj: MultiLinear.Config = MultiLinear.default_config().set( + bias=False, + param_partition_spec=(None, None, "model"), + ) + + bc_proj: MultiLinear.Config = MultiLinear.default_config().set( + bias=False, + param_partition_spec=(None, None, "model"), + ) + # A causal convolution. The window defaults to 4, the same as mamba1. + x_conv: Conv1D.Config = Conv1D.default_config().set( + window=4, + bias=True, + param_partition_spec=(None, None, "model"), + ) + b_conv: Conv1D.Config = Conv1D.default_config().set( + window=4, + bias=True, + param_partition_spec=(None, None, "model"), + ) + c_conv: Conv1D.Config = Conv1D.default_config().set( + window=4, + bias=True, + param_partition_spec=(None, None, "model"), + ) + + # `dt_bias` is separately created and initialized. + dt_proj: Linear.Config = Linear.default_config().set( + bias=False, param_partition_spec=(None, "model") + ) + pre_out_proj_norm: InstantiableConfig = GroupNorm.default_config().set( + norm_type=NormType.RMSNORM, + norm_axes=-1, + ) + out_proj: Linear.Config = Linear.default_config().set( + bias=False, + param_partition_spec=("model", None), + ) + + expansion_factor: float = 2.0 + cache_dtype: Optional[jnp.dtype] = None + bc_norm: Optional[InstantiableConfig] = RMSNorm.default_config() + norm_eps: float = 1e-5 + norm_dtype: Optional[jnp.dtype] = None + + # The recurrence implementation to use for full-sequence inputs. + ssd_recurrence: BaseSSDRecurrence = PallasSSDRecurrence.default_config() + # The recurrence implementation to use for inference. + inference_mamba_recurrence: BaseSSDRecurrence = ( + LinearScanSSDRecurrence.default_config().set( + output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES + ) + ) + + class Mamba2Output(NamedTuple): + """Defines the output of the Mamba2MixerLayer.""" + + data: Tensor # [batch, num_heads, target_length, head_dim] + ssd_state: Tensor # [batch, num_heads, state_dim, head_dim] + + class SSDParameters(NamedTuple): + """Defines the parameters of the SSD recurrence.""" + + log_a: Tensor # [1, num_heads, 1] + b: Tensor # [batch_size, num_groups, seq_len, state_dim] + c: Tensor # [batch_size, num_groups, seq_len, state_dim] + delta: Tensor # [batch_size, num_heads, seq_len] + d: Tensor # [1, num_heads, 1, 1] + + # Cache used for internal inference, whereas Mamba2Output is external output. + class Mamba2Cache(NamedTuple): + """Defines the cache of the Mamba2MixerLayer for inference.""" + + # Naming is a bit different from Mamba1: conv_input -> conv_state. + x_conv_state: Tensor # [batch_size, seq_len, inner_dim] + b_conv_state: Tensor # [batch_size, seq_len, state_dim * 2] + c_conv_state: Tensor # [batch_size, seq_len, state_dim * 2] + ssd_state: Tensor # [batch_size, num_heads, state_dim, head_dim] + time_step: Optional[Tensor] = None # [batch] + + @property + def inner_dim(self): + cfg = self.config + return int(cfg.input_dim * cfg.expansion_factor) + + @property + def head_dim(self): + cfg = self.config + return self.inner_dim // cfg.num_heads + + @property + def output_dim(self): + cfg = self.config + return cfg.input_dim + + @property + def group_dim(self): + cfg = self.config + return self.inner_dim // cfg.num_groups + + @property + def bc_state_dim(self): + cfg = self.config + return cfg.state_dim * cfg.num_groups + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + cfg = self.config + + self._add_child( + "xz_proj", + cfg.xz_proj.set( + input_dim=cfg.input_dim, + num_outputs=2, + output_dim=self.inner_dim, + bias=False, + ), + ) + self._add_child( + "bc_proj", + cfg.bc_proj.set( + input_dim=cfg.input_dim, + num_outputs=2, + output_dim=self.bc_state_dim, + bias=False, + ), + ) + self._add_child( + "x_conv", + cfg.x_conv.set( + padding=((cfg.x_conv.window - 1, 0),), # A causal convolution. + input_dim=self.inner_dim, + output_dim=self.inner_dim, + num_input_dim_groups=self.inner_dim, + ), + ) + self._add_child( + "b_conv", + cfg.b_conv.set( + padding=((cfg.b_conv.window - 1, 0),), # A causal convolution. + input_dim=self.bc_state_dim, + output_dim=self.bc_state_dim, + num_input_dim_groups=self.bc_state_dim, + ), + ) + self._add_child( + "c_conv", + cfg.c_conv.set( + padding=((cfg.c_conv.window - 1, 0),), # A causal convolution. + input_dim=self.bc_state_dim, + output_dim=self.bc_state_dim, + num_input_dim_groups=self.bc_state_dim, + ), + ) + + # b/c norm is analoguous to q/k norm in standard attention. + if cfg.bc_norm: + self._add_child( + "b_norm", + cfg.bc_norm.clone().set( + input_dim=cfg.state_dim, eps=cfg.norm_eps, forward_dtype=cfg.norm_dtype + ), + ) + self._add_child( + "c_norm", + cfg.bc_norm.clone().set( + input_dim=cfg.state_dim, eps=cfg.norm_eps, forward_dtype=cfg.norm_dtype + ), + ) + + self._add_child( + "dt_proj", + cfg.dt_proj.set( + input_dim=cfg.input_dim, + output_dim=cfg.num_heads, + bias=False, + ), + ) + self._add_child( + "pre_out_proj_norm", + cfg.pre_out_proj_norm.set( + input_dim=self.inner_dim, num_groups=cfg.num_groups, eps=cfg.norm_eps + ), + ) + self._add_child( + "out_proj", + cfg.out_proj.set( + input_dim=self.inner_dim, + output_dim=cfg.input_dim, + bias=False, + ), + ) + + self._add_child("recurrence", cfg.ssd_recurrence) + self._add_child( + "inference_recurrence", + cfg.inference_mamba_recurrence.set( + output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES + ), + ) + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + """Creates parameter specs. + + Returns: + A dict mapping `llog_a`, `dt_bias` and `d` to their respective ParameterSpecs. + """ + cfg = self.config + params = dict( + llog_a=ParameterSpec( + # Initialize with a shape that avoids expansion later. + shape=(1, cfg.num_heads, 1), + mesh_axes=(None, "model", None), + initializer=SSDLLogAInitializer.default_config().instantiate(), + dtype=cfg.dtype, + weight_decay_scale=0.0, + ), + dt_bias=ParameterSpec( + shape=(cfg.num_heads,), + mesh_axes=("model",), + initializer=SSDdtBiasInitializer.default_config().instantiate(), + dtype=cfg.dtype, + weight_decay_scale=0.0, + ), + d=ParameterSpec( + # Initialize with a shape that avoids expansion later. + shape=(1, cfg.num_heads, 1, 1), + mesh_axes=(None, "model", None, None), + initializer=constant_initializer(1.0), + dtype=cfg.dtype, + weight_decay_scale=0.0, + ), + ) + return params + + def _project_input(self, inputs: Tensor) -> tuple[Tensor, Tensor]: + """Projects inputs into tensors with dimension inner_dim. + + Args: + inputs: [batch_size, seq_len, input_dim] + + Returns: + x, z of the same size [batch_size, seq_len, inner_dim] + """ + xz = self.xz_proj(inputs) + x, z = jnp.split(xz, 2, axis=-2) # [batch_size, seq_len, 1, inner_dim] + return jnp.squeeze(x, axis=2), jnp.squeeze(z, axis=2) + + def _ssm_parameters( + self, + inputs: Tensor, + b_input: Optional[Tensor] = None, + c_input: Optional[Tensor] = None, + ) -> SSDParameters: + """Computes input-dependent SSD parameters. + + Args: + inputs: [batch_size, seq_len, inner_dim] + b_input: [batch_size, seq_len, bc_state_dim]. If b_input and c_input + are given, no need to compute bc_proj. + c_input: [batch_size, seq_len, bc_state_dim]. If b_input and c_input + are given, no need to compute bc_proj. + + Exposing the computation of `b` and `c` is useful to keep track the conv1d input for + `b_conv` and `c_conv`. During training, `b_input` and `c_input` should be None as + they represent the results after short conv. During inference, they represent the + input of short conv. + + Returns: + An instance of SSMParameters. + + Raises: + ValueError: If only one of b_input and c_input is provided. + + TODO (bailin-wang): merge b_conv and c_conv for better efficiency. + """ + cfg = self.config + if (b_input is None) != (c_input is None): + raise ValueError("Either both or none of b_input and c_input should be provided.") + + if b_input is None or c_input is None: + bc = self.bc_proj(inputs) # [batch_size, seq_len, 2, bc_state_dim] + bc = rearrange(bc, "b s n d -> b s (n d)") + b, c = jnp.split(bc, 2, axis=-1) + else: + b = b_input + c = c_input + + b = jax.nn.silu(self.b_conv(b)) + c = jax.nn.silu(self.c_conv(c)) + + b = rearrange(b, "b s (g d) -> b g s d", d=cfg.state_dim) + c = rearrange(c, "b s (g d) -> b g s d", d=cfg.state_dim) + + if "b_norm" in self.children and "c_norm" in self.children: + b = self.b_norm(b) + c = self.c_norm(c) + + # `dt` is in float32 for better precision of softplus for the delta term which later will + # be combined with float32 `log_a`. See also the following link: + # https://github.com/state-spaces/mamba/blob/6b72c122713bb769cc82c6b8e6d019c53d27d6a1/mamba_ssm/ops/triton/ssd_combined.py#L603. + dt = self.dt_proj(inputs) + jnp.expand_dims( + _at_least_float32(self.parameters["dt_bias"]), axis=(0, 1) + ) + delta = jax.nn.softplus(dt) # [batch_size, seq_len, num_heads] + delta = rearrange(delta, "b s h -> b h s") # [batch_size, num_heads, seq_len] + + log_a = -jnp.exp( + _at_least_float32(self.parameters["llog_a"]) + ) # a = exp(-exp(llog_a)), log_a = -exp(llog_a * delta) + + return Mamba2MixerLayer.SSDParameters( + log_a=log_a, b=b, c=c, delta=delta, d=self.parameters["d"] + ) + + def _output_from_states(self, inputs: Tensor, *, z: Tensor) -> Tensor: + """Projects recurrence output back to input dimension. + + Args: + inputs: [batch_size, num_heads, seq_len, head_dim] + z: [batch_size, num_heads, seq_len, head_dim] + + Returns: + A tensor of shape [batch_size, seq_len, input_dim] + + Note that the num_heads/num_groups dim is contracted in the output. + """ + cfg = self.config + y = inputs * jax.nn.silu(z) + y_for_gnorm = rearrange(y, "b nh l d -> b l (nh d)", nh=cfg.num_heads) + y_for_proj = self.pre_out_proj_norm(y_for_gnorm) + return self.out_proj(y_for_proj) + + def forward(self, query: Tensor) -> Mamba2Output: + """Computes the Mamba2 recurrence over the provided inputs. + + Args: + query: [batch_size, input_length, input_dim] + + Returns: + A Mamba2Output instance where .data is the same shape as `inputs`. + """ + _, output = self._forward_for_mode(mode=ForwardMode.FORWARD, query=query) + return output + + def _forward_for_mode( + self, + *, + mode: ForwardMode, + query: Tensor, + cache: Optional[Mamba2Cache] = None, + ) -> tuple[Optional[Nested[Tensor]], Tensor]: + """Computes MambaMixerLayer outputs. + + Args: + mode: {FORWARD, INIT_STATES, EXTEND_STEP} + query: A Tensor of shape [batch_size, seq_len, input_dim] + cache: Optional NestedTensor as produced by `prefill_states`. + + Returns: + An optional cache, depending on `mode`. + A Mamba2Output instance, where .data is of the same shape as `inputs`. + + Raises: + ValueError: If `mode` is unsupported. + """ + self.vlog(3, "mamba2.input=%s", query.sum()) + if mode == ForwardMode.FORWARD: + mamba_cache, mamba_output = self._full_sequence_forward( + query, recurrence=self.recurrence + ) + elif mode == ForwardMode.INIT_STATES: + assert cache is not None + mamba_cache, mamba_output = self.prefill_states( + time_step=cache, + query=query, + ) + elif mode == ForwardMode.EXTEND_STEP: + assert cache is not None + mamba_cache, mamba_output = self.extend_step(cache, query) + else: + raise ValueError(f"Unrecognized mode {mode}.") + self.vlog(3, "mamba2.output=%s", mamba_output.data.sum()) + return dict(mamba_layer=mamba_cache), mamba_output + + def _full_sequence_forward( + self, inputs: Tensor, *, recurrence: BaseSSDRecurrence + ) -> tuple[Optional[Mamba2Cache], Mamba2Output]: + """Computes the Mamba2 layer output from a full sequence of inputs. + + Args: + inputs: A tensor of shape [batch_size, seq_len, input_dim]. + recurrence: A BaseMambaRecurrence to use for computing the recurrence. + + Returns: + An optional Mamba2Cache instance. Currently, it is always None. + A Mamba2Output instance. + """ + cfg = self.config + + x, z = self._project_input(inputs) + x_conv = jax.nn.silu(self.x_conv(x)) + x_conv_w_head = rearrange(x_conv, "b s (h d) -> b h s d", d=self.head_dim) + z_w_head = rearrange(z, "b s (h d) -> b h s d", d=self.head_dim) + + log_a, b, c, delta, d = self._ssm_parameters(inputs) + recurrence_output = recurrence(x_conv_w_head, log_a=log_a, b=b, c=c, delta=delta, d=d) + output = self._output_from_states(recurrence_output.data, z=z_w_head) + + ssd_state = recurrence_output.states + if ssd_state is not None: + ssd_state = ssd_state.astype(cfg.cache_dtype) + + mamba_cache = None + mamba_output = Mamba2MixerLayer.Mamba2Output(data=output, ssd_state=ssd_state) + return mamba_cache, mamba_output + + # pylint: disable=unused-argument + def init_states(self, *, target_batch_size: int, target_max_len: int) -> Mamba2Cache: + """Initializes cache for autoregressive cached decoding. + + Args: + batch_size: The batch size of the target to be decoded. + target_max_len: The maximum length of the target to be decoded. + + Returns: + A Mamba2Cache instance. + """ + cfg = self.config + dtype = cfg.cache_dtype or cfg.dtype + cache = Mamba2MixerLayer.Mamba2Cache( + x_conv_state=jnp.zeros( + (target_batch_size, cfg.x_conv.window, self.inner_dim), dtype=dtype + ), + b_conv_state=jnp.zeros( + (target_batch_size, cfg.b_conv.window, self.bc_state_dim), dtype=dtype + ), + c_conv_state=jnp.zeros( + (target_batch_size, cfg.c_conv.window, self.bc_state_dim), dtype=dtype + ), + ssd_state=jnp.zeros( + (target_batch_size, cfg.num_heads, cfg.state_dim, self.head_dim), dtype=dtype + ), + time_step=jnp.zeros(target_batch_size, dtype=jnp.int32), + ) + return cache + + def prefill_states( + self, + *, + time_step: Tensor, + query: Tensor, + ) -> tuple[Mamba2Cache, Mamba2Output]: + """Initializes cache for autoregressive cached decoding. It refines the mamba state + returned from `_full_sequence_forward` to the state at `time_step` for the + incremental decoding later. + + Args: + time_step: A Tensor of shape [batch_size]. Each value is an index into the length + dimension indicating where decoding will start from. + query: Tensor of shape [batch_size, target_length, target_dim] corresponding to input + vector up to `time_step` indices. For batch index `i`, only + `inputs[i, :time_step[i], ...]` will affect subsequent decoding. + + Returns: + A Mamba2Cache instance containing updated convolution state, ssm state and time_step. + A Mamba2Output instance where .data is the same shape as query. + """ + cfg = self.config + cache_dtype = cfg.cache_dtype or cfg.dtype + + x, z = self._project_input(query) + x_conv = jax.nn.silu(self.x_conv(x)) + x_conv_w_head = rearrange(x_conv, "b s (h d) -> b h s d", d=self.head_dim) + z_w_head = rearrange(z, "b s (h d) -> b h s d", d=self.head_dim) + + # Run `bc_proj` outside of `_ssm_parameters` so that we can keep track of the conv1d input. + bc_input = self.bc_proj(query) # [batch_size, seq_len, 2, bc_state_dim] + bc_input = rearrange(bc_input, "b s n d -> b s (n d)") + b_input, c_input = jnp.split(bc_input, 2, axis=-1) + log_a, b, c, delta, d = self._ssm_parameters(query, b_input=b_input, c_input=c_input) + + recurrence_output = self.inference_recurrence( + x_conv_w_head, log_a=log_a, b=b, c=c, delta=delta, d=d, time_step=time_step + ) + output = self._output_from_states(recurrence_output.data, z=z_w_head) + mamba_output = Mamba2MixerLayer.Mamba2Output( + data=output, ssd_state=recurrence_output.states.astype(cache_dtype) + ) + + # Collect and refine conv states and ssd states. + x_conv_state = x + b_conv_state = b_input + c_conv_state = c_input + + # For the full sequence, always in float32, will be down-cast based on cache_dtype. + cont_ssd_state = recurrence_output.states.astype(cache_dtype) + + batch_size = query.shape[0] + batch_range = jnp.arange(batch_size) + + # Pad conv input so we can take the last window timesteps that precede time_step. + x_time_step_range = time_step[:, None] + jnp.arange(cfg.x_conv.window)[None, :] + padded_x_conv_state = jnp.pad( + x_conv_state, ((0, 0), (cfg.x_conv.window, 0), (0, 0)) + ) # [batch_size, target_length+window, input_dim] + cont_x_conv_state = padded_x_conv_state[batch_range[:, None], x_time_step_range] + + b_time_step_range = time_step[:, None] + jnp.arange(cfg.b_conv.window) + padded_b_conv_state = jnp.pad(b_conv_state, ((0, 0), (cfg.b_conv.window, 0), (0, 0))) + cont_b_conv_state = padded_b_conv_state[batch_range[:, None], b_time_step_range] + + c_time_step_range = time_step[:, None] + jnp.arange(cfg.c_conv.window) + padded_c_conv_state = jnp.pad(c_conv_state, ((0, 0), (cfg.c_conv.window, 0), (0, 0))) + cont_c_conv_state = padded_c_conv_state[batch_range[:, None], c_time_step_range] + + init_cache = Mamba2MixerLayer.Mamba2Cache( + x_conv_state=cont_x_conv_state.astype(cache_dtype), + b_conv_state=cont_b_conv_state.astype(cache_dtype), + c_conv_state=cont_c_conv_state.astype(cache_dtype), + ssd_state=cont_ssd_state.astype(cache_dtype), + time_step=time_step, + ) + return init_cache, mamba_output + + def _single_step_conv_update( + self, + inputs: Tensor, + *, + conv_state: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + ) -> tuple[Tensor, Tensor]: + """Updates cache of convolutional inputs and returns updated state. + + Args: + inputs: [batch_size, inner_dim] + conv_state: [batch_size, width, inner_dim] + weight: [width, 1, inner_dim] + bias: [inner_dim] + + Returns: + A tensor of shape [batch_size, inner_dim]. + A tensor of shape [batch_size, width, inner_dim], representing the new conv state. + """ + new_conv_state = jnp.roll(conv_state, shift=-1, axis=1) + new_conv_state = new_conv_state.at[:, -1].set(inputs) + + conv_output = jnp.sum( + new_conv_state * jnp.squeeze(_at_least_float32(weight), axis=1), axis=1 + ).astype( + inputs.dtype + ) # [batch_size, inner_dim] + if bias is not None: + conv_output = conv_output + bias + return conv_output, new_conv_state + + def _single_step_ssm_update( + self, + x: Tensor, + *, + ssm_state: Tensor, + log_a: Tensor, + b: Tensor, + c: Tensor, + d: Tensor, + delta: Tensor, + ) -> tuple[Tensor, Tensor]: + """Moves the SSM state forward by a single step. + + Args: + x: [batch_size, num_heads, 1, head_dim] + ssm_state: [batch_size, num_heads, state_dim, head_dim] + log_a: [1, num_heads, 1], always float32 + b: [batch_size, num_groups, 1, state_dim] + c: [batch_size, num_groups, 1, state_dim] + delta: [batch_size, num_heads, 1], always float32 + d: [1, head_dim, 1, 1] + + Returns: + A tensor of shape [batch_size, num_heads, 1, head_dim] for the new output. + A tensor of shape [batch_size, num_heads, state_dim, head_dim] for the updated state. + """ + cfg = self.config + num_head_per_group = cfg.num_heads // cfg.num_groups + + orig_dtype = x.dtype + acc_dtype = cfg.cache_dtype or cfg.dtype + + # x: [batch_size, num_heads, head_dim] + # b and c: [batch_size, num_groups, state_dim] + # d: [batch_size, num_heads] + x, b, c, d = map(lambda x: jnp.squeeze(x, axis=2), (x, b, c, d)) + + # [batch_size, num_heads, state_dim] + b = repeat(b, "b ng d -> b (ng ngh) d", ngh=num_head_per_group) + c = repeat(c, "b ng d -> b (ng ngh) d", ngh=num_head_per_group) + + # [batch_size, num_heads, head_dim] + x_bar = x * delta + # [batch_size, num_heads, 1] + loga_bar = log_a * delta + # [batch_size, num_heads, 1] + a = jnp.exp(loga_bar) + # [batch_size, num_heads, state_dim, head_dim] + a = jnp.expand_dims(a, axis=-1) + + new_ssm_state = a * ssm_state + jnp.einsum("...i,...j->...ij", b, x_bar) + output = jnp.einsum("...ij,...i->...j", new_ssm_state, c) + d * x + + output = jnp.expand_dims(output.astype(orig_dtype), axis=2) + new_ssm_state = new_ssm_state.astype(acc_dtype) + return output, new_ssm_state + + def extend_step( + self, + cache: Mamba2Cache, + query: Tensor, + ) -> tuple[Mamba2Cache, Mamba2Output]: + """Computes the next state given the query of the current step. This function is used + in autoregressive decoding. + + Args: + cached_states: A Nested[Tensor] containing previous state of shape and index. + query: Tensor of shape [batch_size, 1, inner_dim] + + Returns: + A Mamba2Cache instance containing the convolution state, ssm state and time_step. + A Mamba2Output instance, where .data is the same shape as query. + """ + time_step: Tensor = cache.time_step + assert time_step.ndim == 1 + cfg = self.config + + x, z = self._project_input(query) + x_conv, new_x_conv_state = self._single_step_conv_update( + jnp.squeeze(x, axis=1), + conv_state=cache.x_conv_state, + weight=self.parameters["x_conv"]["weight"], + bias=self.parameters["x_conv"]["bias"], + ) + x_conv = jnp.expand_dims(jax.nn.silu(x_conv), axis=1) # [batch_size, 1, inner_dim] + x_conv_w_head = rearrange(x_conv, "b s (h d) -> b h s d", d=self.head_dim) + z_w_head = rearrange(z, "b s (h d) -> b h s d", d=self.head_dim) + + # Obtain ssm parameters. + bc = self.bc_proj(query) # [batch_size, seq_len, 2, bc_state_dim] + bc = rearrange(bc, "b s n d -> b s (n d)") + b, c = jnp.split(bc, 2, axis=-1) + + b_conv, new_b_conv_state = self._single_step_conv_update( + jnp.squeeze(b, axis=1), + conv_state=cache.b_conv_state, + weight=self.parameters["b_conv"]["weight"], + bias=self.parameters["b_conv"]["bias"], + ) + b = jnp.expand_dims(jax.nn.silu(b_conv), axis=1) # [batch_size, 1, bc_inner_dim] + + c_conv, new_c_conv_state = self._single_step_conv_update( + jnp.squeeze(c, axis=1), + conv_state=cache.c_conv_state, + weight=self.parameters["c_conv"]["weight"], + bias=self.parameters["c_conv"]["bias"], + ) + c = jnp.expand_dims(jax.nn.silu(c_conv), axis=1) # [batch_size, 1, bc_inner_dim] + + b = rearrange(b, "b s (g d) -> b g s d", d=cfg.state_dim) + c = rearrange(c, "b s (g d) -> b g s d", d=cfg.state_dim) + + if cfg.bc_norm: + b = self.b_norm(b) + c = self.c_norm(c) + + dt = self.dt_proj(query) + jnp.expand_dims( + _at_least_float32(self.parameters["dt_bias"]), axis=(0, 1) + ) + delta = jax.nn.softplus(dt) # [batch_size, 1, num_heads] + delta = rearrange(delta, "b s h -> b h s") # [batch_size, num_heads, 1] + + log_a = -jnp.exp( + _at_least_float32(self.parameters["llog_a"]) + ) # a = exp(-exp(llog_a)), log_a = -exp(llog_a) + d = self.parameters["d"] + + y, new_ssd_state = self._single_step_ssm_update( + x_conv_w_head, + ssm_state=cache.ssd_state, + log_a=log_a, + b=b, + c=c, + d=d, + delta=delta, + ) + output = self._output_from_states(y, z=z_w_head) + + new_cache = Mamba2MixerLayer.Mamba2Cache( + x_conv_state=new_x_conv_state, + b_conv_state=new_b_conv_state, + c_conv_state=new_c_conv_state, + ssd_state=new_ssd_state, + time_step=time_step + 1, + ) + mamba2output = Mamba2MixerLayer.Mamba2Output( + data=output, + ssd_state=new_ssd_state, + ) + return new_cache, mamba2output + + +class JambaMamba2Block(JambaMambaBlock): + """A JambaMamba2Block along with RMN norm and a feed-forward layer.""" + + @config_class + class Config(JambaMambaBlock.Config): + """Configures a JambaMamba2Block.""" + + num_heads: Required[int] = REQUIRED + num_groups: Required[int] = REQUIRED + + @classmethod + def default_config(cls) -> Config: + cfg = super().default_config() + cfg.mamba_layer = Mamba2MixerLayer.default_config() + return cfg + + def __init__(self, cfg: Config, *, parent: Module): + cfg.mamba_layer = cfg.mamba_layer.set(num_heads=cfg.num_heads, num_groups=cfg.num_groups) + super().__init__(cfg, parent=parent) + + +def set_double_shard_weights_config_mamba2( + cfg: Union[JambaMamba2Block.Config, Sequence[JambaMamba2Block.Config]], + *, + batch_axis_names: Union[str, Sequence[str]] = ("data", "expert", "fsdp"), + fsdp_axis_names: Union[str, Sequence[str]] = "fsdp", + tp_axis_names: Union[str, Sequence[str]] = "model", + seq_axis_names: Union[str, Sequence[str]] = "seq", +): + """Sets `cfg` to shard FFN and attention weights over both fsdp and tp axes. + + Args: + cfg: (A sequence of) Transformer layer config to apply sharding spec to. + batch_axis_names: Axis name(s) over which we shard the batch dimension of output tensors. + fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors. + tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors. + seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors. + """ + + def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config): + # Shard weights. + ff_layer.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names) + ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names) + # Encourage the right activation sharding. + ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) + ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) + + def set_mamba2_partition_specs(mamba_layer: Mamba2MixerLayer.Config): + mamba_layer.xz_proj.param_partition_spec = (fsdp_axis_names, None, tp_axis_names) + mamba_layer.bc_proj.param_partition_spec = (fsdp_axis_names, None, tp_axis_names) + mamba_layer.b_conv.param_partition_spec = (None, None, tp_axis_names) + mamba_layer.c_conv.param_partition_spec = (None, None, tp_axis_names) + mamba_layer.dt_proj.param_partition_spec = (fsdp_axis_names, tp_axis_names) + mamba_layer.out_proj.param_partition_spec = (tp_axis_names, fsdp_axis_names) + + mamba_layer.dt_proj.output_partition_spec = ( + batch_axis_names, + seq_axis_names, + tp_axis_names, + ) + mamba_layer.out_proj.output_partition_spec = ( + batch_axis_names, + seq_axis_names, + tp_axis_names, + ) + + if not isinstance(cfg, Sequence): + cfg = [cfg] + + for layer_cfg in cfg: + set_mamba2_partition_specs(layer_cfg.mamba_layer) + if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config): + set_ffn_partition_specs(layer_cfg.feed_forward) diff --git a/axlearn/common/ssm_kernels/ssd_kernels.py b/axlearn/common/ssm_kernels/ssd_kernels.py new file mode 100644 index 000000000..d05df3ee4 --- /dev/null +++ b/axlearn/common/ssm_kernels/ssd_kernels.py @@ -0,0 +1,781 @@ +# Copyright © 2024 Apple Inc. + +""" Pallas kernels for Mamba2 + +High-level idea: this kernel implements a two-level chunking algorithm to +balance memory consumption and running speed. Intuitively, we store chunk-level +hidden states to avoid recomputation, and subchunk-level states are recomputed based +on the chunk-level states. + + +Notations: + nb: number of chunks + ns: number of subchunks + bl: subchunk size + dkn: number of tiles in the dk dim + dvn: number of tiles in the dv dim + dk: state_dim (corresponds to dim of qk heads) + dv: head_dim (corresponds to dim of v heads) + +q/k/v is used as it's more intuitive than b/c/x of SSD in the orginal implementation, +see section 7.2 https://arxiv.org/pdf/2405.21060. Accordingly, dk/dv is used instead +of state_dim/head_dim. This notation is also used in linear attention models. +However, state_dim/head_dim is used in the model file to be consistent with Mamba1 +and the original implementation. + +""" + +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from einops import rearrange, repeat +from jax import lax +from jax._src.lax.control_flow import for_loop +from jax.experimental import pallas as pl + +from axlearn.common.utils import Tensor + + +def _matmul_fp32(lhs: Tensor, rhs: Tensor) -> Tensor: + """A wrapper around jax.lax.dot to conduct float32 matmul""" + return jax.lax.dot(lhs, rhs, precision="float32", preferred_element_type=jnp.float32) + + +@jax.custom_vjp +def _ssd(q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, h0: Tensor) -> Tensor: + """A differentiable function that computes the output of SSD. + + Args: + q: [bs, num_heads, seq_len, dk] + k: [bs, num_heads, seq_len, dk] + v: [bs, num_heads, seq_len, dv] + log_alpha: [bs, num_heads, seq_len] + h0: [bs, num_heads, dk, dv] + + Returns: + o: [bs, num_heads, seq_len, dv] + """ + ( + o, + _, + ) = _ssd_forward(q, k, v, log_alpha, h0) + return o + + +def _ssd_forward_kernel( + q_ref: Tensor, + k_ref: Tensor, + v_ref: Tensor, + cum_log_alpha_ref: Tensor, + initial_state_ref: Tensor, + gamma_ref: Tensor, + mutable_ch_ref: Tensor, + mutable_final_state_ref: Tensor, + mutable_o_ref: Tensor, +): + """Forward kernel for SSD. + + Args: + q_ref: tensor reference of shape [ns, bl, singleton_dim] + k_ref: tensor reference of shape [ns, bl, singleton_dim] + v_ref: tensor reference of shape [ns, bl, singleton_dim] + cum_log_alpha_ref: tensor reference of shape [ns, bl] + initial_state_ref: tensor reference of shape [singleton_dim, singleton_dim] + gamma_ref: tensor reference of shape [ns, bl, singleton_dim] + + Output via mutable tensors: + mutable_ch_ref: tensor reference of shape [ns, singleton_dim, singleton_dim] + mutable_final_state_ref: tensor reference of shape [singleton_dim, singleton_dim] + mutable_o_ref: tensor reference of shape [ns, bl, singleton_dim] + + Note on intial_state and final_state: + * initial_state is at seq-level and not updated during the forward pass + * final_state is used to pass chunk-level states across different chunks + - it will be initialized to initial_state at the beginning of each chunk + - it will be updated after processing each chunk + - in the end, it will return as the seq-level final state + """ + subchunk_dim, subchunk_size = cum_log_alpha_ref.shape[0], cum_log_alpha_ref.shape[1] + casual_mask = jnp.tril(jnp.ones((subchunk_size, subchunk_size)), k=0) + + # In our grid definition, axis 4 is the chunk index. + @pl.when(pl.program_id(axis=4) == 0) + def init_carry(): + mutable_final_state_ref[:, :] = initial_state_ref[:, :] + + def _ssd_forward_chunk_loop_body(t: int, h_carry_ref: Tensor): + subchunk_idx = t + prev_state = h_carry_ref[:, :] + + q_block = q_ref[subchunk_idx, :].astype(jnp.float32) + k_block = k_ref[subchunk_idx, :].astype(jnp.float32) + v_block = v_ref[subchunk_idx, :].astype(jnp.float32) + + # Notation mapping wrt. the paper: lambda -> Lambda, gamma -> gamma, beta -> Gamma. + lambda_block = cum_log_alpha_ref[subchunk_idx, :] + gamma_block = gamma_ref[subchunk_idx] + + lambda_block = jnp.expand_dims(lambda_block, axis=-1) # [bl, 1] + beta_block = ( + jnp.expand_dims(gamma_block, axis=0) - lambda_block + ) # [bl, singleton_dim] after broadcasting + ssd_mask_block = lambda_block - jnp.transpose(lambda_block, [1, 0]) + ssd_mask_block = ssd_mask_block * casual_mask + + lambda_block = jnp.exp(lambda_block) + beta_block = jnp.exp(beta_block) + gamma_block = jnp.exp(gamma_block) + ssd_mask_block = jnp.exp(ssd_mask_block) + + q_tilde_block = q_block * lambda_block + k_tilde_block = k_block * beta_block + + o_block_inter = _matmul_fp32(q_tilde_block, prev_state) + intra_att = _matmul_fp32(q_block, k_block.T) + attn_mask = casual_mask * ssd_mask_block + o_block_intra = _matmul_fp32((intra_att * attn_mask), v_block) + o_block = o_block_inter + o_block_intra + + cur_state = prev_state * jnp.expand_dims(gamma_block, axis=-1) + _matmul_fp32( + k_tilde_block.T, v_block + ) # [d_k, d_v] + h_carry_ref[:, :] = cur_state + mutable_o_ref[subchunk_idx, :] = o_block.astype(mutable_o_ref.dtype) + + # Obtain final state from previous chunk. + h_carry = mutable_final_state_ref[:, :] + mutable_ch_ref[:, :] = mutable_final_state_ref[:, :] + final_state = for_loop.for_loop( + subchunk_dim, + _ssd_forward_chunk_loop_body, + h_carry, + ) + mutable_final_state_ref[:, :] = final_state + + +@jax.jit +def _ssd_forward( + q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, initial_state: Tensor +) -> Tuple: + """Forward pass for SSD. + + Args: + q, k: [bs, num_heads, seq_len, dk] + v: [bs, num_heads, seq_len, dv] + log_alpha: [bs, num_heads, seq_len] + initial_state: [singleton_dim, singleton_dim] + + Returns: + o: [bs, num_heads, seq_len, dv] + residuals: Tuple of tensors to be used in the backward + """ + bs, num_qk_heads, seq_len, k_head_dim = q.shape + _, num_v_heads, _, v_head_dim = v.shape + # TODO (bailin-wang): the following defaults works best for v5p, but they may not be optimal + # for others tpu types. We may need to expose them as arguments in the future. + singleton_dim = 128 + chunk_size, subchunk_size = 512, 64 + acc_dtype, orig_dtype = jnp.float32, q.dtype + + assert seq_len % chunk_size == 0 and chunk_size % subchunk_size == 0 + + assert num_v_heads % num_qk_heads == 0 + num_heads = num_v_heads + num_head_per_group = num_v_heads // num_qk_heads + + assert k_head_dim % singleton_dim == 0 + assert v_head_dim % singleton_dim == 0 + num_k_tiles = k_head_dim // singleton_dim + num_v_tiles = v_head_dim // singleton_dim + + # Add two extra dims for chunk-wise computation. + chunk_dim = seq_len // chunk_size + subchunk_dim = chunk_size // subchunk_size + + grid = (bs, num_heads, num_k_tiles, num_v_tiles, chunk_dim) + + # q/k/v tensors are kept in bf16 and converted later to fp32 in VMEM. + log_alpha = log_alpha.astype(jnp.float32) + initial_state = initial_state.astype(jnp.float32) + + # None is effectively 1, but the dim will be squeezed out. + qk_tiling = (None, None, subchunk_dim, subchunk_size, singleton_dim) + qk_spec = pl.BlockSpec( + lambda b, h, k, v, m: (b, lax.div(h, num_head_per_group), m, 0, k), qk_tiling + ) + v_tiling = (None, None, subchunk_dim, subchunk_size, singleton_dim) + v_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, m, 0, v), v_tiling) + + alpha_tiling = (None, None, None, subchunk_dim, subchunk_size) + alpha_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, m, 0, 0), alpha_tiling) + + # Initial hidden states. + is_tiling = (None, None, singleton_dim, singleton_dim) + is_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, k, v), is_tiling) + + # Chunk-wise states (not subchunk-wise states). + ch_tiling = (None, None, None, singleton_dim, singleton_dim) + ch_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, m, k, v), ch_tiling) + + # Chunk-wise final states help pass states from the previous chunk to the next. + fs_spec = is_spec + + ch_shape = jax.ShapeDtypeStruct( + shape=(bs, num_heads, chunk_dim, k_head_dim, v_head_dim), dtype=acc_dtype + ) + fs_shape = jax.ShapeDtypeStruct( + shape=(bs, num_heads, k_head_dim, v_head_dim), dtype=jnp.float32 + ) + + # Pre-compute the cumulative sum of log_alpha. + log_alpha = rearrange( + log_alpha, "b h (nb ns bl) -> b h nb ns bl", nb=chunk_dim, ns=subchunk_dim + ) + cum_log_alpha = jnp.cumsum(log_alpha, axis=-1) + + q = rearrange(q, "b h (nb bl) dk -> b h nb bl dk", bl=subchunk_size) + k = rearrange(k, "b h (nb bl) dk -> b h nb bl dk", bl=subchunk_size) + v = rearrange(v, "b h (nb bl) dv -> b h nb bl dv", bl=subchunk_size) + + # Pallas kernels operate on tiles of size at least [8, 128]. + gamma = cum_log_alpha[:, :, :, :, subchunk_size - 1 :] # [b, h, nb, ns, 1] + gamma_expanded = jnp.repeat(gamma, singleton_dim, axis=-1) # [b, h, nb, ns, singleton_dim] + gamma_tiling = (None, None, None, subchunk_dim, singleton_dim) + gamma_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, m, 0, 0), gamma_tiling) + + o_tiling = (None, None, None, subchunk_dim, subchunk_size, singleton_dim) + o_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, k, m, 0, v), o_tiling) + o_shape = jax.ShapeDtypeStruct( + shape=(bs, num_heads, num_k_tiles, chunk_dim * subchunk_dim, subchunk_size, v_head_dim), + dtype=orig_dtype, + ) + + chunk_states, final_state, o = pl.pallas_call( + _ssd_forward_kernel, + in_specs=(qk_spec, qk_spec, v_spec, alpha_spec, is_spec, gamma_spec), + out_specs=(ch_spec, fs_spec, o_spec), + out_shape=(ch_shape, fs_shape, o_shape), + grid=grid, + compiler_params=dict( + mosaic=dict( + dimension_semantics=("parallel", "parallel", "parallel", "parallel", "arbitrary") + ) + ), + )(q, k, v, cum_log_alpha, initial_state, gamma_expanded) + + o = jnp.sum(o, axis=2) # sum over dkn dim + o = rearrange(o, "b h nb bl dv -> b h (nb bl) dv") + + # Input tensors q/k/v stored in the residual list for backward pass are reshaped, and + # cum_log_alpha and gamma are upcasted to float32. + final_state = final_state.astype(orig_dtype) + return o, (q, k, v, cum_log_alpha, gamma_expanded, chunk_states, final_state) + + +def _ssd_backward_kernel( + q_ref: Tensor, + k_ref: Tensor, + v_ref: Tensor, + cum_log_alpha_ref: Tensor, + gamma_ref: Tensor, + ch_ref: Tensor, + mutable_do_ref: Tensor, + mutable_dq_ref: Tensor, + mutable_dk_ref: Tensor, + mutable_dv_ref: Tensor, + mutable_dh_carry_ref: Tensor, +): + """Backward kernel for SSD. + + Args: + q_ref: tensor reference of shape [ns, bl, singleton_dim] + k_ref: tensor reference of shape [ns, bl, singleton_dim] + v_ref: tensor reference of shape [ns, bl, singleton_dim] + cum_log_alpha_ref: tensor reference of shape [ns, bl] + gamma_ref: tensor reference of shape [ns, bl, singleton_dim] + ch_ref: tensor reference of shape [ns, singleton_dim, singleton_dim] + + Output via mutable tensors: + mutable_do_ref: tensor reference of shape [ns, bl, singleton_dim] + mutable_dq_ref: tensor reference of shape [ns, bl, singleton_dim] + mutable_dk_ref: tensor reference of shape [ns, bl, singleton_dim] + mutable_dv_ref: tensor reference of shape [ns, bl, singleton_dim] + mutable_dh_carry_ref: tensor reference of shape [ns, singleton_dim, singleton_dim] + + Note: similar to final_state in the forward pass, dh_carry is used to pass gradients wrt. + hidden states across different chunks. It will be initalized to zero at the last chunk. + The final gradient wrt. hidden states will be returned as the gradient wrt. initial_state. + """ + subchunk_dim, subchunk_size = cum_log_alpha_ref.shape[0], cum_log_alpha_ref.shape[1] + causal_mask = jnp.tril(jnp.ones((subchunk_size, subchunk_size)), k=0).astype(jnp.float32) + + # In our grid definition, axis 4 is the chunk index. + @pl.when(pl.program_id(axis=4) == 0) + def init_carry(): + mutable_dh_carry_ref[:, :] = jnp.zeros_like(mutable_dh_carry_ref, dtype=jnp.float32) + + def _ssd_backward_dq_chunk_loop_body(t: int, h_carry_ref: Tensor): + subchunk_idx = t + h_block = h_carry_ref[:, :] # final states from previous chunk + k_block = k_ref[subchunk_idx, :].astype(jnp.float32) + v_block = v_ref[subchunk_idx, :].astype(jnp.float32) + do_block = mutable_do_ref[subchunk_idx, :].astype(jnp.float32) + + lambda_block = cum_log_alpha_ref[subchunk_idx, :] + gamma_block = gamma_ref[subchunk_idx] + + lambda_block = jnp.expand_dims(lambda_block, axis=-1) # [nb, 1] + beta_block = gamma_block - lambda_block # [nb, d_k] + ssd_mask_block = lambda_block - jnp.transpose(lambda_block, [1, 0]) + ssd_mask_block = ssd_mask_block * causal_mask + + lambda_block = jnp.exp(lambda_block) + beta_block = jnp.exp(beta_block) + gamma_block = jnp.exp(gamma_block) + ssd_mask_block = jnp.exp(ssd_mask_block) + + k_tilde_block = k_block * beta_block + + attn_mask = causal_mask * ssd_mask_block + d_intra_att = _matmul_fp32(do_block, v_block.T) * attn_mask + + dq_tilde_block = _matmul_fp32(do_block, h_block.T) + dq_block_1 = dq_tilde_block * lambda_block + dq_block_2 = _matmul_fp32(d_intra_att, k_block) + dq_block = dq_block_1 + dq_block_2 + mutable_dq_ref[subchunk_idx, :] = dq_block + + next_h_block = h_block * jnp.expand_dims(gamma_block, axis=-1) + _matmul_fp32( + k_tilde_block.T, v_block + ) + h_carry_ref[:, :] = next_h_block + + def _ssd_backward_dkv_chunk_loop_body(t: int, dh_carry_ref: Tensor): + subchunk_idx = t + dh_block = dh_carry_ref[:, :] + q_block = q_ref[subchunk_idx, :].astype(jnp.float32) + k_block = k_ref[subchunk_idx, :].astype(jnp.float32) + v_block = v_ref[subchunk_idx, :].astype(jnp.float32) + do_block = mutable_do_ref[subchunk_idx, :].astype(jnp.float32) + causal_mask = jnp.tril(jnp.ones((subchunk_size, subchunk_size)), k=0).astype(jnp.float32) + + lambda_block = cum_log_alpha_ref[subchunk_idx, :] + gamma_block = gamma_ref[subchunk_idx] + + lambda_block = jnp.expand_dims(lambda_block, axis=-1) # [nb, 1] + beta_block = gamma_block - lambda_block # [nb, d_k] + ssd_mask_block = lambda_block - jnp.transpose(lambda_block, [1, 0]) + ssd_mask_block = ssd_mask_block * causal_mask + + lambda_block = jnp.exp(lambda_block) + beta_block = jnp.exp(beta_block) + gamma_block = jnp.exp(gamma_block) + ssd_mask_block = jnp.exp(ssd_mask_block) + + q_tilde_block = q_block * lambda_block + k_tilde_block = k_block * beta_block + + intra_att = _matmul_fp32(q_block, k_block.T) + attn_mask = causal_mask * ssd_mask_block + d_intra_att = _matmul_fp32(do_block, v_block.T) * attn_mask + + dk_block_1 = _matmul_fp32(d_intra_att.T, q_block) + dk_tilde_block = _matmul_fp32(v_block, dh_block.T) + dk_block_2 = dk_tilde_block * beta_block + dk_block = dk_block_1 + dk_block_2 + mutable_dk_ref[subchunk_idx, :] = dk_block + + dv_block_1 = _matmul_fp32((intra_att * attn_mask).T, do_block) + dv_block_2 = _matmul_fp32(k_tilde_block, dh_block) + dv_block = dv_block_1 + dv_block_2 + mutable_dv_ref[subchunk_idx, :] = dv_block + + prev_dh_block = dh_block * jnp.expand_dims(gamma_block, axis=-1) + _matmul_fp32( + q_tilde_block.T, do_block + ) + dh_carry_ref[:, :] = prev_dh_block + + h_carry = ch_ref[:, :] + _ = for_loop.for_loop(subchunk_dim, _ssd_backward_dq_chunk_loop_body, h_carry) + + dh_carry = mutable_dh_carry_ref[:, :] + dinitial_state = for_loop.for_loop( + subchunk_dim, _ssd_backward_dkv_chunk_loop_body, dh_carry, reverse=True + ) + mutable_dh_carry_ref[:, :] = dinitial_state + + +@jax.jit +def _ssd_backward(residuals: Tuple, do: Tensor) -> Tuple: + """Backward pass for SSD. + + Args: + residuals: Tuple of tensors returned from the forward pass + do: [bs, num_heads, seq_len, dv] + + Returns: + dq: [bs, num_heads, seq_len, dk] + dk: [bs, num_heads, seq_len, dk] + dv: [bs, num_heads, seq_len, dv] + dlog_alpha: [bs, num_heads, seq_len] + dinitial_state: [bs, num_heads, dk, dv] + """ + q, k, v, cum_log_alpha, gamma_expanded, chunk_states, final_state = residuals + + # `final_state` preserves the original dtype (e.g., bfloat16). + orig_dtype = final_state.dtype + + singleton_dim = 128 + bs, num_heads, chunk_dim, subchunk_dim, subchunk_size = cum_log_alpha.shape + k_dim, v_dim = q.shape[-1], v.shape[-1] + num_k_tiles, num_v_tiles = k_dim // singleton_dim, v_dim // singleton_dim + num_qk_heads = q.shape[1] + num_head_per_group = num_heads // num_qk_heads + + grid = (bs, num_heads, num_k_tiles, num_v_tiles, chunk_dim) + + qk_tiling = (None, None, subchunk_dim, subchunk_size, singleton_dim) + qk_spec = pl.BlockSpec( + lambda b, h, k, v, m: (b, lax.div(h, num_head_per_group), chunk_dim - 1 - m, 0, k), + qk_tiling, + ) + v_tiling = (None, None, subchunk_dim, subchunk_size, singleton_dim) + v_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, chunk_dim - 1 - m, 0, v), v_tiling) + + alpha_tiling = (None, None, None, subchunk_dim, subchunk_size) + alpha_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, chunk_dim - 1 - m, 0, 0), alpha_tiling) + gamma_tiling = (None, None, None, subchunk_dim, singleton_dim) + gamma_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, chunk_dim - 1 - m, 0, 0), gamma_tiling) + + ch_tiling = (None, None, None, singleton_dim, singleton_dim) + ch_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, chunk_dim - 1 - m, k, v), ch_tiling) + + do_tiling = (None, None, subchunk_dim, subchunk_size, singleton_dim) + do_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, chunk_dim - 1 - m, 0, v), do_tiling) + + dqk_tiling = (None, None, None, None, subchunk_dim, subchunk_size, singleton_dim) + dqk_spec = pl.BlockSpec( + lambda b, h, k, v, m: ( + b, + lax.div(h, num_head_per_group), + lax.rem(h, num_head_per_group), + v, + chunk_dim - 1 - m, + 0, + k, + ), + dqk_tiling, + ) + dqk_shape = jax.ShapeDtypeStruct( + shape=( + bs, + num_qk_heads, + num_head_per_group, + num_v_tiles, + chunk_dim * subchunk_dim, + subchunk_size, + k_dim, + ), + dtype=jnp.float32, + ) + + dv_tiling = (None, None, None, subchunk_dim, subchunk_size, singleton_dim) + dv_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, k, chunk_dim - 1 - m, 0, v), dv_tiling) + dv_shape = jax.ShapeDtypeStruct( + shape=(bs, num_heads, num_k_tiles, chunk_dim * subchunk_dim, subchunk_size, v_dim), + dtype=jnp.float32, + ) + + dh_carry_tiling = (None, None, singleton_dim, singleton_dim) + dh_carry_spec = pl.BlockSpec(lambda b, h, k, v, m: (b, h, k, v), dh_carry_tiling) + dh_carry_shape = jax.ShapeDtypeStruct(shape=(bs, num_heads, k_dim, v_dim), dtype=jnp.float32) + + do = rearrange(do, "b h (nb bl) dv -> b h nb bl dv", bl=subchunk_size) + + dq, dk, dv, dinitial_state = pl.pallas_call( + _ssd_backward_kernel, + in_specs=(qk_spec, qk_spec, v_spec, alpha_spec, gamma_spec, ch_spec, do_spec), + out_specs=(dqk_spec, dqk_spec, dv_spec, dh_carry_spec), + out_shape=(dqk_shape, dqk_shape, dv_shape, dh_carry_shape), + grid=grid, + compiler_params=dict( + mosaic=dict( + dimension_semantics=("parallel", "parallel", "parallel", "parallel", "arbitrary") + ) + ), + )(q, k, v, cum_log_alpha, gamma_expanded, chunk_states, do) + + # Sum over dvn dim. + dq = jnp.sum(dq, axis=3) + dk = jnp.sum(dk, axis=3) + dq = rearrange(dq, "b ng nhg nb bl dk -> b ng nhg (nb bl) dk") + dk = rearrange(dk, "b ng nhg nb bl dk -> b ng nhg (nb bl) dk") + + # Compute dlog_alpha via `q * dq - k * dk`. + dq_ = rearrange(dq, "b ng nhg l dk -> b (ng nhg) l dk") + dk_ = rearrange(dk, "b ng nhg l dk -> b (ng nhg) l dk") + + q_ = repeat(q, "b ng nb bl dk -> b (ng nhg) nb bl dk", nhg=num_head_per_group) + k_ = repeat(k, "b ng nb bl dk -> b (ng nhg) nb bl dk", nhg=num_head_per_group) + q_ = rearrange(q_, "b h nb bl dk -> b h (nb bl) dk") + k_ = rearrange(k_, "b h nb bl dk -> b h (nb bl) dk") + + dlog_alpha_ = jnp.sum(dq_ * q_ - dk_ * k_, axis=-1) + dlog_alpha = lax.cumsum(dlog_alpha_, axis=2, reverse=True) + + # Sum over dkn dim. + dv = jnp.sum(dv, axis=2) + dv = rearrange(dv, "b h nb bl dv -> b h (nb bl) dv") + + # Sum over nhg dim + dq = jnp.sum(dq, axis=2) + dk = jnp.sum(dk, axis=2) + # `dlog_alpha` is always in float32, `dv` is also in float32. + dq, dk = dq.astype(orig_dtype), dk.astype(orig_dtype) + + dinitial_state = dinitial_state.astype(orig_dtype) + return dq, dk, dv, dlog_alpha, dinitial_state + + +_ssd.defvjp(_ssd_forward, _ssd_backward) + + +@jax.jit +@jax.named_call # `named_call` ensures the name is used in tracing, which is useful for profiling. +def ssd(q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, h0: Optional[Tensor] = None) -> Tensor: + """Differentiable function that computes the output of SSD. + + Args: + q: [batch_size, num_groups, seq_len, dk] + k: [batch_size, num_groups, seq_len, dk] + v: [batch_size, num_groups, seq_len, dv] + log_alpha: [batch_size, num_heads, seq_len] + h0: [batch_size, num_heads, dk, dv] + + Returns: + output: [batch_size, num_heads, seq_len, dv] + + The notion of groups is similar to the group in multi-group attention (or more preciesly + multi-value attention) -- one group of q/k corresponds to multiple v heads. + """ + + bs, ng, _, dk = q.shape + bs, nh, _, dv = v.shape + assert nh % ng == 0 + assert v.dtype == jnp.float32 + assert log_alpha.dtype == jnp.float32 + + if h0 is None: + h0 = jnp.zeros((bs, nh, dk, dv), dtype=jnp.float32) + + output = _ssd(q, k, v, log_alpha, h0) + return output + + +def ssd_linear_scan( + q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, h0: Union[Tensor, None] = None +) -> Tensor: + """LinearScan based reference implementations for testing SSD kernels. + + Args: + q, k: [batch_size, num_groups, seq_len, dk] + v: [batch_size, num_groups, seq_len, dv] + log_alpha: [batch_size, num_heads, seq_len] + h0: [batch_size, num_heads, dk, dv] or None + + Returns: + output: [batch_size, num_heads, seq_len, dv] + """ + bs, ng, _, dk = q.shape + bs, nh, _, dv = v.shape + assert nh % ng == 0 + + # The linearscan kernel assumes that nh == ng, so we need to repeat q/k. + num_head_per_group = nh // ng + q = repeat(q, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + k = repeat(k, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + + # ITt's more convenient for vmap to have internal states of size [dv, dk] + if h0 is None: + h0 = jnp.zeros((bs, nh, dv, dk), dtype=jnp.float32) + else: + # to be consistent with pallas api, h0 is in dk x dv as input + h0 = rearrange(h0, "b h dk dv -> b h dv dk") + + # All inputs are upcasted to float32, making this function a good reference funciton to + # test pallas kernel's numerical precision in the case of bf16 inputs. + dtype = q.dtype + if dtype == jnp.bfloat16: + q, k, v, h0 = map(lambda x: x.astype(jnp.float32), (q, k, v, h0)) + + def scan_body_fn(h_prev, current_inputs): + acc_dtype = h_prev.dtype + q_t, k_t, v_t, log_a_t = current_inputs + a_t = jnp.exp(log_a_t).astype(acc_dtype) + h_next = a_t * h_prev + jnp.einsum("i,j->ij", v_t, k_t, preferred_element_type=jnp.float32) + o_t = jnp.einsum("ij,j->i", h_next, q_t, preferred_element_type=jnp.float32) + return h_next, o_t.astype(q_t.dtype) + + def single_head_scan(q_head, k_head, v_head, alpha_head, h0_head): + return jax.lax.scan(scan_body_fn, h0_head, (q_head, k_head, v_head, alpha_head)) + + multi_head_scan = jax.vmap(single_head_scan, in_axes=(0, 0, 0, 0, 0), out_axes=(0, 0)) + batched_scan = jax.vmap(multi_head_scan, in_axes=(0, 0, 0, 0, 0), out_axes=(0, 0)) + + # Note: if dk > 128 (e.g., 256), somehow jax jvp would fail; a work-around + # is to add another dim to ensure that minor dk is always 128. + q = rearrange(q, "b h l (dkn dks) -> dkn b h l dks", dks=128) + k = rearrange(k, "b h l (dkn dks) -> dkn b h l dks", dks=128) + h0 = rearrange(h0, "b h dv (dkn dks) -> dkn b h dv dks", dks=128) + + batched_scan = jax.vmap(batched_scan, in_axes=(0, 0, None, None, 0), out_axes=(0, 0)) + final_state, output = batched_scan(q, k, v, log_alpha, h0) + final_state = rearrange(final_state, "dkn b h dv dks -> b h dv (dkn dks)") + output = jnp.sum(output, axis=0) + + final_state = rearrange(final_state, "b h dv dk -> b h dk dv") + + if dtype == jnp.bfloat16: + output = output.astype(jnp.bfloat16) + final_state = final_state.astype(jnp.bfloat16) + + return output, final_state + + +def ssd_linear_scan_w_hidden_states( + q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, h0: Union[Tensor, None] = None +) -> Tensor: + """LinearScan based reference implementations for testing SSD kernels. + + This version additionally returns the hidden states of all tokens. + + Args: + q: [batch_size, num_groups, seqlen, dk] + k: [batch_size, num_groups, seqlen, dk] + v: [batch_size, num_groups, seqlen, dv] + log_alpha: [batch_size, num_heads, seqlen] + h0: [batch_size, num_heads, dk, dv] or None + + Returns: + output: [batch_size, num_heads, seq_len, dv] + """ + bs, ng, _, dk = q.shape + bs, nh, _, dv = v.shape + assert nh % ng == 0 + + num_head_per_group = nh // ng + q = repeat(q, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + k = repeat(k, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + + if h0 is None: + h0 = jnp.zeros((bs, nh, dv, dk), dtype=jnp.float32) + else: + # to be consistent with pallas api, h0 is in dk x dv as input + h0 = rearrange(h0, "b h dk dv -> b h dv dk") + + dtype = q.dtype + if dtype == jnp.bfloat16: + q, k, v, h0 = map(lambda x: x.astype(jnp.float32), (q, k, v, h0)) + + def scan_body_fn(h_prev, current_inputs): + acc_dtype = h_prev.dtype + k_t, v_t, log_a_t = current_inputs + a_t = jnp.exp(log_a_t).astype(acc_dtype) + h_next = a_t * h_prev + jnp.einsum("i,j->ij", v_t, k_t) + return h_next, h_next + + def single_head_scan(k_head, v_head, alpha_head, h0_head): + return jax.lax.scan(scan_body_fn, h0_head, (k_head, v_head, alpha_head)) + + multi_head_scan = jax.vmap(single_head_scan, in_axes=(0, 0, 0, 0), out_axes=(0, 0)) + batched_scan = jax.vmap(multi_head_scan, in_axes=(0, 0, 0, 0), out_axes=(0, 0)) + + k = rearrange(k, "b h l (dkn dks) -> dkn b h l dks", dks=128) + h0 = rearrange(h0, "b h dv (dkn dks) -> dkn b h dv dks", dks=128) + + batched_scan = jax.vmap(batched_scan, in_axes=(0, None, None, 0), out_axes=(0, 0)) + final_state, hidden_states = batched_scan(k, v, log_alpha, h0) + assert final_state is not None + + hidden_states = rearrange(hidden_states, "dkn b h l dv dks -> b h l (dkn dks) dv") + output = jnp.einsum( + "b h l s, b h l s d -> b h l d", q, hidden_states, preferred_element_type=jnp.float32 + ) + + if dtype == jnp.bfloat16: + output = output.astype(jnp.bfloat16) + return output, hidden_states + + +def ssd_linear_scan_w_timestep( + q: Tensor, k: Tensor, v: Tensor, log_alpha: Tensor, timestep: Tensor, h0=None +) -> Tensor: + """LinearScan that takes timestep as input and masks useless k/v based on timestep. + + This function is used during inference where decoding might start from different timesteps. + + Args: + q: [batch_size, num_groups, seqlen, dk] + k: [batch_size, num_groups, seqlen, dk] + v: [batch_size, num_groups, seqlen, dv] + log_alpha: [batch_size, num_heads, seqlen] + h0: [batch_size, num_heads, dk, dv] or None + timestep: [batch_size, seqlen] or None + + Returns: + output: [batch_size, num_heads, seq_len, dv] + + """ + bs, ng, l, dk = q.shape + bs, nh, l, dv = v.shape + assert nh % ng == 0 + + num_head_per_group = nh // ng + q = repeat(q, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + k = repeat(k, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + + timestep_mask = jnp.arange(l)[None, :] >= timestep[:, None] + k = jnp.where(timestep_mask[:, None, :, None], 0.0, k) + v = jnp.where(timestep_mask[:, None, :, None], 0.0, v) + log_alpha = jnp.where(timestep_mask[:, None, :], 0.0, log_alpha) + + if h0 is None: + h0 = jnp.zeros((bs, nh, dv, dk), dtype=jnp.float32) + else: + # to be consistent with pallas api, h0 is in dk x dv as input + h0 = rearrange(h0, "b h dk dv -> b h dv dk") + + dtype = q.dtype + if dtype == jnp.bfloat16: + q, k, v, h0 = map(lambda x: x.astype(jnp.float32), (q, k, v, h0)) + + def scan_body_fn(h_prev, current_inputs): + acc_dtype = h_prev.dtype + q_t, k_t, v_t, log_a_t = current_inputs + a_t = jnp.exp(log_a_t).astype(acc_dtype) + h_next = a_t * h_prev + jnp.einsum("i,j->ij", v_t, k_t, preferred_element_type=jnp.float32) + o_t = jnp.einsum("ij,j->i", h_next, q_t, preferred_element_type=jnp.float32) + return h_next, o_t.astype(q_t.dtype) + + def single_head_scan(q_head, k_head, v_head, alpha_head, h0_head): + return jax.lax.scan(scan_body_fn, h0_head, (q_head, k_head, v_head, alpha_head)) + + multi_head_scan = jax.vmap(single_head_scan, in_axes=(0, 0, 0, 0, 0), out_axes=(0, 0)) + batched_scan = jax.vmap(multi_head_scan, in_axes=(0, 0, 0, 0, 0), out_axes=(0, 0)) + + q = rearrange(q, "b h l (dkn dks) -> dkn b h l dks", dks=128) + k = rearrange(k, "b h l (dkn dks) -> dkn b h l dks", dks=128) + h0 = rearrange(h0, "b h dv (dkn dks) -> dkn b h dv dks", dks=128) + + batched_scan = jax.vmap(batched_scan, in_axes=(0, 0, None, None, 0), out_axes=(0, 0)) + final_state, output = batched_scan(q, k, v, log_alpha, h0) + final_state = rearrange(final_state, "dkn b h dv dks -> b h dv (dkn dks)") + output = jnp.sum(output, axis=0) + + final_state = rearrange(final_state, "b h dv dk -> b h dk dv") + + if dtype == jnp.bfloat16: + output = output.astype(jnp.bfloat16) + + return output, final_state diff --git a/axlearn/common/ssm_kernels/ssd_kernels_test.py b/axlearn/common/ssm_kernels/ssd_kernels_test.py new file mode 100644 index 000000000..2cc5c05ad --- /dev/null +++ b/axlearn/common/ssm_kernels/ssd_kernels_test.py @@ -0,0 +1,389 @@ +# Copyright © 2024 Apple Inc. + +"""Tests SSD Pallas kernels.""" +from typing import Union + +import jax +import jax.nn as jnn +import jax.numpy as jnp +import numpy as np +import pytest +import torch +from absl.testing import parameterized +from einops import rearrange, repeat +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, PartitionSpec +from torch.nn import functional as F + +from axlearn.common.ssm_kernels.ssd_kernels import _ssd_backward, _ssd_forward, ssd, ssd_linear_scan +from axlearn.common.test_utils import TestCase, assert_allclose + +if jax.default_backend() != "tpu": + pytest.skip(reason="Incompatible hardware", allow_module_level=True) + + +def _ssd_reference(q, k, v, log_alpha, h0): + """Reference implementation of SSD for comparison. + + Args: + q/k: [batch_size, num_heads, seq_len, dk] + v: [batch_size, num_heads, seq_len, dv] + log_alpha: [batch_size, num_heads, seq_len] + h0: [batch_size, num_heads, dk, dv] + + Returns: + o: [batch_size, num_heads, seq_len, dv] + """ + return ssd_linear_scan(q, k, v, log_alpha, h0)[0] + + +def _ssd_naive_reference(q, k, v, log_alpha, h0=None): + """For-loop reference implementation of SSD. + + Note that this implementation somehow have worse + numerical stability than the vmap version above. + + Args: + q/k: [batch_size, num_heads, seq_len, dk] + v: [batch_size, num_heads, seq_len, dv] + log_alpha: [batch_size, num_heads, seq_len] + h0: [batch_size, num_heads, dk, dv] + + Returns: + o: [batch_size, num_heads, seq_len, dv] + h: [batch_size, num_heads, dk, dv] + """ + bs, ng, l, dk = q.shape + _, _, _, dv = v.shape + + bs, ng, l, dk = q.shape + bs, nh, l, dv = v.shape + assert nh % ng == 0 + + num_head_per_group = nh // ng + q = repeat(q, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + k = repeat(k, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) + + if h0 is None: + h0 = jnp.zeros((bs, nh, dk, dv), dtype=jnp.float32) + + o_list = [] + h = h0 + for t in range(l): + q_t = q[:, :, t] + k_t = k[:, :, t] + v_t = v[:, :, t] + alpha_t = jnp.exp(log_alpha[:, :, t, None, None]) + + h = alpha_t * h + jnp.einsum( + "...i,...j->...ij", k_t, v_t, preferred_element_type=jnp.float32 + ) + o_t = jnp.einsum("...ij,...i->...j", h, q_t, preferred_element_type=jnp.float32) + o_list.append(o_t) + o = jnp.stack(o_list, axis=2) + return o, h + + +# disable some pylint checks to allow copied code to pass checks + +# pylint: disable=line-too-long +# pylint: disable=invalid-name +# pylint: disable=unused-variable + + +def segsum(x): + """More stable segment sum calculation. Helper function copied from + https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py. + """ + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_chunk_tri(X, A, B, C, chunk_size=16, initial_states=None): + """Reference implementation of SSD with chunked computation, copied from + https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py. + + Arguments: + X: (batch, length, n_heads, d_head) + A: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + Y: (batch, length, n_heads, d_head) + + X, A, B, C corresponds to V, \alpha, K, Q in linear attention + (H_t = \alpha H_{t-1)+ K_t^\top V_t, O_t = Q_t S_t). + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % chunk_size == 0 + + # Rearrange into blocks/chunks + X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=chunk_size) for x in (X, A, B, C)] + + A = rearrange(A, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(segsum(A)) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + return Y, final_state + + +@jax.jit +def _ssd_reference_vjp( + q: jax.Array, + k: jax.Array, + v: jax.Array, + alpha: jax.Array, + h0: Union[jax.Array, None], + do: jax.Array, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + o, vjp = jax.vjp(_ssd_reference, q, k, v, alpha, h0) + return o, vjp(do) + + +def _generate_ssd_inputs(shape, dtype, seed, paramn="gla", zero_h0=True): + """ + Args: + shape: [bs, ng, nh, l, dk, dv] + dtype: float32, bfloat16 + seed: random seed + paramn: "mamba" or "gla" + zero_h0: whether to generate zero initial hidden state + + Returns: + q, k, v, log_alpha, h0, do + """ + bs, ng, nh, l, dk, dv = shape + rng = jax.random.PRNGKey(seed) + q_key, k_key, v_key, alpha_key, h_key, dh_key = jax.random.split(rng, 6) + + if paramn == "mamba": + q = jax.random.uniform(q_key, (bs, ng, l, dk), dtype=dtype) + k = jax.random.uniform(k_key, (bs, ng, l, dk), dtype=dtype) + v = jax.random.uniform(v_key, (bs, nh, l, dv), dtype=jnp.float32) + + log_alpha = -jnp.exp(jax.random.uniform(alpha_key, (bs, nh, l), dtype=jnp.float32)) + dt = jax.random.normal(alpha_key, (bs, nh, l), dtype=jnp.float32) + dt = jnp.log(1.0 + jnp.exp(dt - 4)) + + log_alpha = dt * log_alpha + v = v * dt[..., None] + elif paramn == "gla": + q = jax.random.normal(q_key, (bs, ng, l, dk), dtype=dtype) + k = jax.random.normal(k_key, (bs, ng, l, dk), dtype=dtype) + v = jax.random.normal(v_key, (bs, nh, l, dv), dtype=dtype) + + # shortconv (skipped) and non-linear activation + q = jnn.silu(q) + k = jnn.silu(k) + v = jnn.silu(v) + + # l2 norm (help reduces the range of dq/dk -> better precision for bfloat16) + q = q / jnp.linalg.norm(q, axis=-1, keepdims=True) + k = k / jnp.linalg.norm(k, axis=-1, keepdims=True) + + log_alpha = ( + jnn.log_sigmoid(jax.random.normal(alpha_key, (bs, nh, l), dtype=jnp.float32)) / 16.0 + ) + else: + raise ValueError(f"Unsupported param: {paramn}") + + if zero_h0: + h0 = jnp.zeros((bs, nh, dk, dv), dtype=jnp.float32) + else: + h0 = jax.random.normal(h_key, (bs, nh, dk, dv), dtype=jnp.float32) + + do = jax.random.normal(dh_key, (bs, nh, l, dv), dtype=dtype) + + # log_alpha is always in float32 + log_alpha = log_alpha.astype(jnp.float32) + return q, k, v, log_alpha, h0, do + + +class SSDPallasKernelTest(TestCase): + @parameterized.product( + batch_size=[2, 4], + num_heads=[4, 8], + seq_len=[1024, 2048], + dk=[128, 256], + dv=[128, 256], + seed=[0, 1], + ) + def test_ssd_forward( + self, batch_size: int, num_heads: int, seq_len: int, dk: int, dv: int, seed: int + ) -> None: + """Test SSD forward pass against Tri's torch reference implementation.""" + # Set the device to CPU + device = "cpu" + + # Set the random seed for reproducibility + np.random.seed(seed) + + # Generate random input data + x = np.random.rand(batch_size, seq_len, num_heads, dk).astype(np.float32) + dt = np.random.rand(batch_size, seq_len, num_heads).astype(np.float32) + dt = np.log(1.0 + np.exp(dt - 4)) + A = -np.exp(np.random.rand(batch_size, seq_len, num_heads).astype(np.float32)) + B = np.random.rand(batch_size, seq_len, num_heads, dv).astype(np.float32) + C = np.random.rand(batch_size, seq_len, num_heads, dv).astype(np.float32) + + # Compute intermediate variables + x_bar = x * dt[..., None] + A_bar = A * dt + + # Convert numpy arrays to torch tensors + x_torch = torch.tensor(x, dtype=torch.float32) + dt_torch = torch.tensor(dt, dtype=torch.float32) + A_torch = torch.tensor(A, dtype=torch.float32) + B_torch = torch.tensor(B, dtype=torch.float32) + C_torch = torch.tensor(C, dtype=torch.float32) + x_bar_torch = torch.tensor(x_bar, dtype=torch.float32) + A_bar_torch = torch.tensor(A_bar, dtype=torch.float32) + + # Compute the torch reference output + y_torch, _ = ssd_chunk_tri(x_bar_torch, A_bar_torch, B_torch, C_torch) + + # Convert numpy arrays to jax arrays + x_jax = jnp.array(x, dtype=jnp.float32) + dt_jax = jnp.array(dt, dtype=jnp.float32) + A_jax = jnp.array(A, dtype=jnp.float32) + B_jax = jnp.array(B, dtype=jnp.float32) + C_jax = jnp.array(C, dtype=jnp.float32) + x_bar_jax = jnp.array(x_bar, dtype=jnp.float32) + A_bar_jax = jnp.array(A_bar, dtype=jnp.float32) + + # Reshape jax arrays for comparison + x_jax = rearrange(x_jax, "b t h d -> b h t d") + dt_jax = rearrange(dt_jax, "b t h -> b h t") + A_jax = rearrange(A_jax, "b t h -> b h t") + B_jax = rearrange(B_jax, "b t h n -> b h t n") + C_jax = rearrange(C_jax, "b t h n -> b h t n") + x_bar_jax = rearrange(x_bar_jax, "b t h d -> b h t d") + A_bar_jax = rearrange(A_bar_jax, "b t h -> b h t") + + # Compute the jax output + y_jax = ssd(C_jax, B_jax, x_bar_jax, A_bar_jax, h0=None) + y_jax = rearrange(y_jax, "b h t d -> b t h d") + + assert_allclose(y_torch.numpy(), np.asarray(y_jax), atol=1e-3, rtol=1e-3) + + @parameterized.product( + batch_size=[2, 4], + num_heads=[4, 8], + seq_len=[1024, 2048], + dk=[128, 256], + dv=[128, 256], + dtype=["float32", "bfloat16"], + seed=[0, 1], + ) + def test_forward_and_backward(self, batch_size, num_heads, seq_len, dk, dv, dtype, seed): + try: + self.ssd_forward_and_backward(batch_size, num_heads, seq_len, dk, dv, dtype, seed) + except Exception as e: + # breakpoint() # uncomment for debugging failed conditions + raise e + + def ssd_forward_and_backward(self, batch_size, num_heads, seq_len, dk, dv, dtype, seed): + num_groups = num_heads + shape = (batch_size, num_groups, num_heads, seq_len, dk, dv) + q, k, v, log_alpha, h0, do = _generate_ssd_inputs(shape, dtype, seed) + if dtype == "float32": + tol = 1e-3 + elif dtype == "bfloat16": + tol = 1e-2 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + o_pallas, residuals = _ssd_forward(q, k, v, log_alpha, h0) + final_state_pallas = residuals[-1] + o_ref, final_state_ref = ssd_linear_scan(q, k, v, log_alpha, h0) + + assert_allclose(o_pallas, o_ref, atol=tol, rtol=tol) + assert_allclose(final_state_pallas, final_state_ref, atol=tol, rtol=tol) + + dq_pallas, dk_pallas, dv_pallas, dlog_alpha_pallas, dh0_pallas = _ssd_backward( + residuals, do + ) + _, ssd_reference_grad_ = jax.vjp(_ssd_reference, q, k, v, log_alpha, h0) + dq_ref, dk_ref, dv_ref, dlog_alpha_ref, dh0_ref = ssd_reference_grad_(do) + + assert_allclose(dq_pallas, dq_ref, atol=tol, rtol=tol) + assert_allclose(dk_pallas, dk_ref, atol=tol, rtol=tol) + assert_allclose(dv_pallas, dv_ref, atol=tol, rtol=tol) + assert_allclose(dlog_alpha_pallas, dlog_alpha_ref, atol=tol, rtol=tol) + assert_allclose(dh0_pallas, dh0_ref, atol=tol, rtol=tol) + + +class ShardSSDPallasKernelTest(TestCase): + # this test only works for four devices + @pytest.mark.skipif(jax.device_count() != 4, reason="Requires 4 devices") + def test_sharded_ssd_wo_sp(self): + batch, ngroups, nheads, seqlen, k_head_dim, v_head_dim = 8, 4, 4, 1024, 256, 128 + dtype = "float32" + q, k, v, log_alpha, _, _ = _generate_ssd_inputs( + (batch, ngroups, nheads, seqlen, k_head_dim, v_head_dim), dtype, 0 + ) + + o_ref, _ = ssd_linear_scan(q, k, v, log_alpha) + + devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) + mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) + + def get_sharded_ssd(mesh): + """ + Note: current version assumes that h0 is None, for which you don't + need to provide partition spec. + """ + sharded_ssd = shard_map( + ssd, + mesh=mesh, + in_specs=( + PartitionSpec(("data", "expert", "fsdp"), ("seq", "model"), None, None), + PartitionSpec( + ("data", "expert", "fsdp"), + ("seq", "model"), + None, + None, + ), + PartitionSpec(("data", "expert", "fsdp"), ("seq", "model"), None, None), + PartitionSpec(("data", "expert", "fsdp"), ("seq", "model"), None), + ), + out_specs=PartitionSpec(("data", "expert", "fsdp"), "model", "seq", None), + check_rep=False, + ) + return sharded_ssd + + sharded_ssd = get_sharded_ssd(mesh) + o_pallas = sharded_ssd(q, k, v, log_alpha) + + assert_allclose(o_pallas, o_ref, atol=1e-3, rtol=1e-3) diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index 722f13d94..7412cf74d 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -14,15 +14,19 @@ # Licensed under the Apache License, Version 2.0 (the "License"). -"""Tests Mamba and Jamba implementations.""" +"""Tests Mamba/Mamba2 and Jamba implementations.""" import math from typing import Optional import jax import jax.numpy as jnp import numpy as np +import pytest import torch from absl.testing import parameterized +from jax._src.mesh import ResourceEnv, thread_resources +from jax.experimental import mesh_utils +from jax.sharding import Mesh, PartitionSpec from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig @@ -32,17 +36,28 @@ from axlearn.common.ssm import ( AssociativeScanMambaRecurrence, BlockResidualMode, + JambaMamba2Block, JambaMambaBlock, LinearScanMambaRecurrence, + Mamba2MixerLayer, MambaBlock, MambaMixerLayer, + PallasSSDRecurrence, RepeatedSSMLayer, StackedMixedSSMTransformerLayer, StackedSSMLayer, ) +from axlearn.common.ssm_kernels.ssd_kernels import ssd from axlearn.common.test_utils import TestCase, assert_allclose from axlearn.common.utils import Nested, Tensor, cast_floats +try: + from mamba_ssm.modules.mamba2_simple import Mamba2Simple # pytype: disable=import-error + + MAMBA_INSTALLED = True +except ModuleNotFoundError: + MAMBA_INSTALLED = False + # The following PyTorch Mamba implementations are adapted from: # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/mamba/modeling_mamba.py # and @@ -931,3 +946,589 @@ def test_prefill(self, dtype: jnp.dtype): cfg.layer.self_attention.attention.num_heads = num_heads cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None) _test_prefill_states(cfg, model_dim=model_dim, dtype=dtype) + + +@pytest.mark.skipif( + jax.default_backend() != "tpu" or jax.device_count() != 4, + reason="Test requires four chips, e.g., one v5p gcp instance.", +) +class Mamba2RecurrenceTest(TestCase): + """Test the correctness of the Mamba2 recurrence for decoding.""" + + @classmethod + def setup_class(cls): + devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) + global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) + new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + thread_resources.env = new_env + + @classmethod + def teardown_class(cls): + init_env = ResourceEnv(physical_mesh=(), loops=()) + thread_resources.env = init_env + + def test_ssd_parameterization(self): + batch_size, num_heads, seq_len, state_dim, head_dim = 2, 4, 1024, 128, 256 + key = jax.random.PRNGKey(0) + dtype = jnp.float32 + + # note that construct random params requires that log_a <= 0 and delta > 0. + x = jax.random.normal(key, (batch_size, num_heads, seq_len, head_dim), dtype=dtype) + llog_a = jax.random.uniform(key, (1, num_heads, 1), dtype=dtype) + log_a = -jnp.exp(llog_a) + b = jax.random.normal(key, (batch_size, num_heads, seq_len, state_dim), dtype=dtype) + c = jax.random.normal(key, (batch_size, num_heads, seq_len, state_dim), dtype=dtype) + delta = jax.nn.softplus( + jax.random.uniform(key, (batch_size, num_heads, seq_len), dtype=dtype) - 4.0 + ) + d = jax.random.normal(key, (1, num_heads, 1, 1), dtype=dtype) + + mamba2_dim_to_partition_spec = { + "bhtd": PartitionSpec(("data", "expert", "fsdp"), ("seq", "model"), None, None), + "bht": PartitionSpec(("data", "expert", "fsdp"), ("seq", "model"), None), + } + output_partition_spec = PartitionSpec(("data", "expert", "fsdp"), "model", "seq", None) + + cfg = PallasSSDRecurrence.default_config().set( + name="test", + mamba2_dim_to_partition_spec=mamba2_dim_to_partition_spec, + output_partition_spec=output_partition_spec, + ) + layer = cfg.instantiate(parent=None) + o_module, _ = F( + layer, + inputs=dict(x=x, log_a=log_a, b=b, c=c, delta=delta, d=d), + state=None, + is_training=False, + prng_key=key, + ) + + # alternative input to the kernel; delta by default is applied to x to get x_bar, here we can + # also apply it to b to get b_bar first. + b_bar = b * jnp.expand_dims(delta, axis=-1) + loga_bar = log_a * delta + o_alternative = ssd(c, b_bar, x, loga_bar) + d * x + assert_allclose(o_module.data, o_alternative, atol=1e-1, rtol=1e-1) + + +@pytest.mark.skipif( + jax.default_backend() != "tpu" or jax.device_count() != 4, + reason="Test requires four chips, e.g., one v5p gcp instance.", +) +class Mamba2MixerLayerTest(TestCase): + @classmethod + def setup_class(cls): + devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) + global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) + new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + thread_resources.env = new_env + + @classmethod + def teardown_class(cls): + init_env = ResourceEnv(physical_mesh=(), loops=()) + thread_resources.env = init_env + + @parameterized.product( + dtype=(jnp.float32, jnp.bfloat16), + inference_mode=(True, False), + ) + def test_extend_step(self, dtype: jnp.dtype, inference_mode: bool): + batch_size = 2 + input_dim = 512 + state_dim = 128 + num_heads = 2 + seq_len = 1024 + num_groups = 2 + expansion_factor = 1 + output_dim = input_dim + cache_dtype = dtype + + cfg = Mamba2MixerLayer.default_config().set( + input_dim=input_dim, + state_dim=state_dim, + num_heads=num_heads, + num_groups=num_groups, + expansion_factor=expansion_factor, + dtype=dtype, + cache_dtype=cache_dtype, + ) + + layer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + + if inference_mode: + # inference recurrence can return the ssd states for testing + layer.recurrence = layer.inference_recurrence + + inputs_data = jax.random.uniform( + jax.random.PRNGKey(1), [batch_size, seq_len, input_dim], dtype=dtype + ) + inputs = dict(query=inputs_data) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + + mamba2_cache = layer.init_states(target_batch_size=batch_size, target_max_len=seq_len) + self.assertEqual(mamba2_cache.x_conv_state.dtype, cache_dtype) + self.assertEqual(mamba2_cache.b_conv_state.dtype, cache_dtype) + self.assertEqual(mamba2_cache.c_conv_state.dtype, cache_dtype) + self.assertEqual(mamba2_cache.ssd_state.dtype, cache_dtype) + self.assertEqual(forward_outputs.data.dtype, dtype) + + inputs = dict(cache=mamba2_cache) + decoder_output = jnp.zeros(shape=[seq_len, batch_size, output_dim], dtype=dtype) + for t in range(seq_len): + inputs["query"] = inputs_data[:, t : t + 1, :] + (mamba2_cache, mamba2output), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(3), + inputs=inputs, + method="extend_step", + ) + inputs["cache"] = mamba2_cache + decoder_output = decoder_output.at[t].set(jnp.squeeze(mamba2output.data, axis=1)) + + decoder_output_transposed = jnp.transpose(decoder_output, [1, 0, 2]) + + if dtype == jnp.float32: + final_state_diff_tol = 1e-2 + output_tol = 1e-1 + else: + final_state_diff_tol = 1e-1 + output_tol = 2e0 + + if inference_mode: + forward_final_state = forward_outputs.ssd_state[:, :, -1] + final_state_diff = jnp.abs((forward_final_state - mamba2_cache.ssd_state)).max() + self.assertTrue(final_state_diff < final_state_diff_tol) + + # ssm output diff will get a bit amplified by the ffn layer + assert_allclose( + decoder_output_transposed, forward_outputs.data, atol=output_tol, rtol=output_tol + ) + + @parameterized.product(dtype=(jnp.float32, jnp.bfloat16)) + def test_prefill_states(self, dtype: jnp.dtype): + batch_size = 2 + input_dim = 512 + state_dim = 256 + num_heads = 4 + seq_len = 1024 + num_groups = 2 + expansion_factor = 2 + cache_dtype = jnp.float32 + + cfg = Mamba2MixerLayer.default_config().set( + input_dim=input_dim, + state_dim=state_dim, + num_heads=num_heads, + num_groups=num_groups, + expansion_factor=expansion_factor, + dtype=dtype, + cache_dtype=cache_dtype, + ) + + layer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + + # full forward pass as reference + inputs_data = jax.random.uniform( + jax.random.PRNGKey(1), [batch_size, seq_len, input_dim], dtype=dtype + ) + inputs = dict(query=inputs_data) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + + # prefill stage + time_step = jnp.arange(batch_size) + (initial_state, initial_output), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(3), + inputs=dict(time_step=time_step, query=inputs_data), + method="prefill_states", + ) + self.assertTrue(initial_state.x_conv_state.dtype, cache_dtype) + self.assertTrue(initial_state.b_conv_state.dtype, cache_dtype) + self.assertTrue(initial_state.c_conv_state.dtype, cache_dtype) + self.assertTrue(initial_state.ssd_state.dtype, cache_dtype) + self.assertTrue(initial_output.data.dtype, dtype) + + time_step_mask = (jnp.arange(seq_len) < time_step[:, None]).astype(dtype) + decoder_output = initial_output.data * time_step_mask[..., None] + + inputs = dict(cache=initial_state) + while jnp.any(time_step < seq_len): + inputs["query"] = jnp.take_along_axis( + inputs_data, time_step[:, None, None], axis=1, mode="clip" + ) + (updated_state, outputs), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(4), + inputs=inputs, + method="extend_step", + ) + inputs["cache"] = updated_state + + # [batch_size, 1, output_dim] + cur_outputs = outputs.data + + # [batch_size, seq_len, 1] + oh_indices = jax.nn.one_hot(time_step, seq_len, dtype=dtype)[..., None] + decoder_output = decoder_output + cur_outputs * oh_indices + + time_step = time_step + 1 + + assert_allclose(decoder_output, forward_outputs.data, atol=1e-1, rtol=1e-1) + + +@pytest.mark.skipif( + jax.default_backend() != "tpu" or jax.device_count() != 4, + reason="Test requires four chips, e.g., one v5p gcp instance.", +) +class JambaMamba2BlockTest(TestCase): + @classmethod + def setup_class(cls): + devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) + global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) + new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + thread_resources.env = new_env + + @classmethod + def teardown_class(cls): + init_env = ResourceEnv(physical_mesh=(), loops=()) + thread_resources.env = init_env + + @parameterized.product( + input_dim=[1024, 2048], + state_dim=[128, 256], + num_heads=[2, 4], + num_groups=[2, 4], + dtype=[jnp.float32, jnp.bfloat16], + ) + def forward( + self, input_dim: int, state_dim: int, num_heads: int, num_groups: int, dtype: jnp.dtype + ): + mamba2block_cfg = JambaMamba2Block.default_config().set( + name="test", + input_dim=input_dim, + state_dim=state_dim, + num_heads=num_heads, + num_groups=num_groups, + dtype=dtype, + ) + mamba2block_cfg.feed_forward = mamba2block_cfg.feed_forward.set(hidden_dim=2 * input_dim) + mamba2block = mamba2block_cfg.instantiate(parent=None) + mamba2block_params = mamba2block.initialize_parameters_recursively( + prng_key=jax.random.PRNGKey(0) + ) + + batch_size, tgt_len = 2, 1024 + x = jax.random.uniform(jax.random.PRNGKey(1), [batch_size, tgt_len, input_dim], dtype=dtype) + + outputs, _ = F( + mamba2block, + inputs=(x,), + state=mamba2block_params, + is_training=True, + prng_key=jax.random.PRNGKey(2), + ) + + self.assertEqual(outputs.data.shape, x.shape) + self.assertEqual(outputs.data.dtype, x.dtype) + + @parameterized.product( + batch_size=[2, 4], + input_dim=[1024, 2048], + seq_len=[1024, 2048], + state_dim=[128, 256], + num_heads=[2, 4], + num_groups=[2, 4], + dtype=[jnp.float32, jnp.bfloat16], + ) + def extend_step( + self, + batch_size: int, + input_dim: int, + seq_len: int, + state_dim: int, + num_heads: int, + num_groups: int, + dtype: jnp.dtype, + ): + cfg = JambaMamba2Block.default_config().set( + input_dim=input_dim, + state_dim=state_dim, + num_heads=num_heads, + num_groups=num_groups, + dtype=dtype, + ) + cfg.feed_forward = cfg.feed_forward.set(hidden_dim=2 * input_dim) + layer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + + inputs_data = jax.random.normal( + jax.random.PRNGKey(1), [batch_size, seq_len, input_dim], dtype=dtype + ) + inputs = dict(data=inputs_data) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + + init_state = layer.init_states(target_batch_size=batch_size, target_max_len=seq_len) + self.assertEqual(init_state["mamba_block"].x_conv_state.dtype, dtype) + self.assertEqual(init_state["mamba_block"].b_conv_state.dtype, dtype) + self.assertEqual(init_state["mamba_block"].c_conv_state.dtype, dtype) + self.assertEqual(init_state["mamba_block"].ssd_state.dtype, dtype) + + inputs = dict(cached_states=init_state) + decoder_output = jnp.zeros(shape=[seq_len, batch_size, input_dim]) + for t in range(seq_len): + inputs["data"] = inputs_data[:, t : t + 1, :] + extend_step_output, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(3), + inputs=inputs, + method="extend_step", + ) + inputs["cached_states"] = extend_step_output[0] + decoder_output = decoder_output.at[t].set( + jnp.squeeze(extend_step_output[1].data, axis=1) + ) + + decoder_output_transposed = jnp.transpose(decoder_output, [1, 0, 2]) + assert_allclose(decoder_output_transposed, forward_outputs.data, atol=1e-1, rtol=1e-1) + + @parameterized.product( + batch_size=[2], + input_dim=[1024], + state_dim=[256], + num_heads=[2], + seq_len=[1024], + num_groups=[2], + dtype=[jnp.float32, jnp.bfloat16], + ) + def test_prefill_states( + self, + batch_size: int, + input_dim: int, + seq_len: int, + state_dim: int, + num_heads: int, + num_groups: int, + dtype: jnp.dtype, + ): + cfg = JambaMamba2Block.default_config().set( + input_dim=input_dim, + state_dim=state_dim, + num_heads=num_heads, + num_groups=num_groups, + dtype=dtype, + ) + cfg.feed_forward = cfg.feed_forward.set(hidden_dim=2 * input_dim) + layer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + + inputs_data = jax.random.normal( + jax.random.PRNGKey(1), [batch_size, seq_len, input_dim], dtype=dtype + ) + inputs = dict(data=inputs_data) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + + time_step = jnp.arange(batch_size) + (initial_state, initial_output), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(3), + inputs=dict(time_step=time_step, data=inputs_data), + method="prefill_states", + ) + + time_step_mask = (jnp.arange(seq_len) < time_step[:, None]).astype(dtype) + decoder_output = initial_output.data * time_step_mask[..., None] + + inputs = dict(cached_states=initial_state) + for _ in range(seq_len): + inputs["data"] = jnp.take_along_axis( + inputs_data, time_step[:, None, None], axis=1, mode="clip" + ) + (updated_state, outputs), _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(3), + inputs=inputs, + method="extend_step", + ) + inputs["cached_states"] = updated_state + + # [batch_size, 1, output_dim] + cur_outputs = outputs.data + + # [batch_size, seq_len, 1] + oh_indices = jax.nn.one_hot(time_step, seq_len, dtype=dtype)[..., None] + decoder_output = decoder_output + cur_outputs * oh_indices + + time_step = time_step + 1 + + assert_allclose(decoder_output, forward_outputs.data, atol=1e-1, rtol=1e-1) + + +@pytest.mark.skipif( + jax.default_backend() != "gpu" or not MAMBA_INSTALLED, + reason="Test requires mamba_ssm to be installed on a GPU machine", +) +class GPUMamba2MixerLayerTest(TestCase): + @classmethod + def setup_class(cls): + num_devices = jax.device_count() + devices = mesh_utils.create_device_mesh((1, 1, 1, 1, num_devices)) + global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) + new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + thread_resources.env = new_env + + @classmethod + def teardown_class(cls): + init_env = ResourceEnv(physical_mesh=(), loops=()) + thread_resources.env = init_env + + @parameterized.product( + batch_size=[2, 4], + seq_len=[512, 1024], + expansion_factor=[1, 2], + ) + def test_forward(self, batch_size: int, seq_len: int, expansion_factor: int): + if self.mamba_ssm is None: + self.skipTest("mamba_ssm needs to be installed on a GPU machine for testing") + + d_model, d_state, expansion_factor = 512, 128, 2 + head_dim, num_groups = 128, 4 + d_inner = expansion_factor * d_model + num_heads = d_inner // head_dim + + def _j2t(param): + """Convert jax array to torch tensor.""" + return torch.from_numpy(np.array(param)) + + inputs_data = jax.random.normal(jax.random.PRNGKey(1), [batch_size, seq_len, d_model]) + inputs_torch = _j2t(inputs_data) + + # pylint: disable=undefined-variable + ref_model = Mamba2Simple( + d_model=d_model, + d_state=d_state, + headdim=head_dim, + ngroups=num_groups, + expand=expansion_factor, + use_mem_eff_path=False, + ) + + jax_model = ( + Mamba2MixerLayer.default_config() + .set( + input_dim=d_model, + state_dim=d_state, + num_groups=num_groups, + num_heads=num_heads, + expansion_factor=expansion_factor, + bc_norm=None, + dtype=jnp.float32, + cache_dtype=jnp.float32, + ) + .set(name="test") + .instantiate(parent=None) + ) + jax_params = jax_model.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + jax_params = cast_floats(jax_params, to_dtype=jnp.float32) + + # use linearscan kernel which is already tested against pallas kernel. + jax_model.recurrence = jax_model.inference_recurrence + + # copying the weights from the jax model to the ref model + inputs = dict(query=inputs_data) + forward_outputs, _ = F( + jax_model, + state=jax_params, + is_training=True, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + jax_output_np = np.array(forward_outputs.data) + + # in_proj <-> [z, x, B, C, dt] + xz_w = _j2t(jax_params["xz_proj"]["weight"]) # [d_model, 2, d_inner] + bc_w = _j2t(jax_params["bc_proj"]["weight"]) # [d_model, 2, dk] + dt_w = _j2t(jax_params["dt_proj"]["weight"]) # [d_model, num_heads] + zxBCdt_w = torch.cat([xz_w[:, 1], xz_w[:, 0], bc_w[:, 0], bc_w[:, 1], dt_w], dim=1) + ref_model.in_proj.weight.data.copy_(zxBCdt_w.T) + + # conv1d <-> [x_conv, b_conv, c_conv] + x_conv_w = _j2t(jax_params["x_conv"]["weight"]) + x_conv_bias = _j2t(jax_params["x_conv"]["bias"]) + b_conv_w = _j2t(jax_params["b_conv"]["weight"]) + b_conv_bias = _j2t(jax_params["b_conv"]["bias"]) + c_conv_w = _j2t(jax_params["c_conv"]["weight"]) + c_conv_bias = _j2t(jax_params["c_conv"]["bias"]) + xbc_conv_w = torch.cat([x_conv_w, b_conv_w, c_conv_w], dim=2) + xbc_conv_bias = torch.cat([x_conv_bias, b_conv_bias, c_conv_bias], dim=0) + ref_model.conv1d.weight.data.copy_(xbc_conv_w.T) + ref_model.conv1d.bias.data.copy_(xbc_conv_bias) + + # out_proj <-> out_proj + out_w = _j2t(jax_params["out_proj"]["weight"]) + ref_model.out_proj.weight.data.copy_(out_w.T) + + # A_log <-> llog_a + a_w = _j2t(jax_params["llog_a"]) # [1, num_heads, 1] + ref_model.A_log.data.copy_(a_w[0, :, 0]) + + # dt_bias <-> dt_bias + dt_bias = _j2t(jax_params["dt_bias"]) + ref_model.dt_bias.data.copy_(dt_bias) + + # D <-> d + d = _j2t(jax_params["d"]) # [1, 1, num_heads, 1] + ref_model.D.data.copy_(d[0, 0, :, 0]) + + # norm <-> pre_out_proj_norm + norm_scale = _j2t(jax_params["pre_out_proj_norm"]["scale"]) + ref_model.norm.weight.data.copy_(norm_scale) + + device = "cuda:0" + ref_model = ref_model.to(device) + inputs_torch = inputs_torch.to(device) + torch_output = ref_model(inputs_torch) + torch_output_np = torch_output.cpu().detach().numpy() + + assert_allclose(torch_output_np, jax_output_np, atol=1e-2, rtol=1e-2) From 0bb4ac305776bc03d74ccbfd0e26f10f8000f7de Mon Sep 17 00:00:00 2001 From: bailin_wang Date: Wed, 4 Dec 2024 14:14:55 +0800 Subject: [PATCH 2/3] merge --- axlearn/common/ssm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index c99db1880..9211da274 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -53,7 +53,8 @@ ssd_linear_scan_w_hidden_states, ssd_linear_scan_w_timestep, ) -from axlearn.common.utils import Nested, Tensor, with_sharding_constraint +from axlearn.common.utils import Nested, Tensor, TensorSpec, with_sharding_constraint + class MambaDtProjInitializer(Initializer): """Initializes the weight and bias of a Linear layer as described in the Mamba paper.""" From 678e3928fc96b214098ffbb4e34296eed005397a Mon Sep 17 00:00:00 2001 From: bailin_wang Date: Thu, 12 Dec 2024 10:43:16 +0800 Subject: [PATCH 3/3] unify init and prefill --- axlearn/common/ssm.py | 72 +++++++++---------- axlearn/common/ssm_kernels/ssd_kernels.py | 5 +- .../common/ssm_kernels/ssd_kernels_test.py | 6 +- axlearn/common/ssm_test.py | 33 +++++---- 4 files changed, 57 insertions(+), 59 deletions(-) diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index 9211da274..d9ef14287 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -44,7 +44,8 @@ ) from axlearn.common.base_layer import BaseLayer, ParameterSpec from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class -from axlearn.common.layers import Conv1D, GroupNorm, Linear, MultiLinear, NormType, RMSNorm +from axlearn.common.convolution import Conv1D +from axlearn.common.layers import GroupNorm, Linear, MultiLinear, NormType, RMSNorm from axlearn.common.module import Module from axlearn.common.param_init import FanAxes, Initializer, Shape, constant_initializer, uniform from axlearn.common.ssm_kernels.mamba_kernels import compute_mamba_scan @@ -78,7 +79,8 @@ class Config(Initializer.Config): # Clamp dt projection's bias to at least this value. dt_init_floor: float = 1e-4 # One of 'random' or 'constant'. - # If 'constant', the projection matrix is initialized to a constant; otherwise, random. # pylint: disable=C0301 + # If 'constant', the projection matrix is initialized to a constant; + # otherwise, use random initialization. mode: str = "random" def initialize( @@ -1620,7 +1622,7 @@ def forward( Unlike the Mamba recurrence, discretizations of parameters are not explicitly computed. More specifically, \bar a (i.e., discretized a) is computed outside the kernel whereas - \bar b is computed implicitly via adding the delta term to the input + \bar b is computed implicitly via multiplying the delta term to the input x -- \bar x = x * delta. See the following line from the official repo for details - https://github.com/state-spaces/mamba/blob/8ffd905c91d207f5c0cc84fc2a2fb748655094f0/mamba_ssm/modules/ssd_minimal.py#L103 @@ -2077,7 +2079,7 @@ def _forward_for_mode( Args: mode: {FORWARD, INIT_STATES, EXTEND_STEP} query: A Tensor of shape [batch_size, seq_len, input_dim] - cache: Optional NestedTensor as produced by `prefill_states`. + cache: Optional NestedTensor as produced by `init_states`. Returns: An optional cache, depending on `mode`. @@ -2093,7 +2095,7 @@ def _forward_for_mode( ) elif mode == ForwardMode.INIT_STATES: assert cache is not None - mamba_cache, mamba_output = self.prefill_states( + mamba_cache, mamba_output = self.init_states( time_step=cache, query=query, ) @@ -2137,49 +2139,19 @@ def _full_sequence_forward( mamba_output = Mamba2MixerLayer.Mamba2Output(data=output, ssd_state=ssd_state) return mamba_cache, mamba_output - # pylint: disable=unused-argument - def init_states(self, *, target_batch_size: int, target_max_len: int) -> Mamba2Cache: - """Initializes cache for autoregressive cached decoding. - - Args: - batch_size: The batch size of the target to be decoded. - target_max_len: The maximum length of the target to be decoded. - - Returns: - A Mamba2Cache instance. - """ - cfg = self.config - dtype = cfg.cache_dtype or cfg.dtype - cache = Mamba2MixerLayer.Mamba2Cache( - x_conv_state=jnp.zeros( - (target_batch_size, cfg.x_conv.window, self.inner_dim), dtype=dtype - ), - b_conv_state=jnp.zeros( - (target_batch_size, cfg.b_conv.window, self.bc_state_dim), dtype=dtype - ), - c_conv_state=jnp.zeros( - (target_batch_size, cfg.c_conv.window, self.bc_state_dim), dtype=dtype - ), - ssd_state=jnp.zeros( - (target_batch_size, cfg.num_heads, cfg.state_dim, self.head_dim), dtype=dtype - ), - time_step=jnp.zeros(target_batch_size, dtype=jnp.int32), - ) - return cache - - def prefill_states( + def init_states( self, *, time_step: Tensor, - query: Tensor, - ) -> tuple[Mamba2Cache, Mamba2Output]: + query: Union[Tensor, TensorSpec], + ) -> tuple[Mamba2Cache, Union[None, Mamba2Output]]: """Initializes cache for autoregressive cached decoding. It refines the mamba state returned from `_full_sequence_forward` to the state at `time_step` for the incremental decoding later. Args: - time_step: A Tensor of shape [batch_size]. Each value is an index into the length - dimension indicating where decoding will start from. + time_step: An optional Tensor of shape [batch_size]. Each value is an index into the + length dimension indicating where decoding will start from. query: Tensor of shape [batch_size, target_length, target_dim] corresponding to input vector up to `time_step` indices. For batch index `i`, only `inputs[i, :time_step[i], ...]` will affect subsequent decoding. @@ -2191,6 +2163,26 @@ def prefill_states( cfg = self.config cache_dtype = cfg.cache_dtype or cfg.dtype + if time_step is None: + target_batch_size = query.shape[0] + init_state = Mamba2MixerLayer.Mamba2Cache( + x_conv_state=jnp.zeros( + (target_batch_size, cfg.x_conv.window, self.inner_dim), dtype=cache_dtype + ), + b_conv_state=jnp.zeros( + (target_batch_size, cfg.b_conv.window, self.bc_state_dim), dtype=cache_dtype + ), + c_conv_state=jnp.zeros( + (target_batch_size, cfg.c_conv.window, self.bc_state_dim), dtype=cache_dtype + ), + ssd_state=jnp.zeros( + (target_batch_size, cfg.num_heads, cfg.state_dim, self.head_dim), + dtype=cache_dtype, + ), + time_step=jnp.zeros(target_batch_size, dtype=jnp.int32), + ) + return init_state, None + x, z = self._project_input(query) x_conv = jax.nn.silu(self.x_conv(x)) x_conv_w_head = rearrange(x_conv, "b s (h d) -> b h s d", d=self.head_dim) diff --git a/axlearn/common/ssm_kernels/ssd_kernels.py b/axlearn/common/ssm_kernels/ssd_kernels.py index d05df3ee4..e1f8f7d7f 100644 --- a/axlearn/common/ssm_kernels/ssd_kernels.py +++ b/axlearn/common/ssm_kernels/ssd_kernels.py @@ -358,7 +358,6 @@ def _ssd_backward_dkv_chunk_loop_body(t: int, dh_carry_ref: Tensor): k_block = k_ref[subchunk_idx, :].astype(jnp.float32) v_block = v_ref[subchunk_idx, :].astype(jnp.float32) do_block = mutable_do_ref[subchunk_idx, :].astype(jnp.float32) - causal_mask = jnp.tril(jnp.ones((subchunk_size, subchunk_size)), k=0).astype(jnp.float32) lambda_block = cum_log_alpha_ref[subchunk_idx, :] gamma_block = gamma_ref[subchunk_idx] @@ -596,14 +595,14 @@ def ssd_linear_scan( q = repeat(q, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) k = repeat(k, "b ng l dk -> b (ng nhg) l dk", nhg=num_head_per_group) - # ITt's more convenient for vmap to have internal states of size [dv, dk] + # It's more convenient for vmap to have internal states of size [dv, dk] if h0 is None: h0 = jnp.zeros((bs, nh, dv, dk), dtype=jnp.float32) else: # to be consistent with pallas api, h0 is in dk x dv as input h0 = rearrange(h0, "b h dk dv -> b h dv dk") - # All inputs are upcasted to float32, making this function a good reference funciton to + # All inputs are upcasted to float32, making this function a good reference function to # test pallas kernel's numerical precision in the case of bf16 inputs. dtype = q.dtype if dtype == jnp.bfloat16: diff --git a/axlearn/common/ssm_kernels/ssd_kernels_test.py b/axlearn/common/ssm_kernels/ssd_kernels_test.py index 2cc5c05ad..8a7a84366 100644 --- a/axlearn/common/ssm_kernels/ssd_kernels_test.py +++ b/axlearn/common/ssm_kernels/ssd_kernels_test.py @@ -345,8 +345,10 @@ def ssd_forward_and_backward(self, batch_size, num_heads, seq_len, dk, dv, dtype class ShardSSDPallasKernelTest(TestCase): - # this test only works for four devices - @pytest.mark.skipif(jax.device_count() != 4, reason="Requires 4 devices") + # This test only works for four devices + @pytest.mark.skipif( + jax.default_backend() != "gpu" or jax.device_count() != 4, reason="Requires 4 GPU devices" + ) def test_sharded_ssd_wo_sp(self): batch, ngroups, nheads, seqlen, k_head_dim, v_head_dim = 8, 4, 4, 1024, 256, 128 dtype = "float32" diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index a8e34af27..74c014666 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -15,7 +15,6 @@ """Tests Mamba/Mamba2 and Jamba implementations.""" - import math from typing import Optional @@ -31,7 +30,7 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig -from axlearn.common.attention import make_causal_biases +from axlearn.common.attention_bias import make_causal_biases from axlearn.common.config import InstantiableConfig from axlearn.common.module import functional as F from axlearn.common.ssm import ( @@ -969,12 +968,12 @@ class Mamba2RecurrenceTest(TestCase): def setup_class(cls): devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) - new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + new_env = ResourceEnv(physical_mesh=global_mesh) thread_resources.env = new_env @classmethod def teardown_class(cls): - init_env = ResourceEnv(physical_mesh=(), loops=()) + init_env = ResourceEnv(physical_mesh=()) thread_resources.env = init_env def test_ssd_parameterization(self): @@ -1030,12 +1029,12 @@ class Mamba2MixerLayerTest(TestCase): def setup_class(cls): devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) - new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + new_env = ResourceEnv(physical_mesh=global_mesh) thread_resources.env = new_env @classmethod def teardown_class(cls): - init_env = ResourceEnv(physical_mesh=(), loops=()) + init_env = ResourceEnv(physical_mesh=()) thread_resources.env = init_env @parameterized.product( @@ -1083,7 +1082,10 @@ def test_extend_step(self, dtype: jnp.dtype, inference_mode: bool): inputs=inputs, ) - mamba2_cache = layer.init_states(target_batch_size=batch_size, target_max_len=seq_len) + mamba2_cache, _ = layer.init_states( + time_step=None, + query=TensorSpec([batch_size, seq_len]), + ) self.assertEqual(mamba2_cache.x_conv_state.dtype, cache_dtype) self.assertEqual(mamba2_cache.b_conv_state.dtype, cache_dtype) self.assertEqual(mamba2_cache.c_conv_state.dtype, cache_dtype) @@ -1170,7 +1172,7 @@ def test_prefill_states(self, dtype: jnp.dtype): is_training=False, prng_key=jax.random.PRNGKey(3), inputs=dict(time_step=time_step, query=inputs_data), - method="prefill_states", + method="init_states", ) self.assertTrue(initial_state.x_conv_state.dtype, cache_dtype) self.assertTrue(initial_state.b_conv_state.dtype, cache_dtype) @@ -1217,12 +1219,12 @@ class JambaMamba2BlockTest(TestCase): def setup_class(cls): devices = mesh_utils.create_device_mesh((2, 1, 1, 1, 2)) global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) - new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + new_env = ResourceEnv(physical_mesh=global_mesh) thread_resources.env = new_env @classmethod def teardown_class(cls): - init_env = ResourceEnv(physical_mesh=(), loops=()) + init_env = ResourceEnv(physical_mesh=()) thread_resources.env = init_env @parameterized.product( @@ -1306,7 +1308,10 @@ def extend_step( inputs=inputs, ) - init_state = layer.init_states(target_batch_size=batch_size, target_max_len=seq_len) + init_state = layer.init_states( + time_step=None, + data=TensorSpec([batch_size, seq_len]), + ) self.assertEqual(init_state["mamba_block"].x_conv_state.dtype, dtype) self.assertEqual(init_state["mamba_block"].b_conv_state.dtype, dtype) self.assertEqual(init_state["mamba_block"].c_conv_state.dtype, dtype) @@ -1382,7 +1387,7 @@ def test_prefill_states( is_training=False, prng_key=jax.random.PRNGKey(3), inputs=dict(time_step=time_step, data=inputs_data), - method="prefill_states", + method="init_states", ) time_step_mask = (jnp.arange(seq_len) < time_step[:, None]).astype(dtype) @@ -1425,12 +1430,12 @@ def setup_class(cls): num_devices = jax.device_count() devices = mesh_utils.create_device_mesh((1, 1, 1, 1, num_devices)) global_mesh = Mesh(devices, axis_names=("data", "expert", "fsdp", "seq", "model")) - new_env = ResourceEnv(physical_mesh=global_mesh, loops=()) + new_env = ResourceEnv(physical_mesh=global_mesh) thread_resources.env = new_env @classmethod def teardown_class(cls): - init_env = ResourceEnv(physical_mesh=(), loops=()) + init_env = ResourceEnv(physical_mesh=()) thread_resources.env = init_env @parameterized.product(