diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index b5fa83b437ac4..1cb261d1773e9 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -87,6 +87,11 @@ Text Generation - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. - - ✅︎ + * - :code:`FalconMambaForCausalLM` + - FalconMamba + - :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc. + - ✅︎ + - * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index c27bf6a60a4f4..2dc231c595ffa 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -10,7 +10,7 @@ from ...utils import check_outputs_equal -MODELS = ["state-spaces/mamba-130m-hf"] +MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"] # Use lower-level interfaces to create this greedy generator, as mamba will diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 10fae84dab723..30b43f375dd5c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -27,7 +27,6 @@ def __init__( self.variance_epsilon = eps self.variance_size_override = (None if var_hidden_size == hidden_size else var_hidden_size) - self.weight = nn.Parameter(torch.ones(hidden_size)) def forward_native( diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 7f2efb9895f25..9f4f391a6682e 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -22,7 +22,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, @@ -59,7 +59,7 @@ def __init__(self, config: MambaConfig, layer_idx): self.conv_kernel_size = config.conv_kernel self.intermediate_size = config.intermediate_size self.time_step_rank = int(config.time_step_rank) - + self.is_falcon_mamba = config.model_type == "falcon_mamba" self.conv1d = ColumnParallelLinear( input_size=self.conv_kernel_size, output_size=self.intermediate_size, @@ -109,6 +109,13 @@ def __init__(self, config: MambaConfig, layer_idx): input_is_parallel=True, ) self.activation = config.hidden_act + if self.is_falcon_mamba: + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=config.mixer_rms_eps) + self.b_layernorm = RMSNorm(self.ssm_state_size, + eps=config.mixer_rms_eps) + self.c_layernorm = RMSNorm(self.ssm_state_size, + eps=config.mixer_rms_eps) def forward(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, @@ -158,8 +165,12 @@ def forward(self, hidden_states: torch.Tensor, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1, ) - - # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. + # Note that Jamba and FalconMamba normalizes B, C, and time_step here + # but Mamba doesn't. + if self.is_falcon_mamba: + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) # 3.c perform the recurrence y ← SSM(A, B, C)(x) @@ -213,11 +224,9 @@ def __init__(self, super().__init__() self.layer_idx = layer_idx self.config = config + self.is_falcon_mamba = config.model_type == "falcon_mamba" self.mixer = MambaMixer(config, layer_idx) - self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) def forward( self, @@ -319,8 +328,18 @@ def __init__( self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - - self.lm_head = self.backbone.embeddings + if config.tie_word_embeddings: + self.lm_head = self.backbone.embeddings + else: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) # Used to track and store by the Mamba cache between steps. self.mamba_cache: Optional[MambaCacheManager] = None @@ -398,7 +417,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") - # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f442ce0f63e3e..2a04ece24c8bd 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -53,6 +53,7 @@ # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), + "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),