Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] FalconMamba Support #9325

Merged
merged 10 commits into from
Oct 21, 2024
Merged
7 changes: 6 additions & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ Text Generation
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
-
- ✅︎
* - :code:`FalconMambaForCausalLM`
- FalconMamba
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
- ✅︎
-
* - :code:`GemmaForCausalLM`
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
Expand Down Expand Up @@ -156,7 +161,7 @@ Text Generation
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
- ✅︎
-
-
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spurious whitespace

* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
Expand Down
2 changes: 1 addition & 1 deletion tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ...utils import check_outputs_equal

MODELS = ["state-spaces/mamba-130m-hf"]
MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"]


# Use lower-level interfaces to create this greedy generator, as mamba will
Expand Down
20 changes: 12 additions & 8 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,24 @@ class RMSNorm(CustomOp):
Refer to https://arxiv.org/abs/1910.07467
"""

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
) -> None:
def __init__(self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
is_learnable: bool = True) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RMSNorm weights are non learnable for FalconMamba model.
The idea is to add support for non learnable RMSNorm weights, so we can benefit from the same forward types of this class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this a bit more? It seems like this might have been done to work around some issues that popped up during weight loading. Is that right?

And am I right that the weights will always be 1.0 for Falcon Mamba, i.e. we could skip the application of the weights for dt_layernorm, b_layernorm, and c_layernorm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The idea is to register weights as parameters when they are learnable and register them as buffers whenever they are not so that they will not be included in the state_dict of the model.
    the same logic is applied here https://pytorch.org/docs/stable/_modules/torch/nn/modules/normalization.html#RMSNorm (pytorch implementation of RMSNorm)

2.Yes , you are right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation -- I think it would be better to handle this in FalconMambaForCausalLM.load_weights, since it's a special case that only applies to FalconMamba currently.

In load_weights, could you add a condition to check if dt_layernorm, b_layernorm, or c_layernorm is in the name? If this is the case, we can set the weight loader to a function that explicitly sets all of the elements to 1.0, which will make things explicitly clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review

I managed to integrate FalconMamba inside mamba.py.

for rmsnorm , i reveretd the changes , but i think there is no need to handle dt_layernorm, b_layernorm, or c_layernorm inside load_weights since they have been initialised as nn.parameters(torch.ones(hidden_size)) inside RMSNorm initial implementation which is compatible with FalconMamba dt,b,c rmsnorms.

super().__init__()

self.hidden_size = hidden_size
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: best not to introduce whitespace-only changes to files

self.weight = nn.Parameter(torch.ones(hidden_size))
if is_learnable:
self.register_parameter("weight",
nn.Parameter(torch.ones(hidden_size)))
else:
self.register_buffer('weight',
torch.ones(hidden_size),
persistent=False)

def forward_native(
self,
Expand Down
Loading