-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Fix for Spec model TP + Chunked Prefill #10232
Changes from all commits
f3a6ed5
8aacb66
086811e
aaa7884
a43b19b
e7576f7
99d6e1b
c2a0b81
7b5cafd
5ca1bb8
18187d4
2d3d16f
964e9f6
b517bbe
c6127eb
31b0ddf
c232b84
2d99f39
01b43aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,7 +118,7 @@ Feature x Feature | |
- | ||
- | ||
* - :ref:`SD <spec_decode>` | ||
- ✗ | ||
- ✅ | ||
- ✅ | ||
- ✗ | ||
- ✅ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1201,15 +1201,25 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: | |
# Update swapped requests. | ||
self.swapped.extend(running_scheduled.swapped_out) | ||
# Put prefills first due to Attention backend ordering assumption. | ||
scheduled_seq_groups = (prefills.seq_groups + | ||
running_scheduled.prefill_seq_groups + | ||
swapped_in.prefill_seq_groups + | ||
running_scheduled.decode_seq_groups + | ||
swapped_in.decode_seq_groups) | ||
num_prefill_groups = (len(prefills.seq_groups) + | ||
len(swapped_in.prefill_seq_groups) + | ||
len(running_scheduled.prefill_seq_groups)) | ||
# If all prompts, then we set num_lookahead_slots to 0 | ||
# this allows us to go through the `no_spec` path in | ||
# `spec_decode_worker.py` | ||
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups) | ||
andoorve marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_lookahead_slots = (0 if | ||
(all_prefills | ||
and not self.scheduler_config.is_multi_step) | ||
else running_scheduled.num_lookahead_slots) | ||
return SchedulerOutputs( | ||
scheduled_seq_groups=(prefills.seq_groups + | ||
running_scheduled.prefill_seq_groups + | ||
swapped_in.prefill_seq_groups + | ||
running_scheduled.decode_seq_groups + | ||
swapped_in.decode_seq_groups), | ||
num_prefill_groups=(len(prefills.seq_groups) + | ||
len(swapped_in.prefill_seq_groups) + | ||
len(running_scheduled.prefill_seq_groups)), | ||
scheduled_seq_groups=scheduled_seq_groups, | ||
num_prefill_groups=num_prefill_groups, | ||
num_batched_tokens=budget.num_batched_tokens + | ||
budget.num_cached_tokens, | ||
blocks_to_swap_in=swapped_in.blocks_to_swap_in, | ||
|
@@ -1218,7 +1228,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: | |
swapped_in.blocks_to_copy, | ||
ignored_seq_groups=prefills.ignored_seq_groups + | ||
swapped_in.infeasible_seq_groups, | ||
num_lookahead_slots=running_scheduled.num_lookahead_slots, | ||
num_lookahead_slots=num_lookahead_slots, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @varun-sundar-rabindranath could you also review this part to see if this will break multi-step scheduling with chunked prefill? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the Tag. I believe it will affect performance. @andoorve is there a way to make this update only if spec decode is enabled ? I believe that would be safer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @varun-sundar-rabindranath I think that should be possible, thanks for the feedback! Let me see how we can do that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @varun-sundar-rabindranath @comaniac Can you check whether this condition makes sense? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @andoorve - The condition looks good 👍 |
||
running_queue_size=len(self.running), | ||
preempted=(len(running_scheduled.preempted) + | ||
len(running_scheduled.swapped_out)), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -408,7 +408,20 @@ def execute_model( | |
disable_all_speculation = self._should_disable_all_speculation( | ||
execute_model_req) | ||
num_lookahead_slots = execute_model_req.num_lookahead_slots | ||
|
||
all_prompt = True | ||
atleast_one_prompt = False | ||
all_zero_spec_tokens = True | ||
for sgm in execute_model_req.seq_group_metadata_list: | ||
all_prompt = all_prompt and sgm.is_prompt | ||
atleast_one_prompt = atleast_one_prompt or sgm.is_prompt | ||
all_zero_spec_tokens = all_zero_spec_tokens and ( | ||
sgm.num_speculative_tokens == 0) | ||
|
||
if all_prompt and execute_model_req.seq_group_metadata_list: | ||
assert num_lookahead_slots == 0, ( | ||
"Prompt only runs should have num_lookahead_slots equal to 0. " | ||
"This should never happen, please file a bug at " | ||
"https://github.com/vllm-project/vllm/issues") | ||
# Speculative decoding is disabled in the following cases: | ||
# 1. Prefill phase: Speculative decoding is not | ||
# used during the prefill phase. | ||
|
@@ -419,11 +432,8 @@ def execute_model( | |
# In any of these cases, the proposer and scorer workers | ||
# are called normally. | ||
# We expect `num_speculative_tokens` to be None for prefills. | ||
no_spec = all( | ||
sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list | ||
) or num_lookahead_slots == 0 or disable_all_speculation or all( | ||
sgm.num_speculative_tokens == 0 | ||
for sgm in execute_model_req.seq_group_metadata_list) | ||
no_spec = (num_lookahead_slots == 0 or disable_all_speculation | ||
or all_zero_spec_tokens) | ||
|
||
# Broadcast how many lookahead slots are scheduled for this step, and | ||
# whether all speculation is disabled, to all non-driver workers. | ||
|
@@ -442,6 +452,15 @@ def execute_model( | |
num_lookahead_slots=num_lookahead_slots, | ||
no_spec=no_spec, | ||
disable_all_speculation=disable_all_speculation, | ||
# When both chunked prefill and speculative decoding are enabled | ||
# it is possible that the same batch contains both prefill | ||
# and decodes. If that happens in the scorer we run the batch | ||
# as one single forward pass. However, in the proposer we | ||
# run them as 2 different batches - one for prefill and | ||
# the other for decodes. The variable indicates to the non-driver | ||
# worker that there are prefills as part of the speculative batch | ||
# and hence it needs to run an extra prefill forward pass. | ||
run_spec_proposer_for_prefill=atleast_one_prompt, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be great if we got a sanity check from @NickLucche or someone! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. on it, sorry for the late ack |
||
) | ||
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) | ||
|
||
|
@@ -653,6 +672,8 @@ def _run_non_driver_rank(self) -> bool: | |
|
||
if not data["no_spec"]: | ||
self.scorer_worker.execute_model() | ||
if data["run_spec_proposer_for_prefill"]: | ||
self.proposer_worker.execute_model() | ||
|
||
return True | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we remove this print?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch!