diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index ff413e8e2da3f..f45428675bde8 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -1,5 +1,6 @@ # Test the LLMEngine with multi-step-decoding +import copy from typing import Optional import pytest @@ -196,3 +197,160 @@ def test_multi_step_llm_w_prompt_logprobs( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs", [None, 5]) +def test_multi_step_llm_chunked_prefill_prefix_cache( + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + max_tokens: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + num_logprobs: Optional[int], +) -> None: + """Test vLLM engine with multi-step+"single-step chunked prefill"+APC. + + Set up contrived scenario which tests for a possible failure mode of + scheduling with multi-step+"single-step chunked prefill"+APC + + "single-step chunked prefill" here refers to the current vLLM multi-step+ + chunked-prefill implementation, which requires that a prefill may only + be scheduled in the same step as decodes if the prefill prompt fits in a + single chunk (note that "complete" multi-step+chunked-prefill would allow + a prefill to span multiple chunks & multiple steps but that is not yet + the case.) + + "APC" is short for "automatic prefix caching". + + This test creates a scenario where the scheduler must decide whether/how + to schedule a prefill with a prompt that exceeds the available token budget. + The correct behavior for multi-step+"single-step chunked prefill"+APC is to + put off scheduling the prefill until a future step. + + Validate that: + * Multi-step kernels do not raise an exception due to incorrect scheduler + behavior + * Generated tokens match between + multi-step+"single-step chunked prefill"+APC and + single-step scheduling. + * (If logprobs are enabled) check logprobs are close enough + + Args: + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + max_tokens: the maximum number of tokens to generate + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> 1 logprob returned. + """ + + # Set up contrived test for correct scheduling behavior with + # multi-step+"single-step chunked prefill"+APC. + # + # Assume block_size=16 + # + # Assume max_num_batched_tokens=48 + # => Per-step token budget=48 + # + # 1. Scheduler schedules 0th prompt (24 tokens) + # => Remaining token budget=24 + # 2. Scheduler attempts to schedule 1st prompt (30 tokens) + # * 30 tokens exceeds 24 token remaining budget + # * Correct behavior: do not schedule this prompt in this step + # * Incorrect behavior: schedule prompt chunk + # * `do_sample=False` for this prompt in this step + # * Chunk size = (remaining tokens // block size) * block size + # + # The Incorrect scheduling behavior - if it occurs - will cause an exception + # in the model runner resulting from `do_sample=False`. + assert len(example_prompts) >= 2 + challenge_prompts = copy.deepcopy(example_prompts) + challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' + 'inference and serving engine for LLMs.\n' + ) # 24 tok + challenge_prompts[1] = ( + 'Briefly describe the major milestones in the ' + 'development of artificial intelligence from 1950 to 2020.\n' + ) # 30 tok + + # If necessary, adjust the length of `challenge_prompts` to match + # `num_prompts` + if len(challenge_prompts) < num_prompts: + challenge_prompts = (challenge_prompts * + ((num_prompts // len(challenge_prompts)) + 1)) + challenge_prompts = challenge_prompts[:num_prompts] + assert len(challenge_prompts) == num_prompts + + # Single-step scheduler baseline + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + max_model_len=48, + max_num_batched_tokens=48, + max_num_seqs=4, + block_size=16, + ) as vllm_model: + outputs_baseline = (vllm_model.generate_greedy( + challenge_prompts, max_tokens) if num_logprobs is None else + vllm_model.generate_greedy_logprobs( + challenge_prompts, max_tokens, num_logprobs)) + + # multi-step+"single-step chunked prefill"+APC + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + enable_chunked_prefill=True, + enable_prefix_caching=True, + num_scheduler_steps=num_scheduler_steps, + max_model_len=48, + max_num_batched_tokens=48, + max_num_seqs=4, + block_size=16, + ) as vllm_model: + outputs_w_features = (vllm_model.generate_greedy( + challenge_prompts, max_tokens) if num_logprobs is None else + vllm_model.generate_greedy_logprobs( + challenge_prompts, max_tokens, num_logprobs)) + + if num_logprobs is None: + # No-logprobs test + check_outputs_equal( + outputs_0_lst=outputs_baseline, + outputs_1_lst=outputs_w_features, + name_0="multi-step", + name_1="multi-step+features", + ) + else: + # Yes-logprobs test + check_logprobs_close( + outputs_0_lst=outputs_baseline, + outputs_1_lst=outputs_w_features, + name_0="multi-step", + name_1="multi-step+features", + ) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5b7587d150843..f3a5016d0e62a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1607,10 +1607,29 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, # in a decode phase. Do not chunk. if enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() - if self.cache_config.enable_prefix_caching: + if self.scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > remaining_token_budget \ + else num_new_tokens + elif self.cache_config.enable_prefix_caching: # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block size - # to avoid partial block matching. + # the number of new tokens that is dividable by the block + # size to avoid partial block matching. block_size = self.cache_config.block_size remainder = budget.token_budget % block_size if remainder != 0: @@ -1623,16 +1642,6 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, if remaining_token_budget < num_new_tokens: num_new_tokens = (remaining_token_budget // block_size) * block_size - elif self.scheduler_config.is_multi_step: - if num_new_tokens > self._get_prompt_limit(seq_group): - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > remaining_token_budget \ - else num_new_tokens else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 64fa7360b95b8..c97b6ffb093f7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -999,10 +999,6 @@ def create_engine_config(self) -> EngineConfig: if speculative_config is not None: raise ValueError("Speculative decoding is not supported with " "multi-step (--num-scheduler-steps > 1)") - if self.enable_chunked_prefill and self.enable_prefix_caching: - raise ValueError("Multi-Step is not supported with " - "both Chunked-Prefill and Prefix-Caching " - "enabled together.") if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: raise ValueError("Multi-Step Chunked-Prefill is not supported " "for pipeline-parallel-size > 1")