-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] [Spec decode]: Combine chunked prefill with speculative decoding #9291
[Feature] [Spec decode]: Combine chunked prefill with speculative decoding #9291
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pr. Left some comments. PTAL
49b03ab
to
8b88b8a
Compare
@NickLucche Thanks for the great work and understand that is WIP, just small note while you are working on this piece We tried this PR with tensor parallelism and we found that it throughs the following exception when we activate tensor parallelism:
Here is the exception:
The following command works normally
Thanks again and appreciate your work/ VLLM community |
Thanks for testing that, will look right into it! |
Update on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pr. Left a few comments. PTAL.
0819d12
to
3e5b882
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pr! One comment about leaving out the unified kernel changes in this pr. Please check with @LiuXiaoxuanPKU and @comaniac on this. Otherwise LGTM.
Thanks for reviewing this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pr!! LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Good job! Only nits.
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: NickLucche <[email protected]>
cd9bd2a
to
ca2691e
Compare
Mmm apologies for the automatic call to review on so many people, had to sign commits and force push |
@NickLucche I think you need to remove test_spec_decode_xfail_chunked_prefill from spec_decode/e2e/test_compatibility.py since its no longer applicable. Could you also please sync your branch to the head. It seems like some of the failures e.g. in buildkite/ci-aws/pr/decoder-only-multi-modal-models-test might already be fixed in head. |
Signed-off-by: NickLucche <[email protected]>
756d33f
to
fb66563
Compare
Signed-off-by: NickLucche <[email protected]>
…odes Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @NickLucche for the awesome work, and to @sroy745 @LiuXiaoxuanPKU @comaniac for the reviews
Hi @NickLucche, thanks for the PR! I tried with TP on the latest
|
Hi @andoorve / @arashsadrieh python3 -m vllm.entrypoints.openai.api_server --model "meta-llama/Meta-Llama-3-70B-Instruct" --tensor-parallel-size 4 --disable-log-requests --enable-chunked-prefill --max_num_batched_tokens 2048 --speculative_model turboderp/Qwama-0.5B-Instruct --num_speculative_tokens 1 --speculative_draft_tensor_parallel_size 1 --disable-custom-all-reduce --swap_space 16 --speculative_disable_mqa_scorer What is the command you are using? One difference I think is that in our evals we ran with the speculative model running with tp=1 and the target model running with tp=4. Can you try and see if that works for you? |
Hey @NickLucche @sroy745, this is what I'm using. I think this is the difference, as I'm running with TP > 1 on the draft model as well. Unfortunately the Llama 8B draft model that I want to use is relatively large for TP=1.
|
I will add a check to verify that sd + chunked-prefill is enabled for tp=1 draft model and then continue with the investigation. It is not breaking any existing cases so will add the check and debug. |
…oding (vllm-project#9291) Signed-off-by: NickLucche <[email protected]> Signed-off-by: Isotr0py <[email protected]>
…oding (vllm-project#9291) Signed-off-by: NickLucche <[email protected]> Signed-off-by: Loc Huynh <[email protected]>
…oding (vllm-project#9291) Signed-off-by: NickLucche <[email protected]> Signed-off-by: Sumit Dubey <[email protected]>
…oding (vllm-project#9291) Signed-off-by: NickLucche <[email protected]>
…oding (vllm-project#9291) Signed-off-by: NickLucche <[email protected]> Signed-off-by: Maxime Fournioux <[email protected]>
…oding (vllm-project#9291) Signed-off-by: NickLucche <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]>
…oding (vllm-project#9291) Signed-off-by: NickLucche <[email protected]>
Hey, this PR implements #5016.
The main idea is to make use of the current Speculative Decoder workflow and integrate it with mixed prefill-decode batches.
In particular, we can run the batched prefills and decodes together through the scorer (with the usual prefill|decode layout supported by backend), while the proposer can sync its KV cache on prefills only.
Current attention kernel implementation still doesn't make full use of the prefill|decode, but once the MQA integration is finalized we can get an easy speedup by running the batch in a single forward.
Current implementation on main already is (to some extent) prefill aware, so I was able to re-use a good chunk of the logic and the changes aren't (purposely) drastic.
On the other hand, one could prioritize optimizations more and I am open to any suggestion on how to best implement the approach, even at the cost
of re-writing more parts and making the PR more invasive (ie breaking some of the interfaces to avoid duplication).
TODO:
fix speculative methods requiringEDIT: on second thought, I believe this would be better addressed in a separate PRreturn_hidden_states
disable_logprobs_during_spec_decoding
compatibilityUpdate:
We add support for chunk prefill and spec decoding with the workflow depicted above, unless the proposer requires final hidden state from the target model (
MLPSpeculator
/Medusa
): this is deferred to a second follow-up PR.mqa_scorer
is set to supersedeBatchExpansion*
thanks to the great work by @LiuXiaoxuanPKU, so we add support to that scorer directly in this PR!Incidentally, this means enabling backend withflash_attn_varlen_func
to take in any "mixed prefill-decode batch" in a single kernel call (so no more decoupled prefix-decode calls), which should also boost performance in "vanilla" chunked prefill scheduling policy (no spec).Many thanks to @sroy745 for benchmarking the
BatchExpansionTop1Scorer
approach here (MQA to follow)!Update 2:
After reviewing @sroy745 benchmarks, contrarily to expectations, fusing the two separate kernel call into a unified prefill+decode (single
flash_attn_varlen_func
call) did not yield improvements. I reverted the unifying kernel change, but I will keep the commit history here so we can come back to it and investigate some more on a separate optimization PR.