Skip to content

Commit

Permalink
[Bugfix] Fix for Spec model TP + Chunked Prefill (vllm-project#10232)
Browse files Browse the repository at this point in the history
Signed-off-by: andoorve <[email protected]>
Signed-off-by: Sourashis Roy <[email protected]>
Co-authored-by: Sourashis Roy <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
  • Loading branch information
2 people authored and afeldman-nm committed Dec 2, 2024
1 parent f7833f3 commit 704d635
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 73 deletions.
2 changes: 1 addition & 1 deletion docs/source/serving/compatibility_matrix.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Feature x Feature
-
-
* - :ref:`SD <spec_decode>`
-
-
- ✅
- ✗
- ✅
Expand Down
39 changes: 39 additions & 0 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,45 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots):
assert out.num_batched_tokens == max_num_batched_tokens


@pytest.mark.parametrize("num_scheduler_steps", [1, 5])
def test_chunked_prefill_spec_prefill(num_scheduler_steps):
"""Verify that the num_lookahead_slots is set appropriately for an all"""
"""prefill batch depending on whether multi-step scheduling is enabled"""
"""or not"""
block_size = 4
max_seqs = 30
max_model_len = 200
max_num_batched_tokens = 30
num_lookahead_slots = 4
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
num_lookahead_slots=num_lookahead_slots,
num_scheduler_steps=num_scheduler_steps,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 16
scheduler = Scheduler(scheduler_config, cache_config, None)

_, seq_group = create_dummy_prompt("1",
prompt_length=30,
block_size=block_size)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked.
# prefill scheduled now.
assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == max_num_batched_tokens
print(out.num_lookahead_slots)
assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else
num_lookahead_slots)


def test_chunked_prefill_max_seqs():
block_size = 4
max_seqs = 2
Expand Down
46 changes: 0 additions & 46 deletions tests/spec_decode/e2e/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,49 +50,3 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)


@pytest.mark.parametrize("common_llm_kwargs",
[{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": "True",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"tensor_parallel_size": 2,
"speculative_draft_tensor_parallel_size": 2,
},
{
"tensor_parallel_size": 4,
"speculative_draft_tensor_parallel_size": 4,
},
{
"tensor_parallel_size": 8,
"speculative_draft_tensor_parallel_size": 8,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one(
test_llm_generator):
"""Verify that speculative decoding fails if chunked prefill is enabled for
draft model with tensor parallelism of more than 1.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

with pytest.raises(ValueError, match="with tensor parallel size 1"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
57 changes: 57 additions & 0 deletions tests/spec_decode/e2e/test_integration_dist_tp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,60 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
max_output_len=32,
seed=seed,
temperature=0.0)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[[
# Skip cuda graph recording for fast test.
"--enforce-eager",
"--tensor_parallel_size",
"2",
# precision
"--dtype",
"bfloat16",
]])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[["--enable-chunked-prefill", "False"],
[
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
"--max-num-seqs", "4"
]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
]),
("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
"--speculative-draft-tensor-parallel-size",
"1",
])])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, seed: int):
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
run_equality_correctness_test_tp(model,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)
3 changes: 2 additions & 1 deletion tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,8 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
target_group_metadata_list = prefill + decodes
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=target_group_metadata_list,
num_lookahead_slots=k)
# For prefill only batches we expect num_lookahead_slots = 0.
num_lookahead_slots=k if n_decodes > 0 else 0)

target_token_ids = torch.randint(low=0,
high=vocab_size,
Expand Down
10 changes: 0 additions & 10 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,16 +1409,6 @@ def maybe_create_spec_config(
draft_hf_config
)

if (enable_chunked_prefill and \
speculative_draft_tensor_parallel_size != 1):
# TODO - Investigate why the error reported in
# https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258
# is happening and re-enable it.
raise ValueError(
"Chunked prefill and speculative decoding can be enabled "
"simultaneously only for draft models with tensor "
"parallel size 1.")

draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
Expand Down
28 changes: 19 additions & 9 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,15 +1201,25 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
# Put prefills first due to Attention backend ordering assumption.
scheduled_seq_groups = (prefills.seq_groups +
running_scheduled.prefill_seq_groups +
swapped_in.prefill_seq_groups +
running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups)
num_prefill_groups = (len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups))
# If all prompts, then we set num_lookahead_slots to 0
# this allows us to go through the `no_spec` path in
# `spec_decode_worker.py`
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups)
num_lookahead_slots = (0 if
(all_prefills
and not self.scheduler_config.is_multi_step)
else running_scheduled.num_lookahead_slots)
return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.prefill_seq_groups +
swapped_in.prefill_seq_groups +
running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups),
num_prefill_groups=(len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups)),
scheduled_seq_groups=scheduled_seq_groups,
num_prefill_groups=num_prefill_groups,
num_batched_tokens=budget.num_batched_tokens +
budget.num_cached_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
Expand All @@ -1218,7 +1228,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
num_lookahead_slots=num_lookahead_slots,
running_queue_size=len(self.running),
preempted=(len(running_scheduled.preempted) +
len(running_scheduled.swapped_out)),
Expand Down
33 changes: 27 additions & 6 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,20 @@ def execute_model(
disable_all_speculation = self._should_disable_all_speculation(
execute_model_req)
num_lookahead_slots = execute_model_req.num_lookahead_slots

all_prompt = True
atleast_one_prompt = False
all_zero_spec_tokens = True
for sgm in execute_model_req.seq_group_metadata_list:
all_prompt = all_prompt and sgm.is_prompt
atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
all_zero_spec_tokens = all_zero_spec_tokens and (
sgm.num_speculative_tokens == 0)

if all_prompt and execute_model_req.seq_group_metadata_list:
assert num_lookahead_slots == 0, (
"Prompt only runs should have num_lookahead_slots equal to 0. "
"This should never happen, please file a bug at "
"https://github.com/vllm-project/vllm/issues")
# Speculative decoding is disabled in the following cases:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
Expand All @@ -419,11 +432,8 @@ def execute_model(
# In any of these cases, the proposer and scorer workers
# are called normally.
# We expect `num_speculative_tokens` to be None for prefills.
no_spec = all(
sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list
) or num_lookahead_slots == 0 or disable_all_speculation or all(
sgm.num_speculative_tokens == 0
for sgm in execute_model_req.seq_group_metadata_list)
no_spec = (num_lookahead_slots == 0 or disable_all_speculation
or all_zero_spec_tokens)

# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
Expand All @@ -442,6 +452,15 @@ def execute_model(
num_lookahead_slots=num_lookahead_slots,
no_spec=no_spec,
disable_all_speculation=disable_all_speculation,
# When both chunked prefill and speculative decoding are enabled
# it is possible that the same batch contains both prefill
# and decodes. If that happens in the scorer we run the batch
# as one single forward pass. However, in the proposer we
# run them as 2 different batches - one for prefill and
# the other for decodes. The variable indicates to the non-driver
# worker that there are prefills as part of the speculative batch
# and hence it needs to run an extra prefill forward pass.
run_spec_proposer_for_prefill=atleast_one_prompt,
)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

Expand Down Expand Up @@ -653,6 +672,8 @@ def _run_non_driver_rank(self) -> bool:

if not data["no_spec"]:
self.scorer_worker.execute_model()
if data["run_spec_proposer_for_prefill"]:
self.proposer_worker.execute_model()

return True

Expand Down

0 comments on commit 704d635

Please sign in to comment.