Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] FalconMamba Support #9325

Merged
merged 10 commits into from
Oct 21, 2024
Merged
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ Text Generation
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
-
- ✅︎
* - :code:`FalconMambaForCausalLM`
- FalconMamba
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
- ✅︎
-
* - :code:`GemmaForCausalLM`
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
Expand Down
2 changes: 1 addition & 1 deletion tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ...utils import check_outputs_equal

MODELS = ["state-spaces/mamba-130m-hf"]
MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"]


# Use lower-level interfaces to create this greedy generator, as mamba will
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: best not to introduce whitespace-only changes to files

self.weight = nn.Parameter(torch.ones(hidden_size))

def forward_native(
Expand Down
38 changes: 28 additions & 10 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState,
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self, config: MambaConfig, layer_idx):
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = int(config.time_step_rank)

self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.intermediate_size,
Expand Down Expand Up @@ -109,6 +109,13 @@ def __init__(self, config: MambaConfig, layer_idx):
input_is_parallel=True,
)
self.activation = config.hidden_act
if self.is_falcon_mamba:
self.dt_layernorm = RMSNorm(self.time_step_rank,
eps=config.mixer_rms_eps)
self.b_layernorm = RMSNorm(self.ssm_state_size,
eps=config.mixer_rms_eps)
self.c_layernorm = RMSNorm(self.ssm_state_size,
eps=config.mixer_rms_eps)

def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
Expand Down Expand Up @@ -158,8 +165,12 @@ def forward(self, hidden_states: torch.Tensor,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)

# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
# Note that Jamba and FalconMamba normalizes B, C, and time_step here
# but Mamba doesn't.
if self.is_falcon_mamba:
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())

discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
Expand Down Expand Up @@ -213,11 +224,9 @@ def __init__(self,
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.mixer = MambaMixer(config, layer_idx)

self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)

def forward(
self,
Expand Down Expand Up @@ -319,8 +328,18 @@ def __init__(
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

self.lm_head = self.backbone.embeddings
if config.tie_word_embeddings:
self.lm_head = self.backbone.embeddings
else:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)

# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
Expand Down Expand Up @@ -398,7 +417,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
Expand Down