Skip to content

Commit

Permalink
[Core] Combined support for multi-step scheduling, chunked prefill & …
Browse files Browse the repository at this point in the history
…prefix caching (vllm-project#8804)

Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Andrew Feldman <[email protected]>
  • Loading branch information
3 people authored Oct 2, 2024
1 parent 1570203 commit 563649a
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 17 deletions.
158 changes: 158 additions & 0 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Test the LLMEngine with multi-step-decoding

import copy
from typing import Optional

import pytest
Expand Down Expand Up @@ -196,3 +197,160 @@ def test_multi_step_llm_w_prompt_logprobs(
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
def test_multi_step_llm_chunked_prefill_prefix_cache(
vllm_runner,
example_prompts,
model: str,
dtype: str,
tp_size: int,
max_tokens: int,
enforce_eager: int,
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
) -> None:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
Set up contrived scenario which tests for a possible failure mode of
scheduling with multi-step+"single-step chunked prefill"+APC
"single-step chunked prefill" here refers to the current vLLM multi-step+
chunked-prefill implementation, which requires that a prefill may only
be scheduled in the same step as decodes if the prefill prompt fits in a
single chunk (note that "complete" multi-step+chunked-prefill would allow
a prefill to span multiple chunks & multiple steps but that is not yet
the case.)
"APC" is short for "automatic prefix caching".
This test creates a scenario where the scheduler must decide whether/how
to schedule a prefill with a prompt that exceeds the available token budget.
The correct behavior for multi-step+"single-step chunked prefill"+APC is to
put off scheduling the prefill until a future step.
Validate that:
* Multi-step kernels do not raise an exception due to incorrect scheduler
behavior
* Generated tokens match between
multi-step+"single-step chunked prefill"+APC and
single-step scheduling.
* (If logprobs are enabled) check logprobs are close enough
Args:
vllm_runner: vLLM model runner fixture
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned.
"""

# Set up contrived test for correct scheduling behavior with
# multi-step+"single-step chunked prefill"+APC.
#
# Assume block_size=16
#
# Assume max_num_batched_tokens=48
# => Per-step token budget=48
#
# 1. Scheduler schedules 0th prompt (24 tokens)
# => Remaining token budget=24
# 2. Scheduler attempts to schedule 1st prompt (30 tokens)
# * 30 tokens exceeds 24 token remaining budget
# * Correct behavior: do not schedule this prompt in this step
# * Incorrect behavior: schedule prompt chunk
# * `do_sample=False` for this prompt in this step
# * Chunk size = (remaining tokens // block size) * block size
#
# The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`.
assert len(example_prompts) >= 2
challenge_prompts = copy.deepcopy(example_prompts)
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient '
'inference and serving engine for LLMs.\n'
) # 24 tok
challenge_prompts[1] = (
'Briefly describe the major milestones in the '
'development of artificial intelligence from 1950 to 2020.\n'
) # 30 tok

# If necessary, adjust the length of `challenge_prompts` to match
# `num_prompts`
if len(challenge_prompts) < num_prompts:
challenge_prompts = (challenge_prompts *
((num_prompts // len(challenge_prompts)) + 1))
challenge_prompts = challenge_prompts[:num_prompts]
assert len(challenge_prompts) == num_prompts

# Single-step scheduler baseline
with vllm_runner(
model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps,
max_model_len=48,
max_num_batched_tokens=48,
max_num_seqs=4,
block_size=16,
) as vllm_model:
outputs_baseline = (vllm_model.generate_greedy(
challenge_prompts, max_tokens) if num_logprobs is None else
vllm_model.generate_greedy_logprobs(
challenge_prompts, max_tokens, num_logprobs))

# multi-step+"single-step chunked prefill"+APC
with vllm_runner(
model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
enable_chunked_prefill=True,
enable_prefix_caching=True,
num_scheduler_steps=num_scheduler_steps,
max_model_len=48,
max_num_batched_tokens=48,
max_num_seqs=4,
block_size=16,
) as vllm_model:
outputs_w_features = (vllm_model.generate_greedy(
challenge_prompts, max_tokens) if num_logprobs is None else
vllm_model.generate_greedy_logprobs(
challenge_prompts, max_tokens, num_logprobs))

if num_logprobs is None:
# No-logprobs test
check_outputs_equal(
outputs_0_lst=outputs_baseline,
outputs_1_lst=outputs_w_features,
name_0="multi-step",
name_1="multi-step+features",
)
else:
# Yes-logprobs test
check_logprobs_close(
outputs_0_lst=outputs_baseline,
outputs_1_lst=outputs_w_features,
name_0="multi-step",
name_1="multi-step+features",
)
35 changes: 22 additions & 13 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,10 +1607,29 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
# in a decode phase. Do not chunk.
if enable_chunking and len(seqs) == 1:
remaining_token_budget = budget.remaining_token_budget()
if self.cache_config.enable_prefix_caching:
if self.scheduler_config.is_multi_step:
# The current multi-step + chunked prefill capability does
# not actually support chunking prompts.
#
# Therefore, `num_new_tokens` is computed in the same fashion
# for both multi-step+chunked-prefill &
# multi-step+chunked-prefill+APC
#
# Prompts with more tokens than the current remaining budget
# are postponed to future scheduler steps
if num_new_tokens > self._get_prompt_limit(seq_group):
# If the seq_group is in prompt-stage, pass the
# num_new_tokens as-is so the caller can ignore
# the sequence.
pass
else:
num_new_tokens = 0 \
if num_new_tokens > remaining_token_budget \
else num_new_tokens
elif self.cache_config.enable_prefix_caching:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block size
# to avoid partial block matching.
# the number of new tokens that is dividable by the block
# size to avoid partial block matching.
block_size = self.cache_config.block_size
remainder = budget.token_budget % block_size
if remainder != 0:
Expand All @@ -1623,16 +1642,6 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
if remaining_token_budget < num_new_tokens:
num_new_tokens = (remaining_token_budget //
block_size) * block_size
elif self.scheduler_config.is_multi_step:
if num_new_tokens > self._get_prompt_limit(seq_group):
# If the seq_group is in prompt-stage, pass the
# num_new_tokens as-is so the caller can ignore
# the sequence.
pass
else:
num_new_tokens = 0 \
if num_new_tokens > remaining_token_budget \
else num_new_tokens
else:
num_new_tokens = min(num_new_tokens, remaining_token_budget)
return num_new_tokens
4 changes: 0 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,10 +999,6 @@ def create_engine_config(self) -> EngineConfig:
if speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")
if self.enable_chunked_prefill and self.enable_prefix_caching:
raise ValueError("Multi-Step is not supported with "
"both Chunked-Prefill and Prefix-Caching "
"enabled together.")
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
raise ValueError("Multi-Step Chunked-Prefill is not supported "
"for pipeline-parallel-size > 1")
Expand Down

0 comments on commit 563649a

Please sign in to comment.