From e95cd879598b834f85e70ebcd23db316ae430540 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 16 Apr 2024 13:09:21 -0700 Subject: [PATCH] [Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894) --- tests/core/block/e2e/test_correctness.py | 70 ++++ tests/core/utils.py | 17 +- .../output_processor/test_multi_step.py | 270 +++++++++++++ tests/spec_decode/e2e/test_correctness.py | 127 +++++- tests/spec_decode/test_multi_step_worker.py | 4 +- tests/spec_decode/test_spec_decode_worker.py | 32 +- tests/spec_decode/utils.py | 2 +- vllm/core/block/block_table.py | 1 - vllm/core/scheduler.py | 8 +- vllm/engine/async_llm_engine.py | 4 +- vllm/engine/llm_engine.py | 371 +++--------------- vllm/engine/output_processor/__init__.py | 0 vllm/engine/output_processor/interfaces.py | 69 ++++ vllm/engine/output_processor/multi_step.py | 126 ++++++ vllm/engine/output_processor/single_step.py | 276 +++++++++++++ vllm/engine/output_processor/stop_checker.py | 101 +++++ vllm/engine/output_processor/util.py | 16 + vllm/executor/cpu_executor.py | 3 +- vllm/executor/executor_base.py | 5 +- vllm/executor/gpu_executor.py | 79 +++- vllm/executor/neuron_executor.py | 5 +- vllm/executor/ray_gpu_executor.py | 3 +- vllm/sequence.py | 13 + vllm/spec_decode/batch_expansion.py | 23 +- vllm/spec_decode/multi_step_worker.py | 16 +- vllm/spec_decode/spec_decode_worker.py | 44 ++- vllm/spec_decode/util.py | 26 ++ vllm/worker/cpu_worker.py | 8 +- vllm/worker/neuron_worker.py | 11 +- vllm/worker/worker.py | 11 +- vllm/worker/worker_base.py | 13 +- 31 files changed, 1347 insertions(+), 407 deletions(-) create mode 100644 tests/engine/output_processor/test_multi_step.py create mode 100644 vllm/engine/output_processor/__init__.py create mode 100644 vllm/engine/output_processor/interfaces.py create mode 100644 vllm/engine/output_processor/multi_step.py create mode 100644 vllm/engine/output_processor/single_step.py create mode 100644 vllm/engine/output_processor/stop_checker.py create mode 100644 vllm/engine/output_processor/util.py diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 94b65401e1dd4..0ee78a9b0a8ea 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -230,6 +230,76 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, assert baseline_token_ids == test_token_ids +@pytest.mark.parametrize( + "common_llm_kwargs", + [ + { + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 2, + "max_num_seqs": 2, + }, + ]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [ + { + "use_v2_block_manager": False, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "use_v2_block_manager": True, + "num_lookahead_slots": 0, + }, + { + "use_v2_block_manager": True, + "num_lookahead_slots": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_chunked_prefill_block_manager_v2(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify that chunked prefill works with BlockManagerV2, with and without + lookahead scheduling. + """ + output_len = 32 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids with BlockManagerV1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids with BlockManagerV2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) diff --git a/tests/core/utils.py b/tests/core/utils.py index fbbdb07cb8e6e..22c1d3826dff4 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,5 +1,5 @@ import time -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from vllm import SamplingParams from vllm.lora.request import LoRARequest @@ -31,14 +31,17 @@ def create_dummy_prompt( def create_seq_group( - seq_prompt_len=1024, - seq_output_lens=(128, ), - request_id='0', - seq_id_start=0, -) -> SequenceGroup: + seq_prompt_len: int = 1024, + seq_output_lens: Iterable[int] = (128, ), + request_id: str = '0', + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: assert len(seq_output_lens) > 0 + if sampling_params is None: + sampling_params = SamplingParams() + prompt_token_ids = [0] * seq_prompt_len seqs = [] @@ -60,7 +63,7 @@ def create_seq_group( seq_group = SequenceGroup( request_id=request_id, seqs=seqs, - sampling_params=SamplingParams(), + sampling_params=sampling_params, arrival_time=time.time(), ) diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py new file mode 100644 index 0000000000000..6da3da091db78 --- /dev/null +++ b/tests/engine/output_processor/test_multi_step.py @@ -0,0 +1,270 @@ +import random +from unittest.mock import MagicMock + +import pytest +from transformers import PreTrainedTokenizer + +from tests.core.utils import create_seq_group +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput, + SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + + +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [1, 12]) +@pytest.mark.skip_global_cleanup +def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): + """Verify multi-step decoding appends token ids correctly. + + We append token ids and verify all the token ids were appended correctly. + Note that ignore_eos=True. + """ + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=1024, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams(max_tokens=seq_output_len + + num_new_tokens, + ignore_eos=True), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids + output_processor.process_outputs(seq_group, outputs) + assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids + + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8]) +@pytest.mark.parametrize("max_tokens", [128 + 3]) +@pytest.mark.skip_global_cleanup +def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, max_tokens: int): + """Verify tokens after max_tokens are dropped and not appended to the + sequence. + """ + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams(max_tokens=max_tokens, ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to not go over max tokens in len. + assert seq.get_len() == seq_prompt_len + max_tokens + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [12]) +@pytest.mark.parametrize("seed", list(range(6))) +@pytest.mark.skip_global_cleanup +def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, seed: int): + """Verify the eos token id is included in the sequence, but subsequent + tokens are dropped (not appended to sequence). + """ + random.seed(seed) + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + eos_token_id = 100 + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams( + # Ensure enough space. + max_tokens=seq_output_len + num_new_tokens, ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + assert eos_token_id not in new_token_ids + eos_index = random.randint(0, len(new_token_ids) - 1) + new_token_ids[eos_index] = eos_token_id + + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to not go beyond provided eos. + assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1) + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:eos_index + 1] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [12]) +@pytest.mark.parametrize("seed", list(range(6))) +@pytest.mark.skip_global_cleanup +def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, seed: int): + """When sampling parameters dictate that we should ignore the eos token id, + ensure all token ids are appended even if the eos token id is emitted. + """ + random.seed(seed) + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + eos_token_id = 100 + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams( + # Ensure enough space. + max_tokens=seq_output_len + num_new_tokens, + ignore_eos=True, + ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + assert eos_token_id not in new_token_ids + eos_index = random.randint(0, len(new_token_ids) - 1) + new_token_ids[eos_index] = eos_token_id + + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to go beyond eos. + assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens - + seq_output_len] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + + +def mock_tokenizer(eos_token_id=1000): + tokenizer = MagicMock(spec=PreTrainedTokenizer) + tokenizer.eos_token_id = eos_token_id + return tokenizer diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index b5a6fcb7900a3..a8ebd66841eb2 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -1,4 +1,8 @@ +from itertools import cycle +from typing import List, Tuple + import pytest +from transformers import AutoTokenizer from vllm import SamplingParams @@ -7,18 +11,47 @@ "common_llm_kwargs", [{ # Use a small model for a fast test. - "model": "facebook/opt-125m", - "speculative_model": "facebook/opt-125m", - "num_speculative_tokens": 5, + # Note this is repeated in the test body; to initialize a tokenizer. + "model": "JackFram/llama-68m", + + # Skip real loading for fast test. + "load_format": "dummy", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, # Required for spec decode. "use_v2_block_manager": True }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 1, + }, + { + # No spec decode. + }, + ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [1]) +# NOTE: We should run more permutations of this test (more BS, more seeds). But +# because our spec decode generates gibberish token ids, the likelihood of +# emitting an invalid token combination is nontrivial. This causes divergence in +# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf- +# start" bytes are emitted. @pytest.mark.parametrize("seed", [1]) -def test_spec_decode_config(test_llm_generator): - output_len = 1024 +def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): + """Run generation with speculative decoding on a batch. Verify the engine + generates the correct number of tokens (via ignore_eos=True), and that the + detokenization matches HF transformers. + """ + output_len = 32 temperature = 0.0 prompts = [ @@ -28,23 +61,91 @@ def test_spec_decode_config(test_llm_generator): "The future of AI is", ] + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + skip_special_tokens=True, + spaces_between_special_tokens=False, + ) + + batch_tokens, batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + # Expect a generation for each prompt in the batch. + assert len(batch_token_ids) == len(prompts) + + # Expect each generation to have expected number of tokens (note + # ignore_eos=True). + assert all(len(token_ids) == output_len for token_ids in batch_token_ids) + + # Expect detokenized string to match. + tok = AutoTokenizer.from_pretrained("JackFram/llama-68m") + for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids): + expected_tokens = tok.decode(actual_token_ids) + print(f"{actual_token_ids=}") + assert actual_tokens.strip() == expected_tokens.strip() + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Skip real loading for fast test. + "load_format": "dummy", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Expect failure as spec decode not supported by + # Ray backend. + "worker_use_ray": True, + }, + ]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail(test_llm_generator): + """Verify that speculative decoding with Ray fails. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + sampling_params = SamplingParams( max_tokens=output_len, ignore_eos=True, temperature=temperature, ) - with pytest.raises( - AssertionError, - match="Speculative decoding not yet supported for GPU backend"): - get_token_ids_from_llm_generator(test_llm_generator, prompts, - sampling_params) + with pytest.raises(AssertionError, + match="Speculative decoding not yet supported for "): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) -def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): +def get_output_from_llm_generator( + llm_generator, prompts, + sampling_params) -> Tuple[List[str], List[List[int]]]: for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) token_ids = [output.outputs[0].token_ids for output in outputs] + tokens = [output.outputs[0].text for output in outputs] del llm - return token_ids + return tokens, token_ids diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index f4d44108b47c2..d6edbab579afd 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -125,7 +125,7 @@ def test_same_output_for_single_step(): zero_kv_cache(worker.cache_engine) set_random_seed(seed) expected_output = worker.execute_model( - **single_step_execute_model_data.to_dict(), ) + **single_step_execute_model_data.to_dict(), )[0] actual_token_ids = [ output.samples[0].output_token for output in actual_output @@ -219,7 +219,7 @@ def test_same_output_for_multi_step(): continuations=continuations, final_seq_lens=final_seq_lens)) - single_step_output.append( + single_step_output.extend( worker.execute_model(**execute_model_data.to_dict(), )) # Append output tokens to new sequence data. diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 47aff8f575413..0a3110775e2d6 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -6,6 +6,7 @@ from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) @@ -37,7 +38,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): execute_model_data, _, _ = create_batch(batch_size, k) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(**execute_model_data.to_dict(), + num_lookahead_slots=k) call_args_list = draft_worker.get_spec_proposals.call_args_list assert len(call_args_list) == 1 @@ -102,7 +104,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int): target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(**execute_model_data.to_dict(), + num_lookahead_slots=k) seen_contexts = [] @@ -189,13 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): target_output = create_sampler_output_list(target_token_ids, target_token_probs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] exception_secret = 'artifical stop' rejection_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(**execute_model_data.to_dict(), + num_lookahead_slots=k) assert len(rejection_sampler.call_args_list) == 1 args, _ = rejection_sampler.call_args_list[0] @@ -268,7 +272,7 @@ def test_correctly_formats_output(k: int, batch_size: int): target_output = create_sampler_output_list(target_token_ids, target_token_probs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] rejection_sampler_output = torch.randint(low=0, high=vocab_size, @@ -283,7 +287,7 @@ def test_correctly_formats_output(k: int, batch_size: int): rejection_sampler.return_value = rejection_sampler_output output = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) expected_output = create_sampler_output_list( rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) @@ -380,7 +384,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): target_output = create_sampler_output_list(target_token_ids, target_token_probs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] rejection_sampler_output = torch.randint(low=0, high=vocab_size, @@ -400,7 +404,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): mock_rejsample_metrics) output = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics call_args_list = ( @@ -423,6 +427,8 @@ def test_k_equals_zero(k: int, batch_size: int): rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) + target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -435,7 +441,7 @@ def test_k_equals_zero(k: int, batch_size: int): batch_size, k, prev_output_token_len=0) out = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" @@ -443,7 +449,7 @@ def test_k_equals_zero(k: int, batch_size: int): 0].sampled_tokens is None, "expect gpu tensor references to be None" draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict(), return_python_output=False) + **execute_model_data.to_dict()) target_worker.execute_model.assert_called_once_with( **execute_model_data.to_dict()) @@ -462,6 +468,8 @@ def test_empty_input_batch(k: int, batch_size: int): rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) + target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -474,7 +482,7 @@ def test_empty_input_batch(k: int, batch_size: int): batch_size, k, prev_output_token_len=0) out = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" @@ -482,7 +490,7 @@ def test_empty_input_batch(k: int, batch_size: int): 0].sampled_tokens is None, "expect gpu tensor references to be None" draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict(), return_python_output=False) + **execute_model_data.to_dict()) target_worker.execute_model.assert_called_once_with( **execute_model_data.to_dict()) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index edba4c226b289..d04b6029493f4 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -212,7 +212,7 @@ def create_sampler_output_list( SequenceOutput( output_token=token_id, parent_seq_id=seq_ids[seq_index], - logprobs={token_id: 0}, + logprobs={token_id: Logprob(0)}, ) ], prompt_logprobs=None, diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index ba061bbc4fbcb..560267e55ea3a 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -104,7 +104,6 @@ def append_token_ids(self, token_ids (List[int]): The sequence of token IDs to be appended. """ assert self._is_allocated - assert token_ids, "can't append empty token ids" self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + num_lookahead_slots) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 18ddcd1d6d466..4198550621030 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -762,9 +762,7 @@ def _schedule_default(self) -> SchedulerOutputs: blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, swapped_in.blocks_to_copy), ignored_seq_groups=prefills.ignored_seq_groups, - num_lookahead_slots=(prefills.num_lookahead_slots + - running_scheduled.num_lookahead_slots + - swapped_in.num_lookahead_slots), + num_lookahead_slots=running_scheduled.num_lookahead_slots, ) def _schedule_chunked_prefill(self): @@ -850,9 +848,7 @@ def _schedule_chunked_prefill(self): blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, swapped_in.blocks_to_copy), ignored_seq_groups=prefills.ignored_seq_groups, - num_lookahead_slots=(prefills.num_lookahead_slots + - running_scheduled.num_lookahead_slots + - swapped_in.num_lookahead_slots), + num_lookahead_slots=running_scheduled.num_lookahead_slots, ) def _schedule(self) -> SchedulerOutputs: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 1dbf58904541c..27192449bf15a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -217,7 +217,9 @@ async def step_async(self) -> List[RequestOutput]: else: output = [] - return self._process_model_outputs(output, scheduler_outputs) + return self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups) async def encode_request_async( self, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 563694946d16e..c3de57e249ff8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,5 +1,5 @@ import time -from typing import Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable, List, Optional, Type, Union from transformers import PreTrainedTokenizer @@ -11,6 +11,10 @@ from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.ray_utils import initialize_ray_cluster from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger @@ -18,8 +22,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupOutput, SequenceOutput, - SequenceStatus) + SequenceGroup) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -187,6 +190,21 @@ def __init__( labels=dict(model_name=model_config.model)) self.stat_logger.info("cache_config", self.cache_config) + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = ( + SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + self.get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + self.get_tokenizer_for_seq, + ), + )) + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -412,240 +430,32 @@ def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() - def _check_beam_search_early_stopping( - self, - early_stopping: Union[bool, str], - sampling_params: SamplingParams, - best_running_seq: Sequence, - current_worst_seq: Sequence, - ) -> bool: - assert sampling_params.use_beam_search - length_penalty = sampling_params.length_penalty - if early_stopping is True: - return True - - current_worst_score = current_worst_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=current_worst_seq.eos_token_id) - if early_stopping is False: - highest_attainable_score = best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id) - else: - assert early_stopping == "never" - if length_penalty > 0.0: - # If length_penalty > 0.0, beam search will prefer longer - # sequences. The highest attainable score calculation is - # based on the longest possible sequence length in this case. - max_possible_length = max( - best_running_seq.get_prompt_len() + - sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id, - seq_len=max_possible_length)) - else: - # Otherwise, beam search will prefer shorter sequences. The - # highest attainable score calculation is based on the current - # sequence length. - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id)) - return current_worst_score >= highest_attainable_score - - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: - - # Process prompt logprobs - prompt_logprobs = outputs.prompt_logprobs - if prompt_logprobs is not None and seq_group.sampling_params.detokenize: - self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, prompt_logprobs) - seq_group.prompt_logprobs = prompt_logprobs - - # Process samples - samples = outputs.samples - parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - existing_finished_seqs = seq_group.get_finished_seqs() - parent_child_dict = { - parent_seq.seq_id: [] - for parent_seq in parent_seqs - } - for sample in samples: - parent_child_dict[sample.parent_seq_id].append(sample) - # List of (child, parent) - child_seqs: List[Tuple[Sequence, Sequence]] = [] - - # Process the child samples for each parent sequence - for parent in parent_seqs: - child_samples: List[SequenceOutput] = parent_child_dict[ - parent.seq_id] - if len(child_samples) == 0: - # This parent sequence has no children samples. Remove - # the parent sequence from the sequence group since it will - # not be used in the future iterations. - parent.status = SequenceStatus.FINISHED_ABORTED - seq_group.remove(parent.seq_id) - self.scheduler.free_seq(parent) - continue - # Fork the parent sequence if there are multiple child samples. - for child_sample in child_samples[:-1]: - new_child_seq_id = next(self.seq_counter) - child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, - child_sample.logprobs) - child_seqs.append((child, parent)) - # Continue the parent sequence for the last child sample. - # We reuse the parent sequence here to reduce redundant memory - # copies, especially when using non-beam search sampling methods. - last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) - child_seqs.append((parent, parent)) - - for seq, _ in child_seqs: - if seq_group.sampling_params.detokenize: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, seq_group.sampling_params) - else: - new_char_count = 0 - self._check_stop(seq, new_char_count, seq_group.sampling_params) - - # Non-beam search case - if not seq_group.sampling_params.use_beam_search: - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - # NOTE: we need to fork the new sequences before freeing the - # old sequences. - for seq, parent in child_seqs: - if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) - return - - # Beam search case - # Select the child sequences to keep in the sequence group. - selected_child_seqs = [] - unselected_child_seqs = [] - beam_width = seq_group.sampling_params.best_of - length_penalty = seq_group.sampling_params.length_penalty - - # Select the newly finished sequences with the highest scores - # to replace existing finished sequences. - # Tuple of (seq, parent, is_new) - existing_finished_seqs = [(seq, None, False) - for seq in existing_finished_seqs] - new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs - if seq.is_finished()] - all_finished_seqs = existing_finished_seqs + new_finished_seqs - # Sort the finished sequences by their scores. - all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - reverse=True) - for seq, parent, is_new in all_finished_seqs[:beam_width]: - if is_new: - # A newly generated child sequence finishes and has a high - # score, so we will add it into the sequence group. - selected_child_seqs.append((seq, parent)) - for seq, parent, is_new in all_finished_seqs[beam_width:]: - if is_new: - # A newly generated child sequence finishes but has a low - # score, so we will not add it into the sequence group. - # Additionally, if this sequence is a continuation of a - # parent sequence, we will need remove the parent sequence - # from the sequence group. - unselected_child_seqs.append((seq, parent)) - else: - # An existing finished sequence has a low score, so we will - # remove it from the sequence group. - seq_group.remove(seq.seq_id) - - # select the top beam_width sequences from the running - # sequences for the next iteration to continue the beam - # search. - running_child_seqs = [(seq, parent) for seq, parent in child_seqs - if not seq.is_finished()] - # Sort the running sequences by their scores. - running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - reverse=True) - - # Check if we can stop the beam search. - if len(running_child_seqs) == 0: - # No running sequences, stop the beam search. - stop_beam_search = True - elif len(all_finished_seqs) < beam_width: - # Not enough finished sequences, continue the beam search. - stop_beam_search = False - else: - # Check the early stopping criteria - best_running_seq = running_child_seqs[0][0] - current_worst_seq = all_finished_seqs[beam_width - 1][0] - stop_beam_search = self._check_beam_search_early_stopping( - seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, current_worst_seq) - - if stop_beam_search: - # Stop the beam search and remove all the running sequences from - # the sequence group. - unselected_child_seqs.extend(running_child_seqs) - else: - # Continue the beam search and select the top beam_width sequences - # to continue the beam search. - selected_child_seqs.extend(running_child_seqs[:beam_width]) - # The remaining running sequences will not be used in the next - # iteration. Again, if these sequences are continuations of - # parent sequences, we will need to remove the parent sequences - # from the sequence group. - unselected_child_seqs.extend(running_child_seqs[beam_width:]) - - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in selected_child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - for seq, parent in selected_child_seqs: - if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) - - # Remove the unselected parent sequences from the sequence group and - # free their memory in block manager. - for seq, parent in unselected_child_seqs: - if seq is parent: - # Remove the parent sequence if it is not selected for next - # iteration - seq_group.remove(seq.seq_id) - self.scheduler.free_seq(seq) - def _process_model_outputs( - self, output: SamplerOutput, - scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + self, output: List[SamplerOutput], + scheduled_seq_groups: List[SequenceGroup], + ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: + """Apply the model output to the sequences in the scheduled seq groups. + + Returns RequestOutputs that can be returned to the client. + """ + now = time.time() + + # Organize outputs by [sequence group][step] instead of + # [step][sequence group]. + output_by_sequence_group = create_output_by_sequence_group( + sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) + # Update the scheduled sequence groups with the model outputs. - scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): + for scheduled_seq_group, outputs in zip(scheduled_seq_groups, + output_by_sequence_group): seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) # If uncomputed tokens > 0, it means prefill is chunked. # We don't need to process outputs in that case. if seq_group.get_num_uncomputed_tokens() == 0: - self._process_sequence_group_outputs(seq_group, outputs) + self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() @@ -657,13 +467,9 @@ def _process_model_outputs( seq_group.maybe_set_first_token_time(now) request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) - for seq_group in scheduler_outputs.ignored_seq_groups: + for seq_group in ignored_seq_groups: request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) - - # Log stats. - if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs)) return request_outputs def step(self) -> List[RequestOutput]: @@ -721,13 +527,23 @@ def step(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): output = self.model_executor.execute_model( - seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, - scheduler_outputs.blocks_to_swap_out, - scheduler_outputs.blocks_to_copy) + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots) else: output = [] - return self._process_model_outputs(output, scheduler_outputs) + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups) + + # Log stats. + if self.log_stats: + self.stat_logger.log(self._get_stats(scheduler_outputs)) + + return request_outputs def do_log_stats(self) -> None: """Forced log when no requests active.""" @@ -807,87 +623,6 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) - def _check_stop(self, seq: Sequence, new_char_count: int, - sampling_params: SamplingParams) -> None: - """Stop the finished sequences. - - new_char_count is the number of chars added to the - sequence's output text for the newly generated token - """ - - # Check if the minimum number of tokens has been generated yet; - # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return - - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): - seq.status = SequenceStatus.FINISHED_STOPPED - return - - # Check if a stop token was encountered. - # This assumes a single token produced per step. - last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: - if new_char_count and ( - not sampling_params.include_stop_str_in_output): - # Remove last token - seq.output_text = seq.output_text[:-new_char_count] - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = last_token_id - return - - # Check if any stop strings are matched. - stop_str = self._check_stop_strings(seq, new_char_count, - sampling_params) - if stop_str is not None: - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return - - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - @staticmethod - def _check_stop_strings(seq: Sequence, new_char_count: int, - sampling_params: SamplingParams) -> Optional[str]: - """Check if any stop strings are matched and truncate sequence - output text accordingly. - - Returns the stop string if matched or else None. - """ - if not new_char_count: - return None - - for stop_str in sampling_params.stop: - stop_string_len = len(stop_str) - # Avoid searching already-searched text. - stop_index = seq.output_text.find( - stop_str, -new_char_count - stop_string_len) - if stop_index == -1: - continue - - if sampling_params.include_stop_str_in_output: - # Truncate to end of stop string. - stop_index += stop_string_len - if stop_index >= len(seq.output_text): - # No truncation required. - return stop_str - - # Truncate the output text to either the beginning - # or end of the stop string. - seq.output_text = seq.output_text[:stop_index] - return stop_str - return None - def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/engine/output_processor/__init__.py b/vllm/engine/output_processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py new file mode 100644 index 0000000000000..9ddac7a04cb36 --- /dev/null +++ b/vllm/engine/output_processor/interfaces.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from typing import Callable, Iterable, List + +from transformers import PreTrainedTokenizer + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput +from vllm.transformers_utils.detokenizer import Detokenizer + + +class SequenceGroupOutputProcessor(ABC): + """Interface for logic that processes new token ids in sequence groups, + managing detokenization, stop checking, and freeing/forking sequences with + the scheduler. + + This is highly coupled with the LLMEngine and should be seen as an extension + of it. The logic is separated to simplify the LLMEngine class and allow + separate implementations for single-step decoding (which supports beam + search sequence forking) and multi-step decoding (which does not support + beam search, but does support speculative decoding). + """ + + @staticmethod + def create_output_processor( + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Iterable[int], + get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + stop_checker: "StopChecker", + ): + """Create an output processor. + + This returns a single-step output processor if num_lookahead_slots is + zero, else returns a multi-step output processor. + """ + if scheduler_config.num_lookahead_slots == 0: + # Importing here to avoid cycle. + from vllm.engine.output_processor.single_step import ( + SingleStepOutputProcessor) + return SingleStepOutputProcessor( + scheduler_config, + detokenizer, + scheduler, + seq_counter, + stop_checker, + ) + else: + # Importing here to avoid cycle. + from vllm.engine.output_processor.multi_step import ( + MultiStepOutputProcessor) + return MultiStepOutputProcessor( + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + stop_checker, + ) + + @abstractmethod + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process new token ids for the sequence group. Handles logic such as + detokenization, stop checking, and freeing/forking sequences in the + scheduler. + """ + pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py new file mode 100644 index 0000000000000..50da0d35fcec1 --- /dev/null +++ b/vllm/engine/output_processor/multi_step.py @@ -0,0 +1,126 @@ +from typing import Callable, Iterable, List + +from transformers import PreTrainedTokenizer + +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Logprob, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer + +logger = init_logger(__name__) + + +class MultiStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles logic related to + detokenization and stopping conditions. It specializes to "multi-step + decoding", where vLLM's worker may generate multiple tokens per invocation. + This is currently mutually exclusive with advanced sampling techniques like + beam search, which motivates the separation of this logic from the single + step output processor. + + This class is responsible for things such as correctly appending all new + token ids to their sequence, detokenizing new token ids, truncating new + output tokens after an eos token, and correctly handling the case where the + number of new output tokens per sequence differs in a single batch. + """ + + def __init__( + self, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Iterable[int], + get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + stop_checker: StopChecker, + ): + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.stop_checker = stop_checker + + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Append new tokens in the outputs to sequences in the sequence group. + + This only supports sequence groups of size 1. It supports greater than + one new token per sequence. + + This applies logic like stop condition checking and detokenization, + including freeing finished sequences. It also handles cases where there + are tokens emitted after the EOS token. + """ + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) + + assert seqs, "expected running sequences" + assert len(seqs) == 1, ( + "Beam search not supported in multi-step decoding.") + seq = seqs[0] + + # Since there's only one sequence per sequence group, we can take the + # first sample. + samples = [outputs[step].samples[0] for step in range(len(outputs))] + + # -1 means the output token is not valid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples if sample.output_token != -1 + ] + assert valid_samples + + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) + + def _process_seq_outputs(self, seq: Sequence, + valid_samples: List[SequenceOutput], + sampling_params: SamplingParams) -> None: + output_token_ids = [sample.output_token for sample in valid_samples] + + # Truncate to max_tokens if necessary. + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + + len(output_token_ids)) + if remaining_tokens < 0: + valid_samples = valid_samples[:remaining_tokens] + output_token_ids = output_token_ids[:remaining_tokens] + + # Truncate any tokens after EOS. This is required as spec decode + # generates a fixed number of tokens without evaluating stopping + # conditions within the block. This can cause an eos token to be + # unintentionally ignored. + if not sampling_params.ignore_eos: + eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id + # Avoiding .index calls as exception throwing in the happy path + # is expensive. + for i in range(len(output_token_ids)): + if output_token_ids[i] == eos_token_id: + output_token_ids = output_token_ids[:i + 1] + valid_samples = valid_samples[:i + 1] + break + + # Incrementally append tokens to the sequence, as if we had only one new + # token. + for output_token_id in output_token_ids: + seq.append_token_id( + token_id=output_token_id, + # TODO emit logprobs in multi-step decoding. + logprobs={output_token_id: Logprob(0.0)}, + ) + + new_char_count = 0 + if sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count=new_char_count, + sampling_params=sampling_params) + if seq.is_finished(): + break + + if seq.is_finished(): + self.scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py new file mode 100644 index 0000000000000..1b7eb014f802b --- /dev/null +++ b/vllm/engine/output_processor/single_step.py @@ -0,0 +1,276 @@ +from typing import Iterable, List, Tuple, Union + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, + SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer + +logger = init_logger(__name__) + + +class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles "output processing" logic, + which happens after the model returns generated token ids and before + scheduling of the next batch. Output processing logic includes + detokenization, and determining if a sequence is finished (e.g. via max len + or eos token). + + The SingleStepOutputProcessor is specialized to the case where the model + emits at most a single token per invocation, which precludes configurations + such as speculative decoding or multi-step decoding. This enables beam + search sampling, which requires forking/finishing/freeing sequences in a way + that is currently difficult to schedule multiple steps ahead of time. + """ + + def __init__( + self, + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Iterable[int], + stop_checker: StopChecker, + ): + self.scheduler_config = scheduler_config + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.stop_checker = stop_checker + + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Append all new tokens to sequences in the sequence group. Fork any + surviving beam candidates; free any unsurviving ones. + + Invokes detokenizer to detokenize new tokens, and also marks sequences + as finished if they meet stop conditions. + """ + assert (len(outputs) == 1 + ), f"{type(self)} does not support multiple outputs per step" + return self._process_sequence_group_outputs(sequence_group, outputs[0]) + + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput) -> None: + + # Process prompt logprobs + prompt_logprobs = outputs.prompt_logprobs + if prompt_logprobs is not None and seq_group.sampling_params.detokenize: + self.detokenizer.decode_prompt_logprobs_inplace( + seq_group, prompt_logprobs) + seq_group.prompt_logprobs = prompt_logprobs + + # Process samples + samples = outputs.samples + parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + existing_finished_seqs = seq_group.get_finished_seqs() + parent_child_dict = { + parent_seq.seq_id: [] + for parent_seq in parent_seqs + } + for sample in samples: + parent_child_dict[sample.parent_seq_id].append(sample) + # List of (child, parent) + child_seqs: List[Tuple[Sequence, Sequence]] = [] + + # Process the child samples for each parent sequence + for parent in parent_seqs: + child_samples: List[SequenceOutput] = parent_child_dict[ + parent.seq_id] + if len(child_samples) == 0: + # This parent sequence has no children samples. Remove + # the parent sequence from the sequence group since it will + # not be used in the future iterations. + parent.status = SequenceStatus.FINISHED_ABORTED + seq_group.remove(parent.seq_id) + self.scheduler.free_seq(parent) + continue + # Fork the parent sequence if there are multiple child samples. + for child_sample in child_samples[:-1]: + new_child_seq_id = next(self.seq_counter) + child = parent.fork(new_child_seq_id) + child.append_token_id(child_sample.output_token, + child_sample.logprobs) + child_seqs.append((child, parent)) + # Continue the parent sequence for the last child sample. + # We reuse the parent sequence here to reduce redundant memory + # copies, especially when using non-beam search sampling methods. + last_child_sample = child_samples[-1] + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) + child_seqs.append((parent, parent)) + + for seq, _ in child_seqs: + if seq_group.sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, seq_group.sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence(seq, new_char_count, + seq_group.sampling_params) + + # Non-beam search case + if not seq_group.sampling_params.use_beam_search: + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + # NOTE: we need to fork the new sequences before freeing the + # old sequences. + for seq, parent in child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + return + + # Beam search case + # Select the child sequences to keep in the sequence group. + selected_child_seqs = [] + unselected_child_seqs = [] + beam_width = seq_group.sampling_params.best_of + length_penalty = seq_group.sampling_params.length_penalty + + # Select the newly finished sequences with the highest scores + # to replace existing finished sequences. + # Tuple of (seq, parent, is_new) + existing_finished_seqs = [(seq, None, False) + for seq in existing_finished_seqs] + new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs + if seq.is_finished()] + all_finished_seqs = existing_finished_seqs + new_finished_seqs + # Sort the finished sequences by their scores. + all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), + reverse=True) + for seq, parent, is_new in all_finished_seqs[:beam_width]: + if is_new: + # A newly generated child sequence finishes and has a high + # score, so we will add it into the sequence group. + selected_child_seqs.append((seq, parent)) + for seq, parent, is_new in all_finished_seqs[beam_width:]: + if is_new: + # A newly generated child sequence finishes but has a low + # score, so we will not add it into the sequence group. + # Additionally, if this sequence is a continuation of a + # parent sequence, we will need remove the parent sequence + # from the sequence group. + unselected_child_seqs.append((seq, parent)) + else: + # An existing finished sequence has a low score, so we will + # remove it from the sequence group. + seq_group.remove(seq.seq_id) + + # select the top beam_width sequences from the running + # sequences for the next iteration to continue the beam + # search. + running_child_seqs = [(seq, parent) for seq, parent in child_seqs + if not seq.is_finished()] + # Sort the running sequences by their scores. + running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), + reverse=True) + + # Check if we can stop the beam search. + if len(running_child_seqs) == 0: + # No running sequences, stop the beam search. + stop_beam_search = True + elif len(all_finished_seqs) < beam_width: + # Not enough finished sequences, continue the beam search. + stop_beam_search = False + else: + # Check the early stopping criteria + best_running_seq = running_child_seqs[0][0] + current_worst_seq = all_finished_seqs[beam_width - 1][0] + stop_beam_search = self._check_beam_search_early_stopping( + seq_group.sampling_params.early_stopping, + seq_group.sampling_params, best_running_seq, current_worst_seq) + + if stop_beam_search: + # Stop the beam search and remove all the running sequences from + # the sequence group. + unselected_child_seqs.extend(running_child_seqs) + else: + # Continue the beam search and select the top beam_width sequences + # to continue the beam search. + selected_child_seqs.extend(running_child_seqs[:beam_width]) + # The remaining running sequences will not be used in the next + # iteration. Again, if these sequences are continuations of + # parent sequences, we will need to remove the parent sequences + # from the sequence group. + unselected_child_seqs.extend(running_child_seqs[beam_width:]) + + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in selected_child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + for seq, parent in selected_child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + + # Remove the unselected parent sequences from the sequence group and + # free their memory in block manager. + for seq, parent in unselected_child_seqs: + if seq is parent: + # Remove the parent sequence if it is not selected for next + # iteration + seq_group.remove(seq.seq_id) + self.scheduler.free_seq(seq) + + def _check_beam_search_early_stopping( + self, + early_stopping: Union[bool, str], + sampling_params: SamplingParams, + best_running_seq: Sequence, + current_worst_seq: Sequence, + ) -> bool: + assert sampling_params.use_beam_search + length_penalty = sampling_params.length_penalty + if early_stopping is True: + return True + + current_worst_score = current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=current_worst_seq.eos_token_id) + if early_stopping is False: + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id) + else: + assert early_stopping == "never" + if length_penalty > 0.0: + # If length_penalty > 0.0, beam search will prefer longer + # sequences. The highest attainable score calculation is + # based on the longest possible sequence length in this case. + max_possible_length = max( + best_running_seq.get_prompt_len() + + sampling_params.max_tokens, + self.scheduler_config.max_model_len) + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id, + seq_len=max_possible_length)) + else: + # Otherwise, beam search will prefer shorter sequences. The + # highest attainable score calculation is based on the current + # sequence length. + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id)) + return current_worst_score >= highest_attainable_score diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py new file mode 100644 index 0000000000000..66deb9b591746 --- /dev/null +++ b/vllm/engine/output_processor/stop_checker.py @@ -0,0 +1,101 @@ +from typing import Callable, Optional + +from transformers import PreTrainedTokenizer + +from vllm.sampling_params import SamplingParams +from vllm.sequence import Sequence, SequenceStatus + + +class StopChecker: + """LLMEngine helper class which separates out the logic involving stop + checking. This checks things such as: whether the eos token was emitted, + whether the max_tokens has been consumed, whether a stop string has been + emitted, or if we have exceeded the max model len. + """ + + def __init__(self, max_model_len: int, + get_tokenizer_for_seq: Callable[[Sequence], + PreTrainedTokenizer]): + self.max_model_len = max_model_len + self.get_tokenizer_for_seq = get_tokenizer_for_seq + + def maybe_stop_sequence(self, seq: Sequence, new_char_count: int, + sampling_params: SamplingParams) -> None: + """Stop the finished sequences. + + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if any stop strings are matched. + stop_str = self._check_stop_strings(seq, new_char_count, + sampling_params) + if stop_str is not None: + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + @staticmethod + def _check_stop_strings(seq: Sequence, new_char_count: int, + sampling_params: SamplingParams) -> Optional[str]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns the stop string if matched or else None. + """ + if not new_char_count: + return None + + for stop_str in sampling_params.stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = seq.output_text.find( + stop_str, -new_char_count - stop_string_len) + if stop_index == -1: + continue + + if sampling_params.include_stop_str_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(seq.output_text): + # No truncation required. + return stop_str + + # Truncate the output text to either the beginning + # or end of the stop string. + seq.output_text = seq.output_text[:stop_index] + return stop_str + return None diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py new file mode 100644 index 0000000000000..5fbb09a857a46 --- /dev/null +++ b/vllm/engine/output_processor/util.py @@ -0,0 +1,16 @@ +from typing import List + +from vllm.sequence import SamplerOutput + + +def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], + num_seq_groups: int): + """Helper method which transforms a 2d list organized by + [step][sequence group] into [sequence group][step]. + """ + output_by_sequence_group = [[] for _ in range(num_seq_groups)] + for step in sampler_outputs: + for i, sequence_group_output in enumerate(step): + output_by_sequence_group[i].append(sequence_group_output) + + return output_by_sequence_group diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 426e2c41d8427..f925a6fc93dcd 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -74,7 +74,8 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int) -> List[SamplerOutput]: output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 8cc04c5299ca1..1839b5603ff3e 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -72,8 +72,9 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - """Executes one model step on the given sequences.""" + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int) -> List[SamplerOutput]: + """Executes at least one model step on the given sequences.""" raise NotImplementedError @abstractmethod diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3a9537effe6d9..6e4a765e2ffd5 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -13,13 +13,17 @@ class GPUExecutor(ExecutorBase): def _init_executor(self) -> None: - assert (not self.speculative_config - ), "Speculative decoding not yet supported for GPU backend" + """Initialize the worker and load the model. - # Instantiate the worker and load the model to GPU. - self._init_worker() + If speculative decoding is enabled, we instead create the speculative + worker. + """ + if self.speculative_config is None: + self._init_non_spec_worker() + else: + self._init_spec_worker() - def _init_worker(self): + def _init_non_spec_worker(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker from vllm.worker.worker import Worker @@ -46,6 +50,57 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() + def _init_spec_worker(self): + """Initialize a SpecDecodeWorker, using a draft model for proposals. + """ + assert self.speculative_config is not None + + from vllm.spec_decode.multi_step_worker import MultiStepWorker + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + from vllm.worker.worker import Worker + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + target_worker = Worker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=True, + ) + + draft_worker = MultiStepWorker( + model_config=self.speculative_config.draft_model_config, + parallel_config=self.speculative_config.draft_parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=True, + ) + + spec_decode_worker = SpecDecodeWorker.from_workers( + proposer_worker=draft_worker, scorer_worker=target_worker) + + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = spec_decode_worker + + # Load model handled in spec decode worker. + self.driver_worker.init_device() + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. @@ -63,16 +118,20 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int, + ) -> List[SamplerOutput]: output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, + num_lookahead_slots=num_lookahead_slots, ) return output diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 273b17a927efd..7cc187e297c9f 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -48,10 +48,13 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int) -> List[SamplerOutput]: assert (blocks_to_swap_in == {} and blocks_to_swap_out == {} and blocks_to_copy == {}), ( "Cache operations are not supported for Neuron backend.") + assert num_lookahead_slots == 0, ( + "lookahead not supported for Neuron backend.") output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4065c0868d79a..5f859fdc9c078 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -242,7 +242,8 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int = 0) -> SamplerOutput: all_outputs = self._run_workers( "execute_model", driver_kwargs={ diff --git a/vllm/sequence.py b/vllm/sequence.py index dcde81df19923..92362a9a5d2a3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -693,3 +693,16 @@ def __len__(self): def __eq__(self, other: object): return isinstance(other, self.__class__) and self.outputs == other.outputs + + def __repr__(self) -> str: + """Show the shape of a tensor instead of its values to reduce noise. + """ + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None + else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else + self.sampled_token_ids.shape) + return ( + f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr}, " + f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index e0b75837e8a39..88af1dd360155 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,10 +6,10 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) -from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, - sampler_output_to_torch, +from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors, + nvtx_range, sampler_output_to_torch, split_batch_by_proposal_len) -from vllm.worker.worker import Worker +from vllm.worker.worker_base import WorkerBase SeqId = int TargetSeqId = int @@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): of topk/tree. """ - def __init__(self, scorer_worker: Worker, device: str, vocab_size: int): + def __init__(self, scorer_worker: WorkerBase, device: str, + vocab_size: int): self._scorer_worker = scorer_worker self._device = device self._vocab_size = vocab_size @@ -83,7 +84,9 @@ def score_proposals( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - return_python_output=False) + ) + assert len(target_sampler_output) == 1, "expected single-step output" + target_sampler_output = target_sampler_output[0] all_tokens, all_probs = self._contract_batch( original_bs=len(seq_group_metadata_list), @@ -142,6 +145,16 @@ def _contract_batch(self, original_bs: int, This maps the scores of speculative tokens back to their original sequences. """ + + # We mock the device tensors until PR 7/9 is merged (e2e correctness). + # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + maybe_mock_device_tensors( + sampler_output=target_sampler_output, + batch_size=len(non_spec_indices) + num_scoring_tokens, + vocab_size=self._vocab_size, + device=self._device, + ) + (target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 73b6e201c67a9..ce63c329a40aa 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -6,7 +6,8 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) -from vllm.spec_decode.util import sampler_output_to_torch +from vllm.spec_decode.util import (maybe_mock_device_tensors, + sampler_output_to_torch) from vllm.worker.worker import Worker @@ -69,6 +70,9 @@ def execute_model_multi_step( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ) + assert (len(model_output) == 1 + ), "composing multistep workers not supported" + model_output = model_output[0] self._append_new_tokens(model_output, copied_seq_group_metadata_list) @@ -341,6 +345,16 @@ def _merge_outputs( sampler_output = maybe_sampler_output + # We mock the device tensors until PR 7/9 is merged (e2e correctness). + # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + for step_output in sampler_output: + maybe_mock_device_tensors( + sampler_output=step_output, + batch_size=len(proposal_lens), + vocab_size=self._vocab_size, + device=self._device, + ) + proposal_tokens, proposal_probs = sampler_output_to_torch( sampler_output) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 885bf537568e3..be3af7be93864 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,8 +3,9 @@ import torch +from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, +from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -13,8 +14,9 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, split_batch_by_proposal_len) -from vllm.worker.worker import Worker -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase + +logger = init_logger(__name__) class SpecDecodeWorker(LoraNotSupportedWorkerBase): @@ -45,10 +47,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. """ + @classmethod + def from_workers(cls, proposer_worker: MultiStepWorker, + scorer_worker: WorkerBase) -> "SpecDecodeWorker": + return SpecDecodeWorker( + proposer_worker, + scorer_worker, + # TODO(cade) disable strict mode for speedup. + rejection_sampler=RejectionSampler(strict_mode=True), + ) + def __init__( self, proposer_worker: MultiStepWorker, - scorer_worker: Worker, + scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, ): @@ -87,6 +99,10 @@ def init_device(self) -> None: self.scorer_worker.init_device() self.proposer_worker.init_device() + # NOTE(cade): load_model is not part of the WorkerBase interface. + self.scorer_worker.load_model() + self.proposer_worker.load_model() + self._metrics.init_gpu_tensors(self.rank) self.rejection_sampler.init_gpu_tensors(self.rank) self.scorer = BatchExpansionTop1Scorer( @@ -131,7 +147,7 @@ def execute_model( blocks_to_swap_in: Optional[Dict[int, int]], blocks_to_swap_out: Optional[Dict[int, int]], blocks_to_copy: Optional[Dict[int, List[int]]], - num_spec_tokens: int, + num_lookahead_slots: int, ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ @@ -140,9 +156,11 @@ def execute_model( "speculative decoding " "requires non-None seq_group_metadata_list") + logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}") + # If no spec tokens, call the proposer and scorer workers normally. # Used for prefill. - if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0: + if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0: return self._run_no_spec( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, @@ -155,7 +173,7 @@ def execute_model( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - k=num_spec_tokens, + k=num_lookahead_slots, ) @nvtx_range("spec_decode_worker._run_no_spec") @@ -170,20 +188,24 @@ def _run_no_spec( proposer and scorer model so that the KV cache is consistent between the two. """ + logger.info("run proposer worker no spec") self.proposer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - return_python_output=False) + ) + logger.info("run target worker no spec") sampler_output = self.scorer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ) + assert len(sampler_output) == 1 + sampler_output = sampler_output[0] # Clear device tensors from sampler output. This reduces communication # overhead when the engine runs in a different process than the workers. @@ -209,11 +231,13 @@ def _run_speculative_decoding_step( sequence. """ + logger.info("get spec proposals") # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals( seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, k) + logger.info("score proposals") proposal_scores = self.scorer.score_proposals( seq_group_metadata_list, blocks_to_swap_in, @@ -223,9 +247,11 @@ def _run_speculative_decoding_step( proposals, ) + logger.info("verify proposals") accepted_token_ids = self._verify_tokens(seq_group_metadata_list, proposal_scores, proposals, k) + logger.info("create output list") return self._create_output_sampler_list(seq_group_metadata_list, accepted_token_ids, k) @@ -311,7 +337,7 @@ def _create_output_sampler_list( parent_seq_id=seq_id, output_token=token_id, # TODO Add verifier logprobs. - logprobs={token_id: 0.0}, + logprobs={token_id: Logprob(0.0)}, ) ], prompt_logprobs=None, diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 406568a4bc08c..eb6d4ca1da8e6 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -82,6 +82,32 @@ def sampler_output_to_torch( return sampled_token_ids, sampled_token_probs +def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, + vocab_size: int, device: str) -> None: + """Helper method which mocks out the GPU tensors in SamplerOutput with dummy + values. This will be removed in PR 7/9. + https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + """ + values = [ + sampler_output.sampled_token_probs, sampler_output.sampled_token_ids + ] + assert all(v is None for v in values) or not any(v is None for v in values) + if not any(v is None for v in values): + # Do nothing if the tensors are already created (usually in unit tests). + return + + # Softmax to ensure valid probs. + sampler_output.sampled_token_probs = torch.nn.functional.softmax( + torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), + dim=-1) + + sampler_output.sampled_token_ids = torch.randint(low=10, + high=100, + size=(batch_size, ), + dtype=torch.long, + device=device) + + @contextmanager def nvtx_range(msg, *args, **kwargs): """ diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 6610b9c4be876..afc4a1e1f4630 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -251,7 +251,7 @@ def execute_model( blocks_to_swap_in: Optional[Dict[int, int]] = None, blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, - ) -> Optional[SamplerOutput]: + ) -> List[SamplerOutput]: if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) @@ -274,11 +274,13 @@ def execute_model( # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list, self.cpu_cache) - return output + + # CPU worker only supports single-step execution. + return [output] def init_distributed_environment(self) -> None: """Initialize the distributed environment.""" diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 2f22f82c045db..142c6c97f5194 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,5 +1,5 @@ """A Neuron worker class.""" -from typing import List, Optional, Tuple +from typing import List, Tuple import torch import torch.distributed @@ -73,15 +73,18 @@ def initialize_cache(self, num_gpu_blocks: int, def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Optional[SamplerOutput]: + ) -> List[SamplerOutput]: num_seq_groups = len(seq_group_metadata_list) # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list) - return output + + # Neuron worker only supports single-step output. Wrap the output in a + # list to conform to interface. + return [output] def get_cache_block_size_bytes(self) -> int: """Determine the size in bytes of a cache block. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6a79285f60579..e2b47530d41e4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -210,7 +210,9 @@ def execute_model( blocks_to_swap_in: Optional[Dict[int, int]] = None, blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, - ) -> Optional[SamplerOutput]: + num_lookahead_slots: int = 0, + ) -> List[SamplerOutput]: + if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) @@ -235,11 +237,14 @@ def execute_model( # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache) - return output + + # Worker only supports single-step execution. Wrap the output in a list + # to conform to interface. + return [output] def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d8c9febb11584..a92f5aea76059 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -40,12 +40,13 @@ def initialize_cache(self, num_gpu_blocks: int, raise NotImplementedError @abstractmethod - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - """Executes one model step on the given sequences.""" + def execute_model( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, + int], + blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" raise NotImplementedError @abstractmethod