diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index 56cb0147d9e4f..49e4a5f8150b5 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -58,3 +58,65 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator, batch_size, max_output_len=32, force_output_len=True) + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "tensor_parallel_size": 4, + + # Use AsyncLLM engine, so that the engine runs in its own process. + # Otherwise, since vLLM does not follow true SPMD, the test runner + # process will have both the engine and the rank0 worker. NCCL is not + # cleaned up properly, and its server host thread leaks, causing the + # second run of the test to fail with internal NCCL error. + "use_async": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Artificially limit the draft model max model len; this forces vLLM + # to skip speculation once the sequences grow beyond 32-k tokens. + "speculative_max_model_len": 32, + }, + ]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # This must be a good bit larger than speculative_max_model_len so that + # we can test the case where all seqs are skipped, but still small to + # ensure fast test. + 64, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_skip_speculation(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify job failure with RuntimeError when all sequences skip speculation. + We do this by setting the max model len of the draft model to an + artificially low value, such that when the sequences grow beyond it, they + are skipped in speculative decoding. + + TODO: fix it to pass without raising Error. (#5814) + """ + with pytest.raises(RuntimeError): + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index d109d8edc1b0b..11ab09f10c1f5 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -22,6 +22,9 @@ class SpeculativeProposals: # The valid length of each proposal; can be zero. proposal_lens: torch.Tensor + # A flag to mark that there's no available proposals + no_proposals: bool = False + def __repr__(self): return (f"SpeculativeProposals(" f"proposal_token_ids={self.proposal_token_ids}, " diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3c8e3dee46831..d9e775c9ddd7f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -109,6 +109,7 @@ def create_worker( typical_acceptance_sampler_posterior_alpha: float, ) -> "SpecDecodeWorker": + allow_zero_draft_token_step = True ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -133,6 +134,8 @@ def create_worker( if draft_tp == 1: draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner + else: + allow_zero_draft_token_step = False proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( @@ -155,10 +158,12 @@ def create_worker( logger.info("Configuring SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) - return SpecDecodeWorker(proposer_worker, - scorer_worker, - disable_by_batch_size=disable_by_batch_size, - spec_decode_sampler=spec_decode_sampler) + return SpecDecodeWorker( + proposer_worker, + scorer_worker, + disable_by_batch_size=disable_by_batch_size, + spec_decode_sampler=spec_decode_sampler, + allow_zero_draft_token_step=allow_zero_draft_token_step) def __init__( self, @@ -167,6 +172,7 @@ def __init__( spec_decode_sampler: SpecDecodeBaseSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, + allow_zero_draft_token_step: Optional[bool] = True, ): """ Create a SpecDecodeWorker. @@ -187,11 +193,15 @@ def __init__( disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set for testing purposes. + allow_zero_draft_token_step: whether to allow a step where the draft + model generates no draft token; should disallow when the tp of + draft model is larger than 1 (TODO: #5814) """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker self.disable_by_batch_size = disable_by_batch_size or float("inf") self.spec_decode_sampler = spec_decode_sampler + self._allow_zero_draft_token_step = allow_zero_draft_token_step self._metrics = AsyncMetricsCollector( self.spec_decode_sampler ) if metrics_collector is None else metrics_collector @@ -461,6 +471,11 @@ def _run_speculative_decoding_step( proposals = self.proposer_worker.get_spec_proposals( execute_model_req, self._seq_with_bonus_token_in_last_step) + if not self._allow_zero_draft_token_step and proposals.no_proposals: + #TODO: Fix it #5814 + raise RuntimeError("Cannot handle cases where distributed draft " + "workers generate no tokens") + proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 7b34b5d34208b..59257f7a61a4d 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -108,7 +108,7 @@ def get_spec_proposals( proposal_token_ids=proposal_tokens, proposal_probs=proposal_probs, proposal_lens=proposal_lens, - ) + no_proposals=maybe_sampler_output is None) return proposals