-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] factoring out MambaMixer out of Jamba (#8993)
Signed-off-by: mzusman <[email protected]>
- Loading branch information
Showing
3 changed files
with
245 additions
and
374 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import torch | ||
from torch import nn | ||
from torch.nn.parameter import Parameter | ||
|
||
from vllm.attention.backends.abstract import AttentionMetadata | ||
from vllm.distributed.parallel_state import ( | ||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | ||
from vllm.model_executor.custom_op import CustomOp | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||
MergedColumnParallelLinear, | ||
RowParallelLinear) | ||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( | ||
causal_conv1d_fn, causal_conv1d_update) | ||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( | ||
selective_scan_fn, selective_state_update) | ||
from vllm.model_executor.models.mamba_cache import MambaCacheParams | ||
from vllm.model_executor.utils import set_weight_attrs | ||
|
||
|
||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer | ||
@CustomOp.register("mamba_mixer") | ||
class MambaMixer(CustomOp): | ||
""" | ||
Compute ∆, A, B, C, and D the state space parameters and compute | ||
the `contextualized_states`. A, D are input independent | ||
(see Mamba paper [1] Section 3.5.2 "Interpretation of A" | ||
for why A isn't selective) ∆, B, C are input-dependent | ||
(this is a key difference between Mamba and the linear time | ||
invariant S4, and is why Mamba is called | ||
**selective** state spaces) | ||
""" | ||
|
||
def __init__(self, | ||
hidden_size: int, | ||
ssm_state_size: int, | ||
conv_kernel_size: int, | ||
intermediate_size: int, | ||
time_step_rank: int, | ||
use_conv_bias: bool, | ||
use_bias: bool, | ||
use_rms_norm: bool, | ||
rms_norm_eps: float = 1e-5, | ||
activation="silu"): | ||
super().__init__() | ||
self.time_step_rank = time_step_rank | ||
self.ssm_state_size = ssm_state_size | ||
self.use_rms_norm = use_rms_norm | ||
self.activation = activation | ||
|
||
self.conv1d = ColumnParallelLinear( | ||
input_size=conv_kernel_size, | ||
output_size=intermediate_size, | ||
bias=use_conv_bias, | ||
) | ||
# unsqueeze to fit conv1d weights shape into the linear weights shape. | ||
# Can't do this in `weight_loader` since it already exists in | ||
# `ColumnParallelLinear` and `set_weight_attrs` | ||
# doesn't allow to override it | ||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) | ||
|
||
self.in_proj = MergedColumnParallelLinear(hidden_size, | ||
[intermediate_size] * 2, | ||
bias=use_bias) | ||
# selective projection used to make dt, B and C input dependent | ||
self.x_proj = RowParallelLinear( | ||
intermediate_size, | ||
time_step_rank + ssm_state_size * 2, | ||
bias=False, | ||
) | ||
# time step projection (discretization) - | ||
# In the forward we need to apply dt_proj without the bias, | ||
# as the bias is added in the selective scan kernel. | ||
self.dt_proj = ColumnParallelLinear(time_step_rank, | ||
intermediate_size, | ||
bias=True, | ||
skip_bias_add=True) | ||
|
||
def weight_loader(param: Parameter, loaded_weight: torch.Tensor): | ||
tp_rank = get_tensor_model_parallel_rank() | ||
tp_size = get_tensor_model_parallel_world_size() | ||
param.data.copy_( | ||
loaded_weight.data.split(loaded_weight.shape[0] // tp_size, | ||
dim=0)[tp_rank]) | ||
|
||
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): | ||
weight_loader(param, -torch.exp(loaded_weight.float())) | ||
|
||
tp_size = get_tensor_model_parallel_world_size() | ||
self.A = nn.Parameter( | ||
torch.empty( | ||
intermediate_size // tp_size, | ||
ssm_state_size, | ||
dtype=torch.float32, | ||
)) | ||
self.D = nn.Parameter(torch.ones(intermediate_size // tp_size)) | ||
|
||
set_weight_attrs(self.D, {"weight_loader": weight_loader}) | ||
set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) | ||
|
||
self.out_proj = RowParallelLinear( | ||
intermediate_size, | ||
hidden_size, | ||
bias=use_bias, | ||
input_is_parallel=True, | ||
) | ||
|
||
self.dt_layernorm = RMSNorm(time_step_rank, | ||
eps=rms_norm_eps) if use_rms_norm else None | ||
|
||
self.b_layernorm = RMSNorm(ssm_state_size, | ||
eps=rms_norm_eps) if use_rms_norm else None | ||
|
||
self.c_layernorm = RMSNorm(ssm_state_size, | ||
eps=rms_norm_eps) if use_rms_norm else None | ||
|
||
def forward_native(self, hidden_states: torch.Tensor, | ||
attn_metadata: AttentionMetadata, | ||
conv_state: torch.Tensor, ssm_state: torch.Tensor): | ||
pass | ||
|
||
def forward_cuda(self, hidden_states: torch.Tensor, | ||
attn_metadata: AttentionMetadata, | ||
mamba_cache_params: MambaCacheParams): | ||
|
||
# 1. Gated MLP's linear projection | ||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) | ||
hidden_states, gate = projected_states.chunk(2, dim=-2) | ||
|
||
# 2. Convolution sequence transformation | ||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), | ||
self.conv1d.weight.size(2)) | ||
|
||
if attn_metadata.query_start_loc is not None \ | ||
and attn_metadata.context_lens_tensor is not None: | ||
# |---------- N-1 iteration --------| | ||
# |---------------- N iteration ---------------------| | ||
# |- tokenA -|......................|-- newTokens ---| | ||
# |---------- context_len ----------| | ||
# |-------------------- seq_len ---------------------| | ||
# |-- query_len ---| | ||
hidden_states = causal_conv1d_fn( | ||
hidden_states, | ||
conv_weights, | ||
self.conv1d.bias, | ||
activation=self.activation, | ||
conv_states=mamba_cache_params.conv_state, | ||
has_initial_state=attn_metadata.context_lens_tensor > 0, | ||
cache_indices=mamba_cache_params.state_indices_tensor, | ||
query_start_loc=attn_metadata.query_start_loc) | ||
else: | ||
hidden_states = causal_conv1d_update( | ||
hidden_states.transpose(0, 1), | ||
mamba_cache_params.conv_state, | ||
conv_weights, | ||
self.conv1d.bias, | ||
self.activation, | ||
conv_state_indices=mamba_cache_params.state_indices_tensor) | ||
hidden_states = hidden_states.transpose(0, 1) | ||
|
||
# 3. State Space Model sequence transformation | ||
# 3.a. input varying initialization of time_step, B and C | ||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] | ||
|
||
time_step, B, C = torch.split( | ||
ssm_parameters, | ||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size], | ||
dim=-1, | ||
) | ||
if self.use_rms_norm: | ||
assert self.dt_layernorm is not None | ||
assert self.b_layernorm is not None | ||
assert self.c_layernorm is not None | ||
time_step = self.dt_layernorm(time_step.contiguous()) | ||
B = self.b_layernorm(B.contiguous()) | ||
C = self.c_layernorm(C.contiguous()) | ||
|
||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) | ||
# 3.c perform the recurrence y ← SSM(A, B, C)(x) | ||
time_proj_bias = (self.dt_proj.bias.float() if hasattr( | ||
self.dt_proj, "bias") else None) | ||
|
||
if attn_metadata.query_start_loc is not None \ | ||
and attn_metadata.context_lens_tensor is not None: | ||
scan_outputs = selective_scan_fn( | ||
hidden_states, | ||
mamba_cache_params.ssm_state, | ||
discrete_time_step, | ||
self.A, | ||
B.transpose(-2, -1), | ||
C.transpose(-2, -1), | ||
self.D.float(), | ||
gate, | ||
time_proj_bias, | ||
delta_softplus=True, | ||
cache_indices=mamba_cache_params.state_indices_tensor, | ||
has_initial_state=attn_metadata.context_lens_tensor > 0, | ||
query_start_loc=attn_metadata.query_start_loc) | ||
else: | ||
scan_outputs = selective_state_update( | ||
mamba_cache_params.ssm_state, | ||
hidden_states.transpose(0, 1), | ||
discrete_time_step.transpose(0, 1), | ||
self.A, | ||
B, | ||
C, | ||
self.D, | ||
gate.transpose(0, 1), | ||
time_proj_bias, | ||
dt_softplus=True, | ||
state_batch_indices=mamba_cache_params.state_indices_tensor) | ||
scan_outputs = scan_outputs.transpose(0, 1) | ||
|
||
# 4. Final linear projection | ||
contextualized_states = self.out_proj(scan_outputs.transpose(-2, | ||
-1))[0] | ||
return contextualized_states |
Oops, something went wrong.