-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] FalconMamba Support #9325
Changes from 5 commits
465dee3
e1a1a02
4193204
402758b
a80adf5
f66774a
ff57d44
42bc94c
e0e65e9
7738d2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ Text-only Language Models | |
^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
Text Generation | ||
--------------- | ||
--------------- | ||
|
||
.. list-table:: | ||
:widths: 25 25 50 5 5 | ||
|
@@ -87,6 +87,11 @@ Text Generation | |
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. | ||
- | ||
- ✅︎ | ||
* - :code:`FalconMambaForCausalLM` | ||
- FalconMamba | ||
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc. | ||
- ✅︎ | ||
- | ||
* - :code:`GemmaForCausalLM` | ||
- Gemma | ||
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. | ||
|
@@ -156,7 +161,7 @@ Text Generation | |
- Mamba | ||
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. | ||
- ✅︎ | ||
- | ||
- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto, spurious whitespace |
||
* - :code:`MiniCPMForCausalLM` | ||
- MiniCPM | ||
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. | ||
|
||
Run `pytest tests/models/decoder_only/language/test_falcon_mamba.py`. | ||
""" | ||
import pytest | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from vllm.sampling_params import SamplingParams | ||
from vllm.worker.model_runner import _get_graph_batch_size | ||
|
||
from ...utils import check_outputs_equal | ||
|
||
MODELS = ["tiiuae/falcon-mamba-tiny-dev"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of adding the |
||
|
||
|
||
# Use lower-level interfaces to create this greedy generator, as Falconmamba | ||
# will choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy | ||
# is used. | ||
def generate_greedy(model_name, example_prompts, max_tokens): | ||
# Create a text generation pipeline | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForCausalLM.from_pretrained(model_name) | ||
|
||
# Generate texts from the prompts | ||
outputs = [] | ||
for prompt in example_prompts: | ||
# Tokenize the input prompt with truncation | ||
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | ||
input_ids = inputs["input_ids"].to(model.device) | ||
|
||
# Generate text using the model's generate method directly | ||
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) | ||
generated_text = tokenizer.decode(generated_ids[0], | ||
skip_special_tokens=True) | ||
|
||
outputs.append((generated_ids[0].tolist(), generated_text)) | ||
|
||
return outputs | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [96]) | ||
def test_models( | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
hf_outputs = generate_greedy(model, example_prompts, max_tokens) | ||
|
||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) | ||
|
||
for i in range(len(example_prompts)): | ||
hf_output_ids, hf_output_str = hf_outputs[i] | ||
vllm_output_ids, vllm_output_str = vllm_outputs[i] | ||
assert hf_output_str == vllm_output_str, ( | ||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") | ||
assert hf_output_ids == vllm_output_ids, ( | ||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [96]) | ||
def test_batching( | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
# To pass the small model tests, we need full precision. | ||
for_loop_outputs = [] | ||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
for prompt in example_prompts: | ||
for_loop_outputs.append( | ||
vllm_model.generate_greedy([prompt], max_tokens)[0]) | ||
|
||
batched_outputs = vllm_model.generate_greedy(example_prompts, | ||
max_tokens) | ||
|
||
check_outputs_equal( | ||
outputs_0_lst=for_loop_outputs, | ||
outputs_1_lst=batched_outputs, | ||
name_0="for_loop_vllm", | ||
name_1="batched_vllm", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [10]) | ||
def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, | ||
model: str, dtype: str, | ||
max_tokens: int) -> None: | ||
# Tests chunked prefill in conjunction with n>1. In this case, prefill is | ||
# populated with decoding tokens and we test that it doesn't fail. | ||
# This test might fail if cache is not allocated correctly for n > 1 | ||
# decoding steps inside a chunked prefill forward pass (where we have both | ||
# prefill and decode together ) | ||
sampling_params = SamplingParams(n=3, | ||
temperature=1, | ||
seed=0, | ||
max_tokens=max_tokens) | ||
with vllm_runner( | ||
model, | ||
dtype=dtype, | ||
enable_chunked_prefill=True, | ||
max_num_batched_tokens=30, | ||
max_num_seqs=10 # forces prefill chunks with decoding | ||
) as vllm_model: | ||
vllm_model.generate(example_prompts, sampling_params) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [32]) | ||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) | ||
def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str, | ||
max_tokens: int, | ||
chunked_prefill_token_size: int) -> None: | ||
""" | ||
Checks exact match decode between huggingface model and vllm runner with | ||
chunked prefill. | ||
""" | ||
max_num_seqs = chunked_prefill_token_size | ||
max_num_batched_tokens = chunked_prefill_token_size | ||
|
||
non_chunked = generate_greedy(model, example_prompts, max_tokens) | ||
|
||
with vllm_runner(model, | ||
dtype=dtype, | ||
enable_chunked_prefill=True, | ||
max_num_batched_tokens=max_num_batched_tokens, | ||
max_num_seqs=max_num_seqs) as vllm_model: | ||
chunked = vllm_model.generate_greedy(example_prompts, | ||
max_tokens=max_tokens) | ||
|
||
check_outputs_equal( | ||
outputs_0_lst=chunked, | ||
outputs_1_lst=non_chunked, | ||
name_0="chunked", | ||
name_1="non_chunked", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [15]) | ||
def test_parallel_sampling( | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
|
||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
for_loop_outputs = [] | ||
for _ in range(10): | ||
for_loop_outputs.append( | ||
# using example_prompts index 1 instead of 0 since with 0 the | ||
# logprobs get really close and the test doesn't pass | ||
vllm_model.generate_greedy([example_prompts[1]], max_tokens) | ||
[0]) | ||
sampling_params = SamplingParams(n=10, | ||
temperature=0.001, | ||
seed=0, | ||
max_tokens=max_tokens) | ||
n_lt_1_outputs = vllm_model.generate([example_prompts[1]], | ||
sampling_params) | ||
token_ids, texts = n_lt_1_outputs[0] | ||
n_lt_1_outputs = [(token_id, text) | ||
for token_id, text in zip(token_ids, texts)] | ||
|
||
check_outputs_equal( | ||
outputs_0_lst=n_lt_1_outputs, | ||
outputs_1_lst=for_loop_outputs, | ||
name_0="vllm_n_lt_1_outputs", | ||
name_1="vllm", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
@pytest.mark.parametrize("max_tokens", [20]) | ||
def test_falcon_mamba_cache_cg_padding( | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
# This test is for verifying that mamba cache is padded to CG captured | ||
# batch size. If it's not, a torch RuntimeError will be raised because | ||
# tensor dimensions aren't compatible | ||
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): | ||
example_prompts.append(example_prompts[0]) | ||
|
||
try: | ||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
vllm_model.generate_greedy(example_prompts, max_tokens) | ||
except RuntimeError: | ||
pytest.fail( | ||
"Couldn't run batch size which is not equal to a Cuda Graph " | ||
"captured batch size. " | ||
"Could be related to mamba cache not padded correctly") | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [20]) | ||
def test_models_preemption_recompute( | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
# Tests that outputs are identical with and w/o preemtions (recompute) | ||
assert dtype == "float" | ||
|
||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
vllm_model.model.llm_engine.scheduler[ | ||
0].ENABLE_ARTIFICIAL_PREEMPT = True | ||
preempt_vllm_outputs = vllm_model.generate_greedy( | ||
example_prompts, max_tokens) | ||
|
||
vllm_model.model.llm_engine.scheduler[ | ||
0].ENABLE_ARTIFICIAL_PREEMPT = False | ||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) | ||
|
||
check_outputs_equal( | ||
outputs_0_lst=preempt_vllm_outputs, | ||
outputs_1_lst=vllm_outputs, | ||
name_0="vllm_preepmtions", | ||
name_1="vllm", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( | ||
vllm_runner, | ||
model: str, | ||
dtype: str, | ||
example_prompts, | ||
) -> None: | ||
# This test is for verifying that the Mamba inner state management doesn't | ||
# collapse in case where the number of incoming requests and | ||
# finished_requests_ids is larger than the maximum Mamba block capacity. | ||
# This could generally happen due to the fact that Mamba does support | ||
# statelessness mechanism where it can cleanup new incoming requests in | ||
# a single step. | ||
try: | ||
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: | ||
vllm_model.generate_greedy([example_prompts[0]] * 100, 10) | ||
except ValueError: | ||
pytest.fail("Mamba inner state wasn't cleaned up properly between" | ||
"steps finished requests registered unnecessarily ") | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_state_cleanup( | ||
vllm_runner, | ||
model: str, | ||
dtype: str, | ||
example_prompts, | ||
) -> None: | ||
# This test is for verifying that the Mamba state is cleaned up between | ||
# steps, If its not cleaned, an error would be expected. | ||
try: | ||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
for _ in range(10): | ||
vllm_model.generate_greedy([example_prompts[0]] * 100, 1) | ||
except ValueError: | ||
pytest.fail("Mamba inner state wasn't cleaned up between states, " | ||
"could be related to finished_requests_ids") | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_model_print( | ||
vllm_runner, | ||
model: str, | ||
dtype: str, | ||
) -> None: | ||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
# This test is for verifying whether the model's extra_repr | ||
# can be printed correctly. | ||
print(vllm_model.model.llm_engine.model_executor.driver_worker. | ||
model_runner.model) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,20 +14,24 @@ class RMSNorm(CustomOp): | |
Refer to https://arxiv.org/abs/1910.07467 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
eps: float = 1e-6, | ||
var_hidden_size: Optional[int] = None, | ||
) -> None: | ||
def __init__(self, | ||
hidden_size: int, | ||
eps: float = 1e-6, | ||
var_hidden_size: Optional[int] = None, | ||
is_learnable: bool = True) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RMSNorm weights are non learnable for FalconMamba model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain this a bit more? It seems like this might have been done to work around some issues that popped up during weight loading. Is that right? And am I right that the weights will always be 1.0 for Falcon Mamba, i.e. we could skip the application of the weights for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
2.Yes , you are right. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation -- I think it would be better to handle this in In load_weights, could you add a condition to check if dt_layernorm, b_layernorm, or c_layernorm is in the name? If this is the case, we can set the weight loader to a function that explicitly sets all of the elements to 1.0, which will make things explicitly clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review I managed to integrate FalconMamba inside mamba.py. for rmsnorm , i reveretd the changes , but i think there is no need to handle dt_layernorm, b_layernorm, or c_layernorm inside load_weights since they have been initialised as nn.parameters(torch.ones(hidden_size)) inside RMSNorm initial implementation which is compatible with FalconMamba dt,b,c rmsnorms. |
||
super().__init__() | ||
|
||
self.hidden_size = hidden_size | ||
self.variance_epsilon = eps | ||
self.variance_size_override = (None if var_hidden_size == hidden_size | ||
else var_hidden_size) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: best not to introduce whitespace-only changes to files |
||
self.weight = nn.Parameter(torch.ones(hidden_size)) | ||
if is_learnable: | ||
self.register_parameter("weight", | ||
nn.Parameter(torch.ones(hidden_size))) | ||
else: | ||
self.register_buffer('weight', | ||
torch.ones(hidden_size), | ||
persistent=False) | ||
|
||
def forward_native( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove spurious whitespace