From 465dee3be9d09f4abc0b8290950e7b9b81d3e1ed Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Sun, 13 Oct 2024 11:59:51 +0000 Subject: [PATCH 1/9] add version 1 of FalconMamba --- vllm/model_executor/layers/layernorm.py | 9 +- vllm/model_executor/models/falcon_mamba.py | 480 +++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 3 files changed, 487 insertions(+), 3 deletions(-) create mode 100644 vllm/model_executor/models/falcon_mamba.py diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index d55f86056d17c..a994157278244 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -19,6 +19,7 @@ def __init__( hidden_size: int, eps: float = 1e-6, var_hidden_size: Optional[int] = None, + is_learnable: bool = True ) -> None: super().__init__() @@ -26,9 +27,11 @@ def __init__( self.variance_epsilon = eps self.variance_size_override = (None if var_hidden_size == hidden_size else var_hidden_size) - - self.weight = nn.Parameter(torch.ones(hidden_size)) - + if is_learnable: + self.register_parameter("weight", nn.Parameter(torch.ones(hidden_size))) + else: + self.register_buffer('weight', torch.ones(hidden_size), persistent = False) + def forward_native( self, x: torch.Tensor, diff --git a/vllm/model_executor/models/falcon_mamba.py b/vllm/model_executor/models/falcon_mamba.py new file mode 100644 index 0000000000000..cd3c50f3ad41f --- /dev/null +++ b/vllm/model_executor/models/falcon_mamba.py @@ -0,0 +1,480 @@ +# coding=utf-8 +"""PyTorch FalconMAMBA model.""" +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import FalconMambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.interfaces import (HasInnerState, + IsAttentionFree) +from vllm.model_executor.models.mamba_cache import MambaCacheManager +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +@dataclass +class FalconMambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + +# Adapted from transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaMixer +class FalconMambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: FalconMambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + self.rms_eps = config.mixer_rms_eps + + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=config.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=config.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=config.use_bias, + input_is_parallel=True, + ) + self.activation = config.hidden_act + + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=self.rms_eps, + is_learnable=False) + self.b_layernorm = RMSNorm(self.ssm_state_size, + eps=self.rms_eps, + is_learnable=False) + self.c_layernorm = RMSNorm(self.ssm_state_size, + eps=self.rms_eps, + is_learnable=False) + + def forward(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, conv_state: torch.Tensor, + ssm_state: torch.Tensor): + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + ) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + return contextualized_states + + +class FalconMambaDecoderLayer(nn.Module): + + def __init__(self, + config: FalconMambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mixer = FalconMambaMixer(config, layer_idx) + + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, + ssm_state) + return hidden_states, residual + + +class FalconMambaModel(nn.Module): + + def __init__( + self, + config: FalconMambaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append( + FalconMambaDecoderLayer(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + current_ssm_state = ssm_state[i] + current_conv_state = conv_state[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + conv_state=current_conv_state, + ssm_state=current_ssm_state, + ) + hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states + + +class FalconMambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embeddings": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: FalconMambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, + ) -> None: + assert not cache_config.enable_prefix_caching, \ + "Mamba does not support prefix caching" + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.backbone = FalconMambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + if self.mamba_cache is None: + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, self.config.num_hidden_layers, + max_batch_size, *self._get_mamba_cache_shape()) + + mamba_cache_tensors = self.mamba_cache.current_run_tensors( + input_ids, attn_metadata, **kwargs) + + hidden_states = self.backbone(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_tensors[0], + mamba_cache_tensors[1]) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + conv_state_shape = ( + self.config.intermediate_size // world_size, + self.config.conv_kernel - 1, + ) + temporal_state_shape = ( + self.config.intermediate_size // world_size, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8caaab9974666..2b717258bf83a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -53,6 +53,7 @@ # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), + "FalconMambaForCausalLM": ("falcon_mamba", "FalconMambaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), From e1a1a026143c46ab6baf81928d6640b7600a46ec Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Sun, 13 Oct 2024 12:04:33 +0000 Subject: [PATCH 2/9] code formatting --- vllm/model_executor/layers/layernorm.py | 21 +++++++++-------- vllm/model_executor/models/falcon_mamba.py | 27 +++++++++++----------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a994157278244..1306b2aa381fa 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -14,13 +14,11 @@ class RMSNorm(CustomOp): Refer to https://arxiv.org/abs/1910.07467 """ - def __init__( - self, - hidden_size: int, - eps: float = 1e-6, - var_hidden_size: Optional[int] = None, - is_learnable: bool = True - ) -> None: + def __init__(self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + is_learnable: bool = True) -> None: super().__init__() self.hidden_size = hidden_size @@ -28,10 +26,13 @@ def __init__( self.variance_size_override = (None if var_hidden_size == hidden_size else var_hidden_size) if is_learnable: - self.register_parameter("weight", nn.Parameter(torch.ones(hidden_size))) + self.register_parameter("weight", + nn.Parameter(torch.ones(hidden_size))) else: - self.register_buffer('weight', torch.ones(hidden_size), persistent = False) - + self.register_buffer('weight', + torch.ones(hidden_size), + persistent=False) + def forward_native( self, x: torch.Tensor, diff --git a/vllm/model_executor/models/falcon_mamba.py b/vllm/model_executor/models/falcon_mamba.py index cd3c50f3ad41f..19877dd23f6e0 100644 --- a/vllm/model_executor/models/falcon_mamba.py +++ b/vllm/model_executor/models/falcon_mamba.py @@ -44,6 +44,7 @@ class FalconMambaCacheParams: conv_state: torch.Tensor = torch.Tensor() ssm_state: torch.Tensor = torch.Tensor() + # Adapted from transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaMixer class FalconMambaMixer(nn.Module): """ @@ -118,14 +119,14 @@ def __init__(self, config: FalconMambaConfig, layer_idx): self.activation = config.hidden_act self.dt_layernorm = RMSNorm(self.time_step_rank, - eps=self.rms_eps, - is_learnable=False) + eps=self.rms_eps, + is_learnable=False) self.b_layernorm = RMSNorm(self.ssm_state_size, - eps=self.rms_eps, - is_learnable=False) + eps=self.rms_eps, + is_learnable=False) self.c_layernorm = RMSNorm(self.ssm_state_size, - eps=self.rms_eps, - is_learnable=False) + eps=self.rms_eps, + is_learnable=False) def forward(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, conv_state: torch.Tensor, @@ -178,7 +179,7 @@ def forward(self, hidden_states: torch.Tensor, time_step = self.dt_layernorm(time_step.contiguous()) B = self.b_layernorm(B.contiguous()) C = self.c_layernorm(C.contiguous()) - + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) # 3.c perform the recurrence y ← SSM(A, B, C)(x) time_proj_bias = (self.dt_proj.bias.float() if hasattr( @@ -281,9 +282,9 @@ def __init__( for i in range(config.num_hidden_layers): decoder_layers.append( FalconMambaDecoderLayer(config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config)) + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) self.layers = nn.ModuleList(decoder_layers) self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -355,9 +356,9 @@ def __init__( self.config = config self.scheduler_config = scheduler_config self.backbone = FalconMambaModel(config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config) + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size From 41932049f56db965c35db1db57712296012e40f9 Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Sun, 13 Oct 2024 12:06:13 +0000 Subject: [PATCH 3/9] fix mypy --- vllm/model_executor/models/falcon_mamba.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/falcon_mamba.py b/vllm/model_executor/models/falcon_mamba.py index 19877dd23f6e0..74bc527b70621 100644 --- a/vllm/model_executor/models/falcon_mamba.py +++ b/vllm/model_executor/models/falcon_mamba.py @@ -14,8 +14,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) @@ -24,6 +22,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, @@ -45,7 +45,8 @@ class FalconMambaCacheParams: ssm_state: torch.Tensor = torch.Tensor() -# Adapted from transformers.models.falcon_mamba.modeling_falcon_mamba.FalconMambaMixer +# Adapted from transformers.models.falcon_mamba. +# modeling_falcon_mamba.FalconMambaMixer class FalconMambaMixer(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute From 402758b6a7ea82c479f8fafa9ea9b02c22662a20 Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Sun, 13 Oct 2024 12:22:02 +0000 Subject: [PATCH 4/9] add FalconMamba in supported_models --- docs/source/models/supported_models.rst | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index bf86a72e20b57..cbb7f10bc5f60 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -11,7 +11,7 @@ Text-only Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^ Text Generation ---------------- +--------------- .. list-table:: :widths: 25 25 50 5 5 @@ -87,6 +87,11 @@ Text Generation - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. - - ✅︎ + * - :code:`FalconMambaForCausalLM` + - FalconMamba + - :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc. + - ✅︎ + - * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. @@ -156,7 +161,7 @@ Text Generation - Mamba - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. - ✅︎ - - + - * - :code:`MiniCPMForCausalLM` - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. From a80adf53b283e11d7f16fefff733f1a7d18c804f Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Sun, 13 Oct 2024 12:58:45 +0000 Subject: [PATCH 5/9] add test file --- .../language/test_falcon_mamba.py | 296 ++++++++++++++++++ 1 file changed, 296 insertions(+) create mode 100644 tests/models/decoder_only/language/test_falcon_mamba.py diff --git a/tests/models/decoder_only/language/test_falcon_mamba.py b/tests/models/decoder_only/language/test_falcon_mamba.py new file mode 100644 index 0000000000000..30aab0a7079a1 --- /dev/null +++ b/tests/models/decoder_only/language/test_falcon_mamba.py @@ -0,0 +1,296 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. + +Run `pytest tests/models/decoder_only/language/test_falcon_mamba.py`. +""" +import pytest +from transformers import AutoModelForCausalLM, AutoTokenizer + +from vllm.sampling_params import SamplingParams +from vllm.worker.model_runner import _get_graph_batch_size + +from ...utils import check_outputs_equal + +MODELS = ["tiiuae/falcon-mamba-tiny-dev"] + + +# Use lower-level interfaces to create this greedy generator, as Falconmamba +# will choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy +# is used. +def generate_greedy(model_name, example_prompts, max_tokens): + # Create a text generation pipeline + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name) + + # Generate texts from the prompts + outputs = [] + for prompt in example_prompts: + # Tokenize the input prompt with truncation + inputs = tokenizer(prompt, return_tensors="pt", truncation=True) + input_ids = inputs["input_ids"].to(model.device) + + # Generate text using the model's generate method directly + generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) + generated_text = tokenizer.decode(generated_ids[0], + skip_special_tokens=True) + + outputs.append((generated_ids[0].tolist(), generated_text)) + + return outputs + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_models( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + hf_outputs = generate_greedy(model, example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_batching( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + for_loop_outputs = [] + with vllm_runner(model, dtype=dtype) as vllm_model: + for prompt in example_prompts: + for_loop_outputs.append( + vllm_model.generate_greedy([prompt], max_tokens)[0]) + + batched_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=for_loop_outputs, + outputs_1_lst=batched_outputs, + name_0="for_loop_vllm", + name_1="batched_vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, + model: str, dtype: str, + max_tokens: int) -> None: + # Tests chunked prefill in conjunction with n>1. In this case, prefill is + # populated with decoding tokens and we test that it doesn't fail. + # This test might fail if cache is not allocated correctly for n > 1 + # decoding steps inside a chunked prefill forward pass (where we have both + # prefill and decode together ) + sampling_params = SamplingParams(n=3, + temperature=1, + seed=0, + max_tokens=max_tokens) + with vllm_runner( + model, + dtype=dtype, + enable_chunked_prefill=True, + max_num_batched_tokens=30, + max_num_seqs=10 # forces prefill chunks with decoding + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str, + max_tokens: int, + chunked_prefill_token_size: int) -> None: + """ + Checks exact match decode between huggingface model and vllm runner with + chunked prefill. + """ + max_num_seqs = chunked_prefill_token_size + max_num_batched_tokens = chunked_prefill_token_size + + non_chunked = generate_greedy(model, example_prompts, max_tokens) + + with vllm_runner(model, + dtype=dtype, + enable_chunked_prefill=True, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: + chunked = vllm_model.generate_greedy(example_prompts, + max_tokens=max_tokens) + + check_outputs_equal( + outputs_0_lst=chunked, + outputs_1_lst=non_chunked, + name_0="chunked", + name_1="non_chunked", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [15]) +def test_parallel_sampling( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + + with vllm_runner(model, dtype=dtype) as vllm_model: + for_loop_outputs = [] + for _ in range(10): + for_loop_outputs.append( + # using example_prompts index 1 instead of 0 since with 0 the + # logprobs get really close and the test doesn't pass + vllm_model.generate_greedy([example_prompts[1]], max_tokens) + [0]) + sampling_params = SamplingParams(n=10, + temperature=0.001, + seed=0, + max_tokens=max_tokens) + n_lt_1_outputs = vllm_model.generate([example_prompts[1]], + sampling_params) + token_ids, texts = n_lt_1_outputs[0] + n_lt_1_outputs = [(token_id, text) + for token_id, text in zip(token_ids, texts)] + + check_outputs_equal( + outputs_0_lst=n_lt_1_outputs, + outputs_1_lst=for_loop_outputs, + name_0="vllm_n_lt_1_outputs", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_falcon_mamba_cache_cg_padding( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # This test is for verifying that mamba cache is padded to CG captured + # batch size. If it's not, a torch RuntimeError will be raised because + # tensor dimensions aren't compatible + while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + example_prompts.append(example_prompts[0]) + + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + except RuntimeError: + pytest.fail( + "Couldn't run batch size which is not equal to a Cuda Graph " + "captured batch size. " + "Could be related to mamba cache not padded correctly") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [20]) +def test_models_preemption_recompute( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # Tests that outputs are identical with and w/o preemtions (recompute) + assert dtype == "float" + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = True + preempt_vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + vllm_model.model.llm_engine.scheduler[ + 0].ENABLE_ARTIFICIAL_PREEMPT = False + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=preempt_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="vllm_preepmtions", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Mamba inner state management doesn't + # collapse in case where the number of incoming requests and + # finished_requests_ids is larger than the maximum Mamba block capacity. + # This could generally happen due to the fact that Mamba does support + # statelessness mechanism where it can cleanup new incoming requests in + # a single step. + try: + with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 100, 10) + except ValueError: + pytest.fail("Mamba inner state wasn't cleaned up properly between" + "steps finished requests registered unnecessarily ") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_state_cleanup( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is for verifying that the Mamba state is cleaned up between + # steps, If its not cleaned, an error would be expected. + try: + with vllm_runner(model, dtype=dtype) as vllm_model: + for _ in range(10): + vllm_model.generate_greedy([example_prompts[0]] * 100, 1) + except ValueError: + pytest.fail("Mamba inner state wasn't cleaned up between states, " + "could be related to finished_requests_ids") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) From f66774a2f7aa9137f888fa9747bd4ab02779d162 Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Wed, 16 Oct 2024 06:09:58 +0000 Subject: [PATCH 6/9] adress comments --- docs/source/models/supported_models.rst | 4 +- .../language/test_falcon_mamba.py | 296 ------------------ .../decoder_only/language/test_mamba.py | 2 +- vllm/model_executor/models/falcon_mamba.py | 59 +--- 4 files changed, 11 insertions(+), 350 deletions(-) delete mode 100644 tests/models/decoder_only/language/test_falcon_mamba.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index cbb7f10bc5f60..586dafc5e027a 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -11,7 +11,7 @@ Text-only Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^ Text Generation ---------------- +--------------- .. list-table:: :widths: 25 25 50 5 5 @@ -161,7 +161,7 @@ Text Generation - Mamba - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. - ✅︎ - - + - * - :code:`MiniCPMForCausalLM` - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. diff --git a/tests/models/decoder_only/language/test_falcon_mamba.py b/tests/models/decoder_only/language/test_falcon_mamba.py deleted file mode 100644 index 30aab0a7079a1..0000000000000 --- a/tests/models/decoder_only/language/test_falcon_mamba.py +++ /dev/null @@ -1,296 +0,0 @@ -"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. - -Run `pytest tests/models/decoder_only/language/test_falcon_mamba.py`. -""" -import pytest -from transformers import AutoModelForCausalLM, AutoTokenizer - -from vllm.sampling_params import SamplingParams -from vllm.worker.model_runner import _get_graph_batch_size - -from ...utils import check_outputs_equal - -MODELS = ["tiiuae/falcon-mamba-tiny-dev"] - - -# Use lower-level interfaces to create this greedy generator, as Falconmamba -# will choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy -# is used. -def generate_greedy(model_name, example_prompts, max_tokens): - # Create a text generation pipeline - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name) - - # Generate texts from the prompts - outputs = [] - for prompt in example_prompts: - # Tokenize the input prompt with truncation - inputs = tokenizer(prompt, return_tensors="pt", truncation=True) - input_ids = inputs["input_ids"].to(model.device) - - # Generate text using the model's generate method directly - generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) - generated_text = tokenizer.decode(generated_ids[0], - skip_special_tokens=True) - - outputs.append((generated_ids[0].tolist(), generated_text)) - - return outputs - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_models( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - hf_outputs = generate_greedy(model, example_prompts, max_tokens) - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_batching( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # To pass the small model tests, we need full precision. - for_loop_outputs = [] - with vllm_runner(model, dtype=dtype) as vllm_model: - for prompt in example_prompts: - for_loop_outputs.append( - vllm_model.generate_greedy([prompt], max_tokens)[0]) - - batched_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - check_outputs_equal( - outputs_0_lst=for_loop_outputs, - outputs_1_lst=batched_outputs, - name_0="for_loop_vllm", - name_1="batched_vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [10]) -def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, - model: str, dtype: str, - max_tokens: int) -> None: - # Tests chunked prefill in conjunction with n>1. In this case, prefill is - # populated with decoding tokens and we test that it doesn't fail. - # This test might fail if cache is not allocated correctly for n > 1 - # decoding steps inside a chunked prefill forward pass (where we have both - # prefill and decode together ) - sampling_params = SamplingParams(n=3, - temperature=1, - seed=0, - max_tokens=max_tokens) - with vllm_runner( - model, - dtype=dtype, - enable_chunked_prefill=True, - max_num_batched_tokens=30, - max_num_seqs=10 # forces prefill chunks with decoding - ) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int, - chunked_prefill_token_size: int) -> None: - """ - Checks exact match decode between huggingface model and vllm runner with - chunked prefill. - """ - max_num_seqs = chunked_prefill_token_size - max_num_batched_tokens = chunked_prefill_token_size - - non_chunked = generate_greedy(model, example_prompts, max_tokens) - - with vllm_runner(model, - dtype=dtype, - enable_chunked_prefill=True, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - chunked = vllm_model.generate_greedy(example_prompts, - max_tokens=max_tokens) - - check_outputs_equal( - outputs_0_lst=chunked, - outputs_1_lst=non_chunked, - name_0="chunked", - name_1="non_chunked", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [15]) -def test_parallel_sampling( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - - with vllm_runner(model, dtype=dtype) as vllm_model: - for_loop_outputs = [] - for _ in range(10): - for_loop_outputs.append( - # using example_prompts index 1 instead of 0 since with 0 the - # logprobs get really close and the test doesn't pass - vllm_model.generate_greedy([example_prompts[1]], max_tokens) - [0]) - sampling_params = SamplingParams(n=10, - temperature=0.001, - seed=0, - max_tokens=max_tokens) - n_lt_1_outputs = vllm_model.generate([example_prompts[1]], - sampling_params) - token_ids, texts = n_lt_1_outputs[0] - n_lt_1_outputs = [(token_id, text) - for token_id, text in zip(token_ids, texts)] - - check_outputs_equal( - outputs_0_lst=n_lt_1_outputs, - outputs_1_lst=for_loop_outputs, - name_0="vllm_n_lt_1_outputs", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [20]) -def test_falcon_mamba_cache_cg_padding( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # This test is for verifying that mamba cache is padded to CG captured - # batch size. If it's not, a torch RuntimeError will be raised because - # tensor dimensions aren't compatible - while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): - example_prompts.append(example_prompts[0]) - - try: - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) - except RuntimeError: - pytest.fail( - "Couldn't run batch size which is not equal to a Cuda Graph " - "captured batch size. " - "Could be related to mamba cache not padded correctly") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [20]) -def test_models_preemption_recompute( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # Tests that outputs are identical with and w/o preemtions (recompute) - assert dtype == "float" - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = True - preempt_vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) - - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = False - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=preempt_vllm_outputs, - outputs_1_lst=vllm_outputs, - name_0="vllm_preepmtions", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - # This test is for verifying that the Mamba inner state management doesn't - # collapse in case where the number of incoming requests and - # finished_requests_ids is larger than the maximum Mamba block capacity. - # This could generally happen due to the fact that Mamba does support - # statelessness mechanism where it can cleanup new incoming requests in - # a single step. - try: - with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: - vllm_model.generate_greedy([example_prompts[0]] * 100, 10) - except ValueError: - pytest.fail("Mamba inner state wasn't cleaned up properly between" - "steps finished requests registered unnecessarily ") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_state_cleanup( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - # This test is for verifying that the Mamba state is cleaned up between - # steps, If its not cleaned, an error would be expected. - try: - with vllm_runner(model, dtype=dtype) as vllm_model: - for _ in range(10): - vllm_model.generate_greedy([example_prompts[0]] * 100, 1) - except ValueError: - pytest.fail("Mamba inner state wasn't cleaned up between states, " - "could be related to finished_requests_ids") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_model_print( - vllm_runner, - model: str, - dtype: str, -) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: - # This test is for verifying whether the model's extra_repr - # can be printed correctly. - print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index c27bf6a60a4f4..2dc231c595ffa 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -10,7 +10,7 @@ from ...utils import check_outputs_equal -MODELS = ["state-spaces/mamba-130m-hf"] +MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"] # Use lower-level interfaces to create this greedy generator, as mamba will diff --git a/vllm/model_executor/models/falcon_mamba.py b/vllm/model_executor/models/falcon_mamba.py index 74bc527b70621..68d79e658a72c 100644 --- a/vllm/model_executor/models/falcon_mamba.py +++ b/vllm/model_executor/models/falcon_mamba.py @@ -1,6 +1,5 @@ # coding=utf-8 """PyTorch FalconMAMBA model.""" -from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple import torch @@ -38,13 +37,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -@dataclass -class FalconMambaCacheParams: - is_prompt: bool = False - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - - # Adapted from transformers.models.falcon_mamba. # modeling_falcon_mamba.FalconMambaMixer class FalconMambaMixer(nn.Module): @@ -321,18 +313,10 @@ def forward( class FalconMambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - } + packed_modules_mapping = {} # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", - "o_proj", "embed_tokens", "lm_head", ] @@ -440,43 +424,16 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "A_log" in name: name = name.replace("A_log", "A") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue - if ".self_attn." in name: - name = name.replace(".self_attn", "") - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From ff57d441f82c837002425a25b2628256e7ac7d9f Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Fri, 18 Oct 2024 10:26:46 +0000 Subject: [PATCH 7/9] integrate FalconMamba inside mamba.py --- vllm/model_executor/layers/layernorm.py | 19 +- vllm/model_executor/models/falcon_mamba.py | 439 --------------------- vllm/model_executor/models/mamba.py | 57 ++- vllm/model_executor/models/registry.py | 2 +- 4 files changed, 53 insertions(+), 464 deletions(-) delete mode 100644 vllm/model_executor/models/falcon_mamba.py diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 1306b2aa381fa..847a0df8beba4 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -14,24 +14,19 @@ class RMSNorm(CustomOp): Refer to https://arxiv.org/abs/1910.07467 """ - def __init__(self, - hidden_size: int, - eps: float = 1e-6, - var_hidden_size: Optional[int] = None, - is_learnable: bool = True) -> None: + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size self.variance_epsilon = eps self.variance_size_override = (None if var_hidden_size == hidden_size else var_hidden_size) - if is_learnable: - self.register_parameter("weight", - nn.Parameter(torch.ones(hidden_size))) - else: - self.register_buffer('weight', - torch.ones(hidden_size), - persistent=False) + self.weight = nn.Parameter(torch.ones(hidden_size)) def forward_native( self, diff --git a/vllm/model_executor/models/falcon_mamba.py b/vllm/model_executor/models/falcon_mamba.py deleted file mode 100644 index 68d79e658a72c..0000000000000 --- a/vllm/model_executor/models/falcon_mamba.py +++ /dev/null @@ -1,439 +0,0 @@ -# coding=utf-8 -"""PyTorch FalconMAMBA model.""" -from typing import Iterable, List, Optional, Tuple - -import torch -from torch import nn -from transformers import FalconMambaConfig - -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - composed_weight_loader, default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree) -from vllm.model_executor.models.mamba_cache import MambaCacheManager -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors -from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, - _get_graph_batch_size) - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -# Adapted from transformers.models.falcon_mamba. -# modeling_falcon_mamba.FalconMambaMixer -class FalconMambaMixer(nn.Module): - """ - Compute ∆, A, B, C, and D the state space parameters and compute - the `contextualized_states`. A, D are input independent - (see Mamba paper [1] Section 3.5.2 "Interpretation of A" - for why A isn't selective) ∆, B, C are input-dependent - (this is a key difference between Mamba and the linear time - invariant S4, and is why Mamba is called - **selective** state spaces) - """ - - def __init__(self, config: FalconMambaConfig, layer_idx): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - self.intermediate_size = config.intermediate_size - self.time_step_rank = int(config.time_step_rank) - self.rms_eps = config.mixer_rms_eps - - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.intermediate_size, - bias=config.use_conv_bias, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - self.in_proj = MergedColumnParallelLinear(self.hidden_size, - [self.intermediate_size] * 2, - bias=config.use_bias) - # selective projection used to make dt, B and C input dependent - self.x_proj = RowParallelLinear( - self.intermediate_size, - self.time_step_rank + self.ssm_state_size * 2, - bias=False, - ) - # time step projection (discretization) - - # In the forward we need to apply dt_proj without the bias, - # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear(self.time_step_rank, - self.intermediate_size, - bias=True, - skip_bias_add=True) - - tp_size = get_tensor_model_parallel_world_size() - self.A = nn.Parameter( - torch.empty( - self.intermediate_size // tp_size, - self.ssm_state_size, - dtype=torch.float32, - )) - self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - - set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) - a_weight_loader = composed_weight_loader( - sharded_weight_loader(0), lambda x: -torch.exp(x.float())) - set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) - - self.out_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=config.use_bias, - input_is_parallel=True, - ) - self.activation = config.hidden_act - - self.dt_layernorm = RMSNorm(self.time_step_rank, - eps=self.rms_eps, - is_learnable=False) - self.b_layernorm = RMSNorm(self.ssm_state_size, - eps=self.rms_eps, - is_learnable=False) - self.c_layernorm = RMSNorm(self.ssm_state_size, - eps=self.rms_eps, - is_learnable=False) - - def forward(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, conv_state: torch.Tensor, - ssm_state: torch.Tensor): - - # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - ) - hidden_states = hidden_states.transpose(0, 1) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1, - ) - - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) - - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, - ssm_state, - discrete_time_step, - self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - gate, - time_proj_bias, - delta_softplus=True, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) - else: - scan_outputs = selective_state_update( - ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, - B, - C, - self.D, - gate.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - ) - scan_outputs = scan_outputs.transpose(0, 1) - - # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] - return contextualized_states - - -class FalconMambaDecoderLayer(nn.Module): - - def __init__(self, - config: FalconMambaConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: - super().__init__() - self.layer_idx = layer_idx - self.config = config - self.mixer = FalconMambaMixer(config, layer_idx) - - self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - conv_state: torch.Tensor, - ssm_state: torch.Tensor, - **kwargs, - ): - if residual is None: - residual = hidden_states - hidden_states = self.norm(hidden_states) - else: - hidden_states, residual = self.norm(hidden_states, residual) - - hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, - ssm_state) - return hidden_states, residual - - -class FalconMambaModel(nn.Module): - - def __init__( - self, - config: FalconMambaConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.embeddings = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - decoder_layers = [] - for i in range(config.num_hidden_layers): - decoder_layers.append( - FalconMambaDecoderLayer(config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config)) - self.layers = nn.ModuleList(decoder_layers) - self.norm_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, - ) -> torch.Tensor: - hidden_states = self.embeddings(input_ids) - residual = None - - for i in range(len(self.layers)): - layer = self.layers[i] - current_ssm_state = ssm_state[i] - current_conv_state = conv_state[i] - - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - residual=residual, - conv_state=current_conv_state, - ssm_state=current_ssm_state, - ) - hidden_states, _ = self.norm_f(hidden_states, residual) - - return hidden_states - - -class FalconMambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): - packed_modules_mapping = {} - - # LoRA specific attributes - supported_lora_modules = [ - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embeddings": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - - def __init__( - self, - config: FalconMambaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - scheduler_config: Optional[SchedulerConfig] = None, - ) -> None: - assert not cache_config.enable_prefix_caching, \ - "Mamba does not support prefix caching" - - super().__init__() - self.config = config - self.scheduler_config = scheduler_config - self.backbone = FalconMambaModel(config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - ) - - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Optional[MambaCacheManager] = None - - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - self.sampler = Sampler() - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs): - if self.mamba_cache is None: - max_batch_size = (_get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config - else max(_BATCH_SIZES_TO_CAPTURE) + 2) - self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, self.config.num_hidden_layers, - max_batch_size, *self._get_mamba_cache_shape()) - - mamba_cache_tensors = self.mamba_cache.current_run_tensors( - input_ids, attn_metadata, **kwargs) - - hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_tensors[0], - mamba_cache_tensors[1]) - - return hidden_states - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - - def _get_mamba_cache_shape( - self) -> Tuple[Tuple[int, int], Tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - conv_state_shape = ( - self.config.intermediate_size // world_size, - self.config.conv_kernel - 1, - ) - temporal_state_shape = ( - self.config.intermediate_size // world_size, - self.config.state_size, - ) - return conv_state_shape, temporal_state_shape - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "A_log" in name: - name = name.replace("A_log", "A") - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 1112a2181135a..e61797b1047ff 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -24,7 +24,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, @@ -67,7 +67,7 @@ def __init__(self, config: MambaConfig, layer_idx): self.conv_kernel_size = config.conv_kernel self.intermediate_size = config.intermediate_size self.time_step_rank = int(config.time_step_rank) - + self.is_falcon_mamba = config.model_type == "falcon_mamba" self.conv1d = ColumnParallelLinear( input_size=self.conv_kernel_size, output_size=self.intermediate_size, @@ -117,6 +117,13 @@ def __init__(self, config: MambaConfig, layer_idx): input_is_parallel=True, ) self.activation = config.hidden_act + if self.is_falcon_mamba: + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=config.mixer_rms_eps) + self.b_layernorm = RMSNorm(self.ssm_state_size, + eps=config.mixer_rms_eps) + self.c_layernorm = RMSNorm(self.ssm_state_size, + eps=config.mixer_rms_eps) def forward(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, conv_state: torch.Tensor, @@ -165,8 +172,12 @@ def forward(self, hidden_states: torch.Tensor, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1, ) - - # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. + # Note that Jamba and FalconMamba normalizes B, C, and time_step here + # but Mamba doesn't. + if self.is_falcon_mamba: + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) # 3.c perform the recurrence y ← SSM(A, B, C)(x) @@ -250,9 +261,11 @@ def __init__(self, super().__init__() self.layer_idx = layer_idx self.config = config + self.is_falcon_mamba = config.model_type == "falcon_mamba" self.mixer = MambaMixer(config, layer_idx) - self.feed_forward = MambaMLP(config, quant_config=quant_config) + if not self.is_falcon_mamba: + self.feed_forward = MambaMLP(config, quant_config=quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -274,10 +287,11 @@ def forward( hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, ssm_state) - # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) - hidden_states = self.feed_forward(hidden_states) + if not self.is_falcon_mamba: + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @@ -388,8 +402,18 @@ def __init__( self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - - self.lm_head = self.backbone.embeddings + if config.tie_word_embeddings: + self.lm_head = self.backbone.embeddings + else: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) # Used to track and store by the Mamba cache between steps. self.mamba_cache: Optional[MambaCacheManager] = None @@ -456,6 +480,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def non_learnable_rms(weights: torch.Tensor): + """ + Args: + weights (torch.Tensor): set RMSNorm weights to a non learnable + torch.ones Tensor. + """ + return torch.ones(weights.shape[0]) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -476,7 +508,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if ".self_attn." in name: name = name.replace(".self_attn", "") - + # if name in ("dt_layernorm", "b_layernorm", "c_layernorm"): + # loaded_weight = self.non_learnable_rms(loaded_weight) for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 2b717258bf83a..72e08c0f6eb2a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -53,7 +53,7 @@ # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), - "FalconMambaForCausalLM": ("falcon_mamba", "FalconMambaForCausalLM"), + "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), From e0e65e9dff247355496774db3eee60ea979ecc82 Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Fri, 18 Oct 2024 11:14:51 +0000 Subject: [PATCH 8/9] fix issues --- vllm/model_executor/models/mamba.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 4a8f22ea92111..a8afbaa623ccf 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -226,11 +226,6 @@ def __init__(self, self.config = config self.is_falcon_mamba = config.model_type == "falcon_mamba" self.mixer = MambaMixer(config, layer_idx) - - if not self.is_falcon_mamba: - self.feed_forward = MambaMLP(config, quant_config=quant_config) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( @@ -247,12 +242,8 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, attn_metadata, mamba_cache_params) - if not self.is_falcon_mamba: - # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) - hidden_states = self.feed_forward(hidden_states) + hidden_states = self.mixer(hidden_states, attn_metadata, + mamba_cache_params) return hidden_states, residual From 7738d2df993a7de02b60486527c0ac3c4237d7e5 Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Fri, 18 Oct 2024 11:15:54 +0000 Subject: [PATCH 9/9] remove unneeded method --- vllm/model_executor/models/mamba.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index a8afbaa623ccf..9f4f391a6682e 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -412,14 +412,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def non_learnable_rms(weights: torch.Tensor): - """ - Args: - weights (torch.Tensor): set RMSNorm weights to a non learnable - torch.ones Tensor. - """ - return torch.ones(weights.shape[0]) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: