Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Model] Improve continuous batching for Jamba and Mamba #9189

Merged

Conversation

mzusman
Copy link
Contributor

@mzusman mzusman commented Oct 9, 2024

Copy link

github-actions bot commented Oct 9, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One suggestion: To mirror the PAD_SLOT_ID that's in backends/utils.py, we could #define PAD_SLOT_ID -1 in csrc/attention_generic.cuh, so that we name the constant.

And then we can add comments letting the reader know to keep them in sync

csrc/mamba/causal_conv1d/causal_conv1d.cu Outdated Show resolved Hide resolved
csrc/mamba/causal_conv1d/causal_conv1d.cu Outdated Show resolved Hide resolved
tests/kernels/test_causal_conv1d.py Outdated Show resolved Hide resolved
vllm/model_executor/models/jamba.py Show resolved Hide resolved
vllm/model_executor/models/jamba.py Outdated Show resolved Hide resolved
vllm/model_executor/models/jamba.py Outdated Show resolved Hide resolved
vllm/model_executor/models/jamba.py Outdated Show resolved Hide resolved
@mzusman
Copy link
Contributor Author

mzusman commented Oct 10, 2024

One suggestion: To mirror the PAD_SLOT_ID that's in backends/utils.py, we could #define PAD_SLOT_ID -1 in csrc/attention_generic.cuh, so that we name the constant.

And then we can add comments letting the reader know to keep them in sync

I've taken another approach, passing the pad_slot_id through the kernel params, keeps it simpler since we don't need to keep the cpp and python variables in sync and keeps the mamba kernels as standalone entities. WDYT?

@mzusman mzusman changed the title [Kernel][Model] Continous batching for Jamba [Kernel][Model] Continuous batching for Jamba Oct 13, 2024
@mzusman
Copy link
Contributor Author

mzusman commented Oct 13, 2024

Pushed adaptations for #6484 , The PR is ready again.

@mzusman mzusman changed the title [Kernel][Model] Continuous batching for Jamba [Kernel][Model] Continuous batching for Jamba and Mamba Oct 14, 2024
@tlrmchlsmth
Copy link
Collaborator

Changing the title since Jamba and Mamba already support continuous batching -- this just makes it better

@tlrmchlsmth tlrmchlsmth changed the title [Kernel][Model] Continuous batching for Jamba and Mamba [Kernel][Model] Improve continuous batching for Jamba and Mamba Oct 15, 2024
@tlrmchlsmth
Copy link
Collaborator

One suggestion: To mirror the PAD_SLOT_ID that's in backends/utils.py, we could #define PAD_SLOT_ID -1 in csrc/attention_generic.cuh, so that we name the constant.
And then we can add comments letting the reader know to keep them in sync

I've taken another approach, passing the pad_slot_id through the kernel params, keeps it simpler since we don't need to keep the cpp and python variables in sync and keeps the mamba kernels as standalone entities. WDYT?

Sounds good to me

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried out the PR and am currently seeing the following error:

E           RuntimeError: Error in model execution (input dumped to /tmp/err_execute_model_input_20241015-011501.pkl): _C::causal_conv1d_fwd() expected at most 8 argument(s) but received 9 argument(s). Declaration: _C::causal_conv1d_fwd(Tensor($0! -> ) x, Tensor($1! -> ) weight, Tensor? bias_, Tensor($2! -> )? conv_states, Tensor? query_start_loc, Tensor? cache_indices, Tensor? has_initial_state, bool silu_activation) -> Tensor

If you're seeing the same issue, could you LMK when it's fixed? Looks good otherwise, thanks!

csrc/mamba/causal_conv1d/causal_conv1d.cu Outdated Show resolved Hide resolved
tests/kernels/test_causal_conv1d.py Outdated Show resolved Hide resolved
tests/kernels/test_causal_conv1d.py Show resolved Hide resolved
vllm/model_executor/layers/mamba/ops/causal_conv1d.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/mamba/ops/mamba_ssm.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/mamba/ops/mamba_ssm.py Outdated Show resolved Hide resolved
@mzusman
Copy link
Contributor Author

mzusman commented Oct 15, 2024

I tried out the PR and am currently seeing the following error:

E           RuntimeError: Error in model execution (input dumped to /tmp/err_execute_model_input_20241015-011501.pkl): _C::causal_conv1d_fwd() expected at most 8 argument(s) but received 9 argument(s). Declaration: _C::causal_conv1d_fwd(Tensor($0! -> ) x, Tensor($1! -> ) weight, Tensor? bias_, Tensor($2! -> )? conv_states, Tensor? query_start_loc, Tensor? cache_indices, Tensor? has_initial_state, bool silu_activation) -> Tensor

If you're seeing the same issue, could you LMK when it's fixed? Looks good otherwise, thanks!

@tlrmchlsmth I cannot reproduce this error, could you let me know how did you run into it?

@tlrmchlsmth
Copy link
Collaborator

@tlrmchlsmth I cannot reproduce this error, could you let me know how did you run into it?

Possibly user error, let's see if it goes through the CI ;)

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 16, 2024
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work, much nicer management of the Mamba Mixer state.

Now it's my turn to extend mamba_chunk_scan_combined to support this for Mamba2 :)

@tlrmchlsmth
Copy link
Collaborator

Merging now! from @simon-mo the readthedocs failure looks transient

@tlrmchlsmth tlrmchlsmth merged commit fb60ae9 into vllm-project:main Oct 16, 2024
87 of 88 checks passed
charlifu pushed a commit to charlifu/vllm that referenced this pull request Oct 23, 2024
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Oct 23, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants