-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel][Model] Improve continuous batching for Jamba and Mamba #9189
[Kernel][Model] Improve continuous batching for Jamba and Mamba #9189
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One suggestion: To mirror the PAD_SLOT_ID
that's in backends/utils.py
, we could #define PAD_SLOT_ID -1
in csrc/attention_generic.cuh
, so that we name the constant.
And then we can add comments letting the reader know to keep them in sync
I've taken another approach, passing the pad_slot_id through the kernel params, keeps it simpler since we don't need to keep the cpp and python variables in sync and keeps the mamba kernels as standalone entities. WDYT? |
…ithub/main' into continous_batching_mamba_from_scratch
Pushed adaptations for #6484 , The PR is ready again. |
Changing the title since Jamba and Mamba already support continuous batching -- this just makes it better |
Sounds good to me |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried out the PR and am currently seeing the following error:
E RuntimeError: Error in model execution (input dumped to /tmp/err_execute_model_input_20241015-011501.pkl): _C::causal_conv1d_fwd() expected at most 8 argument(s) but received 9 argument(s). Declaration: _C::causal_conv1d_fwd(Tensor($0! -> ) x, Tensor($1! -> ) weight, Tensor? bias_, Tensor($2! -> )? conv_states, Tensor? query_start_loc, Tensor? cache_indices, Tensor? has_initial_state, bool silu_activation) -> Tensor
If you're seeing the same issue, could you LMK when it's fixed? Looks good otherwise, thanks!
@tlrmchlsmth I cannot reproduce this error, could you let me know how did you run into it? |
Possibly user error, let's see if it goes through the CI ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work, much nicer management of the Mamba Mixer state.
Now it's my turn to extend mamba_chunk_scan_combined
to support this for Mamba2 :)
Merging now! from @simon-mo the readthedocs failure looks transient |
…-project#9189) Signed-off-by: charlifu <[email protected]>
…-project#9189) Signed-off-by: Vinay Damodaran <[email protected]>
…-project#9189) Signed-off-by: Alvant <[email protected]>
…-project#9189) Signed-off-by: Amit Garg <[email protected]>
…-project#9189) Signed-off-by: qishuai <[email protected]>
…-project#9189) Signed-off-by: Sumit Dubey <[email protected]>
CC @tlrmchlsmth