diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 13b5b80cccfdc..00a2379502e6d 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -42,9 +42,11 @@ def mock_causal_accepted_tensor( @pytest.mark.parametrize( "which_tokens_accepted", ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_correct_output_format(which_tokens_accepted: str, seed: int, +def test_correct_output_format(which_tokens_accepted: str, + disable_bonus_tokens: bool, seed: int, device: str): """Verify the output has correct format given predetermined accepted matrix. """ @@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, size=(batch_size, 1), dtype=torch.int64) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens) rejection_sampler.init_gpu_tensors(rank=0) output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access accepted, @@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, bonus_token_ids, ) - # Bonus tokens are currently disabled. Verify they're set to -1. + expected_bonus_token_ids = bonus_token_ids.clone() + # If bonus tokens disabled. Verify they are set to -1. # See https://github.com/vllm-project/vllm/issues/4212 - expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1 + if disable_bonus_tokens: + expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1 if which_tokens_accepted == "all_tokens_accepted": # Expect all tokens to be equal to draft tokens. diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index f15fcc4746d20..94d71fb012727 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_disable_by_batch_size": 2, + }, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_disable_speculation(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when all sequences disable speculation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 44ef400c91d34..c2004ff061a1e 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -57,7 +57,7 @@ @pytest.mark.parametrize("output_len", [ 256, ]) -@pytest.mark.parametrize("batch_size", [1, 64]) +@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) def test_ngram_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, batch_size: int, diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py new file mode 100644 index 0000000000000..948a74b22f0ae --- /dev/null +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.metrics import AsyncMetricsCollector +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker +from vllm.spec_decode.top1_proposer import Top1Proposer + +from .utils import create_batch, mock_worker + + +@pytest.mark.parametrize('queue_size', [2, 4]) +@pytest.mark.parametrize('batch_size', [1, 2, 3, 6]) +@pytest.mark.parametrize('k', [1, 2, 5, 7, 10]) +@torch.inference_mode() +def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): + """Verify that speculative tokens are disabled when the batch size + exceeds the threshold. + """ + disable_by_batch_size = 3 + + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + worker = SpecDecodeWorker(proposer_worker=draft_worker, + scorer_worker=target_worker, + rejection_sampler=rejection_sampler, + metrics_collector=metrics_collector, + disable_by_batch_size=disable_by_batch_size) + + exception_secret = 'artificial stop' + draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + running_queue_size=queue_size) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + + # When the batch size is larger than the threshold, + # we expect no speculative tokens (0). + expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0 + assert seq_group_metadata_list[ + 0].num_speculative_tokens == expected_num_spec_tokens + + draft_worker.sampler_output.side_effect = ValueError(exception_secret) + + proposer = Top1Proposer( + worker=draft_worker, + device='cpu', # not used + vocab_size=100, # not used + # Must be long enough to avoid being skipped due to length. + max_proposal_len=1024, + ) + + if queue_size < disable_by_batch_size: + # Should raise exception when executing the mocked draft model. + with pytest.raises(ValueError, match=exception_secret): + proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) + else: + # Should not execute the draft model because spec decode is disabled + # for all requests. Accordingly, the proposal length should be 0. + proposals = proposer.get_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) + assert proposals.proposal_lens.tolist() == [0] * batch_size diff --git a/vllm/config.py b/vllm/config.py index 5c3a8615eefb4..a2cb9b32c65fc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -692,6 +692,7 @@ def maybe_create_spec_config( speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], ) -> Optional["SpeculativeConfig"]: @@ -720,6 +721,9 @@ def maybe_create_spec_config( use_v2_block_manager (bool): Whether vLLM is configured to use the v2 block manager or not. Used for raising an error since the v2 block manager is required with spec decode. + speculative_disable_by_batch_size (Optional[int]): Disable + speculative decoding for new incoming requests when the number + of enqueue requests is larger than this value, if provided. ngram_prompt_lookup_max (Optional[int]): Max size of ngram token window, if provided. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token @@ -730,7 +734,7 @@ def maybe_create_spec_config( the necessary conditions are met, else None. """ - if (speculative_model is None and num_speculative_tokens is None): + if speculative_model is None and num_speculative_tokens is None: return None if speculative_model is not None and num_speculative_tokens is None: @@ -739,6 +743,12 @@ def maybe_create_spec_config( "num_speculative_tokens to be provided, but found " f"{speculative_model=} and {num_speculative_tokens=}.") + if (speculative_disable_by_batch_size is not None + and speculative_disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{speculative_disable_by_batch_size=}") + assert (speculative_model is not None and num_speculative_tokens is not None) @@ -807,6 +817,7 @@ def maybe_create_spec_config( draft_model_config, draft_parallel_config, num_speculative_tokens, + speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, ) @@ -876,8 +887,9 @@ def __init__( draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, - ngram_prompt_lookup_max: int, - ngram_prompt_lookup_min: int, + speculative_disable_by_batch_size: Optional[int], + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], ): """Create a SpeculativeConfig object. @@ -886,12 +898,19 @@ def __init__( draft_parallel_config: ParallelConfig for the draft model. num_speculative_tokens: The number of tokens to sample from the draft model before scoring with the target model. + speculative_disable_by_batch_size: Disable speculative + decoding for new incoming requests when the number of + enqueue requests is larger than this value. + ngram_prompt_lookup_max: Max size of ngram token window. + ngram_prompt_lookup_min: Min size of ngram token window. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens - self.ngram_prompt_lookup_max = ngram_prompt_lookup_max - self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + self.speculative_disable_by_batch_size = \ + speculative_disable_by_batch_size + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bb8245eb307f7..c99b1806c7d1d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -83,6 +83,7 @@ class EngineArgs: speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None speculative_max_model_len: Optional[int] = None + speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None @@ -467,6 +468,13 @@ def add_cli_args( 'draft model. Sequences over this length will skip ' 'speculation.') + parser.add_argument( + '--speculative-disable-by-batch-size', + type=int, + default=EngineArgs.speculative_disable_by_batch_size, + help='Disable speculative decoding for new incoming requests ' + 'if the number of enqueue requests is larger than this value.') + parser.add_argument( '--ngram-prompt-lookup-max', type=int, @@ -547,6 +555,8 @@ def create_engine_config(self, ) -> EngineConfig: target_dtype=self.dtype, speculative_model=self.speculative_model, num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_by_batch_size=self. + speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index e8559b6a5c0fe..fa3480fa64837 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -93,6 +93,8 @@ def _init_spec_worker(self): spec_decode_worker = SpecDecodeWorker.create_worker( scorer_worker=target_worker, draft_worker_kwargs=draft_worker_kwargs, + disable_by_batch_size=self.speculative_config. + speculative_disable_by_batch_size, ) assert self.parallel_config.world_size == 1, ( diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 5edbbf2c70a49..b5f1e55d0e839 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -12,15 +12,21 @@ class RejectionSampler(nn.Module): https://arxiv.org/pdf/2302.01318.pdf. """ - def __init__(self, strict_mode: bool = False): + def __init__(self, + disable_bonus_tokens: bool = True, + strict_mode: bool = False): """Create a rejection sampler. Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks during sampling. This catches correctness issues but adds nontrivial latency. """ super().__init__() + self._disable_bonus_tokens = disable_bonus_tokens self._strict_mode = strict_mode # NOTE: A "bonus token" is accepted iff all proposal tokens are @@ -312,7 +318,8 @@ def _create_output( # proposal methods that require KV cache. We can fix it by "prefilling" # the bonus token in the proposer. The following issue tracks the fix. # https://github.com/vllm-project/vllm/issues/4212 - output_with_bonus_tokens[:, -1] = -1 + if self._disable_bonus_tokens: + output_with_bonus_tokens[:, -1] = -1 # Fill the recovered token ids. output.mul_(~after_false_mask).add_( diff --git a/vllm/sequence.py b/vllm/sequence.py index 42b508b517200..3cebb85b49d27 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -612,6 +612,12 @@ def __init__( self._token_chunk_size = token_chunk_size self.do_sample = do_sample + # The number of speculative tokens adopted in this request. + # None means specuative decoding is not used. + # Zero means speculative decoding is disabled for some reasons. + # TODO: We should maintain this states out of the sequence group. + self.num_speculative_tokens = None + if self._token_chunk_size is None: if is_prompt: self._token_chunk_size = list(seq_data.values())[0].get_len() diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 84ec974806c7e..a4e759095b294 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch @@ -54,7 +54,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): def create_worker( cls, scorer_worker: WorkerBase, - draft_worker_kwargs, + draft_worker_kwargs: Dict[str, Any], + disable_by_batch_size: Optional[int], ) -> "SpecDecodeWorker": ngram_prompt_lookup_max = ( @@ -62,7 +63,9 @@ def create_worker( ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + disable_bonus_tokens = True if ngram_prompt_lookup_max > 0: + disable_bonus_tokens = False proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) @@ -75,9 +78,9 @@ def create_worker( return SpecDecodeWorker( proposer_worker, scorer_worker, - # TODO(cade) disable strict mode for speedup. - rejection_sampler=RejectionSampler(strict_mode=True), - ) + disable_by_batch_size=disable_by_batch_size, + rejection_sampler=RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens, )) def __init__( self, @@ -85,6 +88,7 @@ def __init__( scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, + disable_by_batch_size: Optional[int] = None, ): """ Create a SpecDecodeWorker. @@ -97,11 +101,14 @@ def __init__( Worker. rejection_sampler: A Torch module used to perform modified rejection sampling for speculative decoding. + disable_by_batch_size: If the batch size is larger than this, + disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set for testing purposes. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker + self.disable_by_batch_size = disable_by_batch_size or float("inf") self.rejection_sampler = rejection_sampler self._metrics = AsyncMetricsCollector( @@ -199,27 +206,41 @@ def execute_model( "speculative decoding " "requires non-None seq_group_metadata_list") + # When the batch size is too large, disable speculative decoding + # to stop trading off throughput for latency. + disable_all = (execute_model_req.running_queue_size >= + self.disable_by_batch_size) + if disable_all: + for seq_group_metadata in execute_model_req.seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 + # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. + # This happens for prefill, or when the spec decode is disabled + # for this batch. if execute_model_req.num_lookahead_slots == 0 or len( execute_model_req.seq_group_metadata_list) == 0: - return self._run_no_spec(execute_model_req) + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all) return self._run_speculative_decoding_step(execute_model_req) @nvtx_range("spec_decode_worker._run_no_spec") - def _run_no_spec( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - """Run a prefill step, without any speculation. The input is sent to the - proposer and scorer model so that the KV cache is consistent between the - two. + def _run_no_spec(self, execute_model_req: ExecuteModelRequest, + skip_proposer: bool) -> List[SamplerOutput]: + """Run a prefill step, without any speculation. The input is sent to + the proposer and scorer model so that the KV cache is consistent + between the two. When skip_proposer is True, the proposer model is + not called, meaning that the kv-cache in proposer for requests is not + updated, so they cannot enable spec decode in the rest decoding. """ - #logger.info("run proposer worker no spec") - - self.proposer_worker.execute_model(execute_model_req) + if not skip_proposer: + self.proposer_worker.execute_model(execute_model_req) - #logger.info("run target worker no spec") sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -244,22 +265,18 @@ def _run_speculative_decoding_step( sequence. """ - #logger.info("get spec proposals") # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals(execute_model_req) - #logger.info("score proposals") proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, ) - #logger.info("verify proposals") accepted_token_ids, target_logprobs = self._verify_tokens( execute_model_req.seq_group_metadata_list, proposal_scores, proposals, execute_model_req.num_lookahead_slots) - #logger.info("create output list") return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index eb622a0e2e7f4..ee9462b68dae8 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -56,7 +56,7 @@ def get_proposals( proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices, - ) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len) + ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) if nonzero_proposal_len_seqs: # Speculate tokens using the draft worker for the speculative @@ -97,17 +97,27 @@ def get_proposals( return proposals - def _split_by_max_model_len( + def _split_by_proposal_len( self, seq_group_metadata_list: List[SequenceGroupMetadata], proposal_len: int, ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: - """Determine which sequences would exceed the max model length.""" + """Split sequences by two groups: + 1. Sequences with non-zero proposal length. + 2. Sequences with zero proposal length (due to disabled speculation + or exceed the maximum model length). + """ proposal_lens: List[int] = [] nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] nonzero_proposal_len_indices: List[int] = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): + # The speculative decoding for this request has been disabled + # (e.g. due to high traffic). + if seq_group_metadata.num_speculative_tokens == 0: + proposal_lens.append(0) + continue + seq_data = next(iter(seq_group_metadata.seq_data.values())) seq_len = seq_data.get_len() @@ -115,13 +125,14 @@ def _split_by_max_model_len( # are supported. # If max_proposal_len is defined, then we shall no exccess this # quota for nonzero_proposal + new_k = 0 if (self.max_proposal_len is None or seq_len + proposal_len < self.max_proposal_len): - proposal_lens.append(proposal_len) + new_k = proposal_len nonzero_proposal_len_seqs.append(seq_group_metadata) nonzero_proposal_len_indices.append(i) - else: - proposal_lens.append(0) + proposal_lens.append(new_k) + seq_group_metadata.num_speculative_tokens = new_k return ( proposal_lens,