Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] Raise an error for no draft token case when draft_tp>1 #6369

Merged
merged 13 commits into from
Jul 19, 2024
62 changes: 62 additions & 0 deletions tests/spec_decode/e2e/test_integration_dist_tp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,65 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",

# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 4,

# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,

# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify job failure with RuntimeError when all sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.

TODO: fix it to pass without raising Error. (#5814)
"""
with pytest.raises(RuntimeError):
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
3 changes: 3 additions & 0 deletions vllm/spec_decode/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class SpeculativeProposals:
# The valid length of each proposal; can be zero.
proposal_lens: torch.Tensor

# A flag to mark that there's no available proposals
no_proposals: bool = False

def __repr__(self):
return (f"SpeculativeProposals("
f"proposal_token_ids={self.proposal_token_ids}, "
Expand Down
23 changes: 19 additions & 4 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def create_worker(
typical_acceptance_sampler_posterior_alpha: float,
) -> "SpecDecodeWorker":

allow_zero_draft_token_step = True
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
Expand All @@ -133,6 +134,8 @@ def create_worker(
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
allow_zero_draft_token_step = False
proposer_worker = MultiStepWorker(**draft_worker_kwargs)

proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
Expand All @@ -155,10 +158,12 @@ def create_worker(
logger.info("Configuring SpecDecodeWorker with sampler=%s",
type(spec_decode_sampler))

return SpecDecodeWorker(proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler)
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)

def __init__(
self,
Expand All @@ -167,6 +172,7 @@ def __init__(
spec_decode_sampler: SpecDecodeBaseSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
):
"""
Create a SpecDecodeWorker.
Expand All @@ -187,11 +193,15 @@ def __init__(
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
self._metrics = AsyncMetricsCollector(
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector
Expand Down Expand Up @@ -461,6 +471,11 @@ def _run_speculative_decoding_step(
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)

if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")

proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
Expand Down
2 changes: 1 addition & 1 deletion vllm/spec_decode/top1_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_spec_proposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
no_proposals=maybe_sampler_output is None)

return proposals

Expand Down
Loading