Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix for Spec model TP + Chunked Prefill #10232

Merged
merged 19 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/serving/compatibility_matrix.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Feature x Feature
-
-
* - :ref:`SD <spec_decode>`
-
-
- ✅
- ✗
- ✅
Expand Down
39 changes: 39 additions & 0 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,45 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots):
assert out.num_batched_tokens == max_num_batched_tokens


@pytest.mark.parametrize("num_scheduler_steps", [1, 5])
def test_chunked_prefill_spec_prefill(num_scheduler_steps):
"""Verify that the num_lookahead_slots is set appropriately for an all"""
"""prefill batch depending on whether multi-step scheduling is enabled"""
"""or not"""
block_size = 4
max_seqs = 30
max_model_len = 200
max_num_batched_tokens = 30
num_lookahead_slots = 4
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
num_lookahead_slots=num_lookahead_slots,
num_scheduler_steps=num_scheduler_steps,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 16
scheduler = Scheduler(scheduler_config, cache_config, None)

_, seq_group = create_dummy_prompt("1",
prompt_length=30,
block_size=block_size)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked.
# prefill scheduled now.
assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == max_num_batched_tokens
print(out.num_lookahead_slots)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we remove this print?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch!

assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else
num_lookahead_slots)


def test_chunked_prefill_max_seqs():
block_size = 4
max_seqs = 2
Expand Down
46 changes: 0 additions & 46 deletions tests/spec_decode/e2e/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,49 +50,3 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)


@pytest.mark.parametrize("common_llm_kwargs",
[{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": "True",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"tensor_parallel_size": 2,
"speculative_draft_tensor_parallel_size": 2,
},
{
"tensor_parallel_size": 4,
"speculative_draft_tensor_parallel_size": 4,
},
{
"tensor_parallel_size": 8,
"speculative_draft_tensor_parallel_size": 8,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one(
test_llm_generator):
"""Verify that speculative decoding fails if chunked prefill is enabled for
draft model with tensor parallelism of more than 1.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

with pytest.raises(ValueError, match="with tensor parallel size 1"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
57 changes: 57 additions & 0 deletions tests/spec_decode/e2e/test_integration_dist_tp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,60 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
max_output_len=32,
seed=seed,
temperature=0.0)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[[
# Skip cuda graph recording for fast test.
"--enforce-eager",
"--tensor_parallel_size",
"2",
# precision
"--dtype",
"bfloat16",
]])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[["--enable-chunked-prefill", "False"],
[
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
"--max-num-seqs", "4"
]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
]),
("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
"--speculative-draft-tensor-parallel-size",
"1",
])])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, seed: int):
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
run_equality_correctness_test_tp(model,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)
3 changes: 2 additions & 1 deletion tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,8 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
target_group_metadata_list = prefill + decodes
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=target_group_metadata_list,
num_lookahead_slots=k)
# For prefill only batches we expect num_lookahead_slots = 0.
num_lookahead_slots=k if n_decodes > 0 else 0)

target_token_ids = torch.randint(low=0,
high=vocab_size,
Expand Down
10 changes: 0 additions & 10 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,16 +1409,6 @@ def maybe_create_spec_config(
draft_hf_config
)

if (enable_chunked_prefill and \
speculative_draft_tensor_parallel_size != 1):
# TODO - Investigate why the error reported in
# https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258
# is happening and re-enable it.
raise ValueError(
"Chunked prefill and speculative decoding can be enabled "
"simultaneously only for draft models with tensor "
"parallel size 1.")

draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
Expand Down
28 changes: 19 additions & 9 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,15 +1201,25 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
# Put prefills first due to Attention backend ordering assumption.
scheduled_seq_groups = (prefills.seq_groups +
running_scheduled.prefill_seq_groups +
swapped_in.prefill_seq_groups +
running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups)
num_prefill_groups = (len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups))
# If all prompts, then we set num_lookahead_slots to 0
# this allows us to go through the `no_spec` path in
# `spec_decode_worker.py`
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups)
andoorve marked this conversation as resolved.
Show resolved Hide resolved
num_lookahead_slots = (0 if
(all_prefills
and not self.scheduler_config.is_multi_step)
else running_scheduled.num_lookahead_slots)
return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.prefill_seq_groups +
swapped_in.prefill_seq_groups +
running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups),
num_prefill_groups=(len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups)),
scheduled_seq_groups=scheduled_seq_groups,
num_prefill_groups=num_prefill_groups,
num_batched_tokens=budget.num_batched_tokens +
budget.num_cached_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
Expand All @@ -1218,7 +1228,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
num_lookahead_slots=num_lookahead_slots,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varun-sundar-rabindranath could you also review this part to see if this will break multi-step scheduling with chunked prefill?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the Tag. I believe it will affect performance.
multi-step + chunked-prefill allows for having look-ahead slots even when all the sequences are prefills. The sequences are processed as prefills in step 1 and are processed as decodes in steps 2 - n.
Setting the lookahead_slots to 0, will force single stepping for the all-prefills case. I can get some profiles.

@andoorve is there a way to make this update only if spec decode is enabled ? I believe that would be safer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @varun-sundar-rabindranath I think that should be possible, thanks for the feedback! Let me see how we can do that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varun-sundar-rabindranath @comaniac Can you check whether this condition makes sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @andoorve - The condition looks good 👍

running_queue_size=len(self.running),
preempted=(len(running_scheduled.preempted) +
len(running_scheduled.swapped_out)),
Expand Down
33 changes: 27 additions & 6 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,20 @@ def execute_model(
disable_all_speculation = self._should_disable_all_speculation(
execute_model_req)
num_lookahead_slots = execute_model_req.num_lookahead_slots

all_prompt = True
atleast_one_prompt = False
all_zero_spec_tokens = True
for sgm in execute_model_req.seq_group_metadata_list:
all_prompt = all_prompt and sgm.is_prompt
atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
all_zero_spec_tokens = all_zero_spec_tokens and (
sgm.num_speculative_tokens == 0)

if all_prompt and execute_model_req.seq_group_metadata_list:
assert num_lookahead_slots == 0, (
"Prompt only runs should have num_lookahead_slots equal to 0. "
"This should never happen, please file a bug at "
"https://github.com/vllm-project/vllm/issues")
# Speculative decoding is disabled in the following cases:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
Expand All @@ -419,11 +432,8 @@ def execute_model(
# In any of these cases, the proposer and scorer workers
# are called normally.
# We expect `num_speculative_tokens` to be None for prefills.
no_spec = all(
sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list
) or num_lookahead_slots == 0 or disable_all_speculation or all(
sgm.num_speculative_tokens == 0
for sgm in execute_model_req.seq_group_metadata_list)
no_spec = (num_lookahead_slots == 0 or disable_all_speculation
or all_zero_spec_tokens)

# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
Expand All @@ -442,6 +452,15 @@ def execute_model(
num_lookahead_slots=num_lookahead_slots,
no_spec=no_spec,
disable_all_speculation=disable_all_speculation,
# When both chunked prefill and speculative decoding are enabled
# it is possible that the same batch contains both prefill
# and decodes. If that happens in the scorer we run the batch
# as one single forward pass. However, in the proposer we
# run them as 2 different batches - one for prefill and
# the other for decodes. The variable indicates to the non-driver
# worker that there are prefills as part of the speculative batch
# and hence it needs to run an extra prefill forward pass.
run_spec_proposer_for_prefill=atleast_one_prompt,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if we got a sanity check from @NickLucche or someone!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on it, sorry for the late ack

)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

Expand Down Expand Up @@ -653,6 +672,8 @@ def _run_non_driver_rank(self) -> bool:

if not data["no_spec"]:
self.scorer_worker.execute_model()
if data["run_spec_proposer_for_prefill"]:
self.proposer_worker.execute_model()

return True

Expand Down