From bbf1484820a37b46e29538645bb0942fb1a320d7 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 7 Jun 2024 16:52:21 +0000 Subject: [PATCH 01/38] Integrate Typical Acceptance Sampler into spec decode worker --- tests/spec_decode/e2e/conftest.py | 66 +++++- .../e2e/test_multistep_correctness.py | 16 +- vllm/config.py | 16 ++ vllm/engine/arg_utils.py | 32 ++- vllm/engine/metrics.py | 5 +- .../layers/rejection_sampler.py | 174 +-------------- .../layers/spec_decode_base_sampler.py | 210 ++++++++++++++++++ .../layers/typical_acceptance_sampler.py | 190 ++++++++++++++++ vllm/spec_decode/metrics.py | 23 +- vllm/spec_decode/spec_decode_worker.py | 93 ++++++-- 10 files changed, 629 insertions(+), 196 deletions(-) create mode 100644 vllm/model_executor/layers/spec_decode_base_sampler.py create mode 100644 vllm/model_executor/layers/typical_acceptance_sampler.py diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 1d060e265848a..9476c2c503166 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -241,7 +241,71 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero. """ - temperature = 0.0 + temperature = 0.8 + + prompts = [ + #"Hello, my name is", + #"The president of the United States is", + #"The capital of France is", + #"The future of AI is", + #"San Francisco is know for its", + #"Facebook was created in 2004 by", + #"Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + ) + start = time.time() + spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) + end = time.time() + + print('time for spec decode ' + str(end - start)) + + start = time.time() + (baseline_batch_tokens, + baseline_batch_token_ids) = get_output_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + end = time.time() + print('time for base line ' + str(end - start)) + + #assert len(baseline_batch_token_ids) == len(prompts) + assert len(spec_batch_token_ids) == len(prompts) + + for i, (baseline_token_ids, baseline_tokens, spec_token_ids, + spec_tokens) in enumerate( + zip(baseline_batch_token_ids, baseline_batch_tokens, + spec_batch_token_ids, spec_batch_tokens)): + if True: + print(f'{i=} {baseline_tokens=}') + print(f'{i=} {spec_tokens=}') + #print(f'{i=} {baseline_token_ids=}') + #print(f'{i=} {spec_token_ids=}') + assert baseline_token_ids == spec_token_ids + + + +def compare_sampler_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + print_tokens: bool = False): + """Helper method that compares the outputs of both the baseline LLM and + the test LLM. It asserts greedy equality, e.g. that the outputs are exactly + the same when temperature is zero. + """ + temperature = 1.0 prompts = [ "Hello, my name is", diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 94d71fb012727..1670435d2dad6 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -213,15 +213,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( # Print spec metrics. "disable_log_stats": False, + + #"tensor_parallel_size" : 1, }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ # Try two different tiny base models. # Note that one is equal to the draft model, another isn't. - { - "model": "JackFram/llama-68m", - }, + #{ + # "model": "JackFram/llama-68m", + #}, { "model": "JackFram/llama-160m", }, @@ -230,7 +232,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( @pytest.mark.parametrize("test_llm_kwargs", [ { "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, + "num_speculative_tokens": 3, }, ]) @pytest.mark.parametrize( @@ -239,9 +241,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( # Use small output len for fast test. 256, ]) -@pytest.mark.parametrize("batch_size", [64]) +@pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("seed", [1]) -def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( +def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_1( baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): """Verify greedy equality on a tiny model and large batch size. @@ -250,7 +252,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( test_llm_generator, batch_size, max_output_len=output_len, - force_output_len=True) + force_output_len=False) @pytest.mark.parametrize( diff --git a/vllm/config.py b/vllm/config.py index eee62d2683835..84fa2efaa9de8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -761,6 +761,9 @@ def maybe_create_spec_config( speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], + draft_token_sampling_method: Optional[str], + typical_acceptance_sampler_posterior_threshold: Optional[float], + typical_acceptance_sampler_posterior_alpha: Optional[float], ) -> Optional["SpeculativeConfig"]: """Create a SpeculativeConfig if possible, else return None. @@ -889,6 +892,11 @@ def maybe_create_spec_config( speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, + draft_token_sampling_method=draft_token_sampling_method, + typical_acceptance_sampler_posterior_threshold=\ + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=\ + typical_acceptance_sampler_posterior_alpha, ) @staticmethod @@ -960,6 +968,9 @@ def __init__( speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], + draft_token_sampling_method: Optional[str], + typical_acceptance_sampler_posterior_threshold: Optional[float], + typical_acceptance_sampler_posterior_alpha: Optional[float], ): """Create a SpeculativeConfig object. @@ -981,6 +992,11 @@ def __init__( speculative_disable_by_batch_size self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 + self.draft_token_sampling_method = draft_token_sampling_method + self.typical_acceptance_sampler_posterior_threshold = \ + typical_acceptance_sampler_posterior_threshold + self.typical_acceptance_sampler_posterior_alpha = \ + typical_acceptance_sampler_posterior_alpha self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b315d4d2ece29..217fa572ca419 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -96,7 +96,9 @@ class EngineArgs: speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None - + speculative_draft_token_sampling_method: str = 'typical_acceptance_sampler' + typical_acceptance_sampler_posterior_threshold: float = 0.09 + typical_acceptance_sampler_posterior_alpha: float = 0.3 qlora_adapter_name_or_path: Optional[str] = None def __post_init__(self): @@ -556,6 +558,30 @@ def add_cli_args( help='Min size of window for ngram prompt lookup in speculative ' 'decoding.') + parser.add_argument( + '--speculative-draft-token-sampling-method', + type=str, + default=EngineArgs.speculative_draft_token_sampling_method, + choices=['rejection_sampler', 'typical_acceptance_sampler'], + help='The draft token sampler to use for speculative decoding.') + + parser.add_argument( + '--typical-acceptance-sampler-posterior-threshold', + type=float, + default=EngineArgs.typical_acceptance_sampler_posterior_threshold, + help='A threshold value that sets a lower bound on the ' + 'posterior probability of a token for it to be accepted. This ' + 'parameter is used by the TypicalAcceptanceSampler for making ' + 'sampling decisions during speculative decoding.') + + parser.add_argument( + '--typical-acceptance-sampler-posterior-alpha', + type=float, + default=EngineArgs.typical_acceptance_sampler_posterior_alpha, + help='A scaling factor for the entropy-based threshold for token ' + 'acceptance in the TypicalAcceptanceSampler. Typically defaults ' + 'to sqrt of --typical-acceptance-sampler-posterior-threshold.') + parser.add_argument('--model-loader-extra-config', type=nullable_str, default=EngineArgs.model_loader_extra_config, @@ -654,6 +680,10 @@ def create_engine_config(self, ) -> EngineConfig: use_v2_block_manager=self.use_v2_block_manager, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, + draft_token_sampling_method="rejection_sampler", + #draft_token_sampling_method="typical_acceptance_sampler", + typical_acceptance_sampler_posterior_threshold=0.09, + typical_acceptance_sampler_posterior_alpha=0.3, ) scheduler_config = SchedulerConfig( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index ae7ae144bc04f..d20ba890bcc31 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -372,4 +372,7 @@ def _format_spec_decode_metrics_str( f"Number of speculative tokens: {metrics.num_spec_tokens}, " f"Number of accepted tokens: {metrics.accepted_tokens}, " f"Number of draft tokens tokens: {metrics.draft_tokens}, " - f"Number of emitted tokens tokens: {metrics.emitted_tokens}.") + f"Number of emitted tokens tokens: {metrics.emitted_tokens}, " + f"Total Time: {metrics.total_time}, " + f"Total Calls: {metrics.total_calls}, " + f"Avg Time: {metrics.avg_time}.") diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 1f2ab7e2870ca..653e4c18c0972 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -3,10 +3,12 @@ import torch import torch.jit +import time import torch.nn as nn +from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler -class RejectionSampler(nn.Module): +class RejectionSampler(SpecDecodeBaseSampler, nn.Module): """Apply modified rejection sampling as described in "Accelerating Large Language Model Decoding with Speculative Sampling" https://arxiv.org/pdf/2302.01318.pdf. @@ -22,39 +24,11 @@ def __init__(self, Require when bonus tokens will cause corrupt KV cache for proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks - during sampling. This catches correctness issues but adds - nontrivial latency. + during sampling. This catches correctness issues but adds + nontrivial latency. """ - super().__init__() - self._disable_bonus_tokens = disable_bonus_tokens - self._strict_mode = strict_mode - - # NOTE: A "bonus token" is accepted iff all proposal tokens are - # accepted. There is always only one possible bonus token. We store this - # value in a variable for readability. - self._num_bonus_tokens = 1 - - self.num_accepted_tokens: Optional[torch.Tensor] = None - self.num_emitted_tokens: Optional[torch.Tensor] = None - self.num_draft_tokens: int = 0 - - def init_gpu_tensors(self, rank: int) -> None: - assert self.num_accepted_tokens is None - device = f"cuda:{rank}" - self.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - self.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - - @property - def probs_dtype(self): - return torch.float32 - - @property - def token_id_dtype(self): - return torch.int64 + SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode) + nn.Module.__init__(self) def forward( self, @@ -99,16 +73,10 @@ def forward( """ # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. + start = time.time() if self._strict_mode: - self._raise_if_incorrect_shape(target_probs, bonus_token_ids, + self._raise_if_incorrect_input(target_probs, bonus_token_ids, draft_probs, draft_token_ids) - self._raise_if_incorrect_dtype(target_probs, bonus_token_ids, - draft_probs, draft_token_ids) - self._raise_if_inconsistent_device(target_probs, bonus_token_ids, - draft_probs, draft_token_ids) - self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], - bonus_token_ids, - draft_token_ids) accepted, recovered_token_ids = self._batch_modified_rejection_sampling( target_probs, @@ -122,7 +90,9 @@ def forward( draft_token_ids, bonus_token_ids, ) - + end = time.time() + self.total_time += (end - start) + self.total_calls += 1 return output_token_ids def _batch_modified_rejection_sampling( @@ -272,126 +242,6 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny - def _create_output( - self, - accepted: torch.Tensor, # [batch_size, k] - recovered_token_ids: torch.Tensor, # [batch_size, k] - draft_token_ids: torch.Tensor, # [batch_size, k] - bonus_token_ids: torch.Tensor, # [batch_size] - ) -> torch.Tensor: - """Format output. Returns a matrix of token ids. When - a token is rejected via rejection sampling, all subsequent - token ids are set to -1 for the sequence. - - shape = [batch_size, k + num_bonus_tokens] - """ - bonus_token_ids = bonus_token_ids.squeeze() - batch_size, k = recovered_token_ids.shape - - # Determine the index of the first False value for each row. - limits = (accepted == 0).max(1).indices - limits[~(accepted == 0).any(1)] = k - - # Create masks using the indices. - indices = torch.arange(k, device=accepted.device).unsqueeze(0) - accepted_mask = indices < limits.unsqueeze(1) - after_false_mask = indices == limits.unsqueeze(1) - - # Create an extended output tensor - output_with_bonus_tokens = -torch.ones( - (batch_size, k + self._num_bonus_tokens), - dtype=self.token_id_dtype, - device=accepted.device) - output = output_with_bonus_tokens[:, :k] - - # Fill in the first k columns of the output tensor using masks and data - # tensors. - output[:, :k] = torch.where(accepted_mask, draft_token_ids, - -torch.ones_like(draft_token_ids)) - - # Fill the last column. - # We check output directly as accepted may have True values inconsistent - # with causal acceptance. - output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, - bonus_token_ids, -1) - - # We disable bonus tokens because it causes corrupt KV cache for - # proposal methods that require KV cache. We can fix it by "prefilling" - # the bonus token in the proposer. The following issue tracks the fix. - # https://github.com/vllm-project/vllm/issues/4212 - if self._disable_bonus_tokens: - output_with_bonus_tokens[:, -1] = -1 - - # Fill the recovered token ids. - output.mul_(~after_false_mask).add_( - recovered_token_ids.mul(after_false_mask)) - - self.num_accepted_tokens += accepted.sum() - self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() - self.num_draft_tokens += batch_size * k - - return output_with_bonus_tokens - - def _raise_if_incorrect_shape( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - (target_batch_size, num_target_probs, - target_vocab_size) = target_probs.shape - bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape - draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape - draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape - - assert draft_batch_size == target_batch_size - assert num_draft_probs == num_target_probs - assert (draft_vocab_size == target_vocab_size - ), f"{draft_vocab_size=} {target_vocab_size=}" - - assert draft_token_ids_batch_size == draft_batch_size - assert num_draft_token_ids == num_draft_probs - - assert bonus_batch_size == target_batch_size - assert num_bonus_tokens == self._num_bonus_tokens - - def _raise_if_incorrect_dtype( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - assert all(probs.dtype == self.probs_dtype - for probs in [target_probs, draft_probs]) - assert all(token_ids.dtype == self.token_id_dtype - for token_ids in [bonus_token_ids, draft_token_ids]) - - def _raise_if_inconsistent_device( - self, - target_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - devices = [ - t.device for t in - [target_probs, bonus_token_ids, draft_probs, draft_token_ids] - ] - assert all([devices[0] == device for device in devices]) - - def _raise_if_out_of_bounds_vocab( - self, - vocab_size: int, - bonus_token_ids: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> None: - assert torch.all(bonus_token_ids < vocab_size) - assert torch.all(bonus_token_ids >= 0) - assert torch.all(draft_token_ids < vocab_size) - assert torch.all(draft_token_ids >= 0) - # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py new file mode 100644 index 0000000000000..fa3bcac9bb233 --- /dev/null +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -0,0 +1,210 @@ +from typing import Optional + +import torch +import torch.jit + + +class SpecDecodeBaseSampler(): + """Base class for samplers used for Speculative Decoding verification + step. + """ + + def __init__(self, + disable_bonus_tokens: bool = True, + strict_mode: bool = False): + """Base class constructor. + Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self._disable_bonus_tokens = disable_bonus_tokens + self._strict_mode = strict_mode + + # NOTE: A "bonus token" is accepted iff all proposal tokens are + # accepted. There is always only one possible bonus token. We store this + # value in a variable for readability. + self._num_bonus_tokens = 1 + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + self.total_time: float = 0 + self.total_calls: float = 0 + + def init_gpu_tensors(self, rank: int) -> None: + assert self.num_accepted_tokens is None + device = f"cuda:{rank}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + @property + def probs_dtype(self): + return torch.float32 + + @property + def token_id_dtype(self): + return torch.int64 + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + substitute_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via sampling, all subsequent token ids are + set to -1 for the sequence. + + Args: + accepted: A boolean tensor indicating if the corresponding + draft token in draft_token_ids should be accepted or not. + substitute_token_ids: A tensor of token_ids that can be used + as substitutes for the draft token ids if the proposed token + is rejected. + draft_token_ids: A tensor of token ids speculated by the + draft model. + bonus_token_ids: Token ids to use as the bonus token if + all the draft tokens are accepted. + Returns: + A tensor containing the accepted token ids. The shape of the + tensor is [batch_size, k + num_bonus_tokens] + """ + batch_size, k = substitute_token_ids.shape + bonus_token_ids = bonus_token_ids.squeeze() + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # We disable bonus tokens because it causes corrupt KV cache for + # proposal methods that require KV cache. We can fix it by "prefilling" + # the bonus token in the proposer. The following issue tracks the fix. + # https://github.com/vllm-project/vllm/issues/4212 + if self._disable_bonus_tokens: + output_with_bonus_tokens[:, -1] = -1 + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + substitute_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_input( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + self._raise_if_incorrect_shape(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) + self._raise_if_incorrect_dtype(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) + self._raise_if_inconsistent_device(target_probs, draft_token_ids, + bonus_token_ids, draft_probs) + self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + draft_token_ids, bonus_token_ids) + + def _raise_if_incorrect_shape( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_probs.shape + + # validate the shape of draft token ids. + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + assert draft_token_ids_batch_size == target_batch_size + assert num_draft_token_ids == num_target_probs + + # validate the shape of bonus token ids + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + # validate the shape of draft probs if it is set + if draft_probs is not None: + (draft_batch_size, num_draft_probs, + draft_vocab_size) = draft_probs.shape + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + def _raise_if_incorrect_dtype( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + assert target_probs.dtype == self.probs_dtype + assert draft_token_ids.dtype == self.token_id_dtype + assert bonus_token_ids.dtype == self.token_id_dtype + if draft_probs is not None: + assert draft_probs.dtype == self.probs_dtype + + def _raise_if_inconsistent_device( + self, + target_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + devices = [ + t.device for t in + [target_probs, bonus_token_ids, draft_probs, draft_token_ids] + if t is not None + ] + assert all([devices[0] == device for device in devices]) + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py new file mode 100644 index 0000000000000..aded4505b0749 --- /dev/null +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -0,0 +1,190 @@ +import torch +import torch.jit +import torch.nn as nn +import time + +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) + + +class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): + """Apply typical acceptance sampling as described in section 3.3.1 in + "MEDUSA: Simple LLM Inference Acceleration Framework with + Multiple Decoding Heads" + https://arxiv.org/pdf/2401.10774 + """ + def __init__( + self, + disable_bonus_tokens: bool = False, + strict_mode: bool = False, + posterior_threshold: float = 0.09, + posterior_alpha: float = 0.3, + ): + """Create a Typical Acceptance Sampler. + + Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + posterior_threshold : A threshold value that sets a lower bound + on the posterior probability of a token in target model for it + to be accepted. Default is 0.09 + posterior_alpha : A scaling factor for the entropy-based + threshold in typical acceptance sampling. Typically defaults to + sqrt of posterior_threshold and is set to 0.3. + """ + self._posterior_threshold = posterior_threshold + self._posterior_alpha = posterior_alpha + super().__init__() + SpecDecodeBaseSampler.__init__( + self, + disable_bonus_tokens=disable_bonus_tokens, + strict_mode=strict_mode) + nn.Module.__init__(self) + + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + """Sample token ids using typical acceptance sampling. This accepts + or rejects tokens proposed by the draft model using the probability + of each token according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one token will be emitted. + + In the case where all draft tokens are accepted, the bonus token will be + accepted conditioned on self._disable_bonus_tokens being false. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + start = time.time() + if self._strict_mode: + self._raise_if_incorrect_input(target_probs, draft_token_ids, + bonus_token_ids) + accepted = self._evaluate_accepted_tokens(target_probs, + draft_token_ids) + recovered_token_ids = self._replacement_token_ids(target_probs) + output_token_ids = self._create_output(accepted, recovered_token_ids, + draft_token_ids, + bonus_token_ids) + #print('draft_token_ids ' + str(draft_token_ids)) + #print('output_token_ids ' + str(output_token_ids)) + #print('target_probs ' + str(target_probs)) + end = time.time() + self.total_time += (end - start) + self.total_calls += 1 + return output_token_ids + + def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): + r""" + Evaluates and returns a mask of accepted tokens based on the + posterior probabilities. + + Parameters: + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) representing + the probabilities of each token in the vocabulary for each + position in the proposed sequence. This is the distribution + generated by the target model. + draft_token_ids : torch.Tensor + A tensor of shape (batch_size, k) representing the proposed + token ids. + + A draft token_id x_{n+k} is accepted if it satisifies the + following condition + + .. math:: + p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > + \min \left( \epsilon, \delta * \exp \left( + -H(p_{\text{original}}( + \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) + + where :math:`p_{\text{original}}` corresponds to target_probs + and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters + specified using self._posterior_threshold and self._posterior_alpha + + This method computes the posterior probabilities for the given + draft token ids based on the provided target probabilities. It + calculates the entropy of the posterior distribution and determines + a dynamic threshold for each token position using the provided + posterior_threshold and posterior_alpha values. The method then + returns a boolean mask indicating which tokens can be accepted. + + Returns: + ------- + torch.Tensor + A boolean tensor of shape (batch_size, k) where each element + indicates whether the corresponding draft token has been accepted + or rejected. True indicates acceptance and false indicates + rejection. + + """ + device = target_probs.device + candidates_prob = torch.gather( + target_probs, dim=-1, + index=draft_token_ids.unsqueeze(-1).to(device), ).squeeze(-1) + posterior_entropy = -torch.sum( + target_probs * torch.log(target_probs + 1e-5), dim=-1) + #print('posterior_entropy ' + str(posterior_entropy)) + threshold = torch.minimum( + torch.ones_like(posterior_entropy, device=device) * self._posterior_threshold, + torch.exp(-posterior_entropy) * self._posterior_alpha, + ) + accepted_mask = candidates_prob > threshold + return accepted_mask + + def _replacement_token_ids(self, target_probs): + """ + Generate one replacement token ID for each sequence based on target + probabilities. The replacement token is used as the fallback option + if typical acceptance sampling does not accept any draft tokens for + that particular sequence. + + This method computes the token IDs to be replaced by selecting the + token with the highest probability for each sequence in the first + position. The rest of the output is filled with -1. + + Parameters + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) containing + the target probability distribution + + Returns + ------- + torch.Tensor + A tensor of shape (batch_size, k) with the replacement + token IDs. Only the first column is set, and the rest of the + columns are filled with -1. + """ + max_indices = torch.argmax(target_probs[:, 0, :], dim=1) + output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), + dtype=self.token_id_dtype, device=target_probs.device) + output[:, 0] = max_indices + return output diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index ab1d96c558de7..da7066fb493c2 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -4,7 +4,7 @@ import torch -from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler from vllm.utils import is_pin_memory_available @@ -41,6 +41,10 @@ class SpecDecodeWorkerMetrics: # The number of speculative tokens per sequence. num_spec_tokens: int + total_time: float + total_calls: int + avg_time: float + Timer = Callable[[], float] @@ -51,10 +55,10 @@ class AsyncMetricsCollector: """ def __init__(self, - rejection_sampler: RejectionSampler, + spec_decode_base_sampler: SpecDecodeBaseSampler, timer: Optional[Timer] = None, collect_interval_s: float = 5.0): - self._rejection_sampler = rejection_sampler + self.spec_decode_base_sampler = spec_decode_base_sampler self._timer = time.time if timer is None else timer self._rank: Optional[int] = None @@ -117,13 +121,13 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: with torch.cuda.stream(self._copy_stream): self._aggregate_num_accepted_tokens.copy_( - self._rejection_sampler.num_accepted_tokens, non_blocking=True) + self.spec_decode_base_sampler.num_accepted_tokens, non_blocking=True) self._aggregate_num_emitted_tokens.copy_( - self._rejection_sampler.num_emitted_tokens, non_blocking=True) + self.spec_decode_base_sampler.num_emitted_tokens, non_blocking=True) # Number of draft tokens is calculated on CPU, so no copy is # required. self._aggregate_num_draft_tokens = ( - self._rejection_sampler.num_draft_tokens) + self.spec_decode_base_sampler.num_draft_tokens) aggregate_metrics_ready = torch.cuda.Event() aggregate_metrics_ready.record(self._copy_stream) @@ -160,6 +164,10 @@ def _collect_rejsample_metrics( else: system_efficiency = float("nan") + #print('emitted_tokens ' + str(emitted_tokens)) + #print('accepted_tokens ' + str(accepted_tokens)) + #print('draft_tokens ' + str(draft_tokens)) + return SpecDecodeWorkerMetrics( num_spec_tokens=k, draft_acceptance_rate=draft_acceptance_rate, @@ -167,6 +175,9 @@ def _collect_rejsample_metrics( accepted_tokens=accepted_tokens, draft_tokens=draft_tokens, emitted_tokens=emitted_tokens, + total_time=self.spec_decode_base_sampler.total_time, + total_calls=self.spec_decode_base_sampler.total_calls, + avg_time= self.spec_decode_base_sampler.total_time * 1.0 / self.spec_decode_base_sampler.total_calls * 1.0 ) @staticmethod diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 150e8db0c8aad..40cf07d85834b 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -6,6 +6,7 @@ from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.typical_acceptance_sampler import TypicalAcceptanceSampler from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceGroupMetadata) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer @@ -20,6 +21,7 @@ split_batch_by_proposal_len) from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase +import time logger = init_logger(__name__) @@ -50,6 +52,12 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": draft_worker_kwargs=draft_worker_kwargs, disable_by_batch_size=speculative_config. speculative_disable_by_batch_size, + draft_token_sampling_method=speculative_config. + draft_token_sampling_method, + typical_acceptance_sampler_posterior_threshold=speculative_config. + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=speculative_config. + typical_acceptance_sampler_posterior_alpha ) return spec_decode_worker @@ -89,6 +97,9 @@ def create_worker( scorer_worker: WorkerBase, draft_worker_kwargs: Dict[str, Any], disable_by_batch_size: Optional[int], + draft_token_sampling_method: Optional[str] = "rejection_sampler", + typical_acceptance_sampler_posterior_threshold: Optional[float] = 0.09, + typical_acceptance_sampler_posterior_alpha: Optional[float] = 0.3, ) -> "SpecDecodeWorker": ngram_prompt_lookup_max = ( @@ -107,19 +118,36 @@ def create_worker( logger.info("Configuring SpecDecodeWorker with proposer=%s", type(proposer_worker)) + + rejection_sampler: RejectionSampler = None + typical_acceptance_sampler: TypicalAcceptanceSampler = None + print('draft_token_sampling_method ' + str(draft_token_sampling_method)) + + if draft_token_sampling_method == "rejection_sampler": + rejection_sampler = RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens, ) + elif draft_token_sampling_method == "typical_acceptance_sampler": + typical_acceptance_sampler = TypicalAcceptanceSampler( + disable_bonus_tokens=disable_bonus_tokens, + posterior_threshold=\ + typical_acceptance_sampler_posterior_threshold, + posterior_alpha=typical_acceptance_sampler_posterior_alpha, + ) return SpecDecodeWorker( proposer_worker, scorer_worker, disable_by_batch_size=disable_by_batch_size, - rejection_sampler=RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens, )) + rejection_sampler=rejection_sampler, + typical_acceptance_sampler=typical_acceptance_sampler) + def __init__( self, proposer_worker: WorkerBase, scorer_worker: WorkerBase, - rejection_sampler: RejectionSampler, + rejection_sampler: Optional[RejectionSampler] = None, + typical_acceptance_sampler: Optional[TypicalAcceptanceSampler] = None, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, ): @@ -143,13 +171,21 @@ def __init__( self.scorer_worker = scorer_worker self.disable_by_batch_size = disable_by_batch_size or float("inf") self.rejection_sampler = rejection_sampler - - self._metrics = AsyncMetricsCollector( - rejection_sampler - ) if metrics_collector is None else metrics_collector - - self.probs_dtype = self.rejection_sampler.probs_dtype - self.token_id_dtype = self.rejection_sampler.token_id_dtype + self.typical_acceptance_sampler = typical_acceptance_sampler + if self.rejection_sampler is not None: + self._metrics = AsyncMetricsCollector( + self.rejection_sampler + ) if metrics_collector is None else metrics_collector + self.probs_dtype = self.rejection_sampler.probs_dtype + self.token_id_dtype = self.rejection_sampler.token_id_dtype + elif self.typical_acceptance_sampler is not None: + self._metrics = AsyncMetricsCollector( + self.typical_acceptance_sampler + ) if metrics_collector is None else metrics_collector + self.probs_dtype = self.typical_acceptance_sampler.probs_dtype + self.token_id_dtype = self.typical_acceptance_sampler.token_id_dtype + self._total_time_in_verify = 0 + self._num_calls = 0 # Lazy initiazliation. self.scorer: SpeculativeScorer @@ -167,7 +203,11 @@ def init_device(self) -> None: self.proposer_worker.load_model() self._metrics.init_gpu_tensors(self.rank) - self.rejection_sampler.init_gpu_tensors(self.rank) + if (self.rejection_sampler is not None): + self.rejection_sampler.init_gpu_tensors(self.rank) + if (self.typical_acceptance_sampler is not None): + self.typical_acceptance_sampler.init_gpu_tensors(self.rank) + self.scorer = BatchExpansionTop1Scorer( scorer_worker=self.scorer_worker, device=self.device, @@ -445,13 +485,30 @@ def _verify_tokens( # Get proposed tokens. proposal_token_ids = proposals.proposal_token_ids[spec_indices] - accepted_token_ids = self.rejection_sampler( - target_probs=proposal_verifier_probs, - bonus_token_ids=bonus_token_ids, - draft_probs=proposal_probs, - draft_token_ids=proposal_token_ids, - ) - + if self.rejection_sampler is not None: + start_time = time.time() + accepted_token_ids = self.rejection_sampler( + target_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_probs=proposal_probs, + draft_token_ids=proposal_token_ids, + ) + end_time = time.time() + self._total_time_in_verify += (end_time - start_time) + self._num_calls += 1 + elif self.typical_acceptance_sampler is not None: + start_time = time.time() + accepted_token_ids = self.typical_acceptance_sampler( + target_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_token_ids=proposal_token_ids, + ) + end_time = time.time() + self._total_time_in_verify += (end_time - start_time) + self._num_calls += 1 + + #print('_total_time_in_verify ' + str(self._total_time_in_verify)) + #print('_num_calls ' + str(self._num_calls)) # Append output tokens from non-speculative sequences to # the accepted token ids tensor. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + From 3495673574d60ec6351d05f2a4cb14abfe2f7f6a Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Sun, 9 Jun 2024 20:53:18 +0000 Subject: [PATCH 02/38] Fixing tests --- tests/spec_decode/e2e/conftest.py | 76 +------- .../e2e/test_multistep_correctness.py | 88 ++++++++- tests/spec_decode/test_dynamic_spec_decode.py | 12 +- tests/spec_decode/test_spec_decode_worker.py | 176 +++++++++++------- tests/spec_decode/test_utils.py | 23 ++- vllm/config.py | 1 - vllm/engine/arg_utils.py | 16 +- .../layers/typical_acceptance_sampler.py | 3 +- vllm/spec_decode/spec_decode_worker.py | 73 ++++---- 9 files changed, 274 insertions(+), 194 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 9476c2c503166..ac086b2482c21 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -241,71 +241,7 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero. """ - temperature = 0.8 - - prompts = [ - #"Hello, my name is", - #"The president of the United States is", - #"The capital of France is", - #"The future of AI is", - #"San Francisco is know for its", - #"Facebook was created in 2004 by", - #"Curious George is a", - "Python 3.11 brings improvements to its", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - # If the test requires that we generated max_output_len tokens, then set the - # sampling params to ignore eos token. - ignore_eos = force_output_len - - sampling_params = SamplingParams( - max_tokens=max_output_len, - ignore_eos=ignore_eos, - temperature=temperature, - ) - start = time.time() - spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( - test_llm_generator, prompts, sampling_params) - end = time.time() - - print('time for spec decode ' + str(end - start)) - - start = time.time() - (baseline_batch_tokens, - baseline_batch_token_ids) = get_output_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - end = time.time() - print('time for base line ' + str(end - start)) - - #assert len(baseline_batch_token_ids) == len(prompts) - assert len(spec_batch_token_ids) == len(prompts) - - for i, (baseline_token_ids, baseline_tokens, spec_token_ids, - spec_tokens) in enumerate( - zip(baseline_batch_token_ids, baseline_batch_tokens, - spec_batch_token_ids, spec_batch_tokens)): - if True: - print(f'{i=} {baseline_tokens=}') - print(f'{i=} {spec_tokens=}') - #print(f'{i=} {baseline_token_ids=}') - #print(f'{i=} {spec_token_ids=}') - assert baseline_token_ids == spec_token_ids - - - -def compare_sampler_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len, - force_output_len: bool, - print_tokens: bool = False): - """Helper method that compares the outputs of both the baseline LLM and - the test LLM. It asserts greedy equality, e.g. that the outputs are exactly - the same when temperature is zero. - """ - temperature = 1.0 + temperature = 0.0 prompts = [ "Hello, my name is", @@ -329,15 +265,20 @@ def compare_sampler_test(baseline_llm_generator, ignore_eos=ignore_eos, temperature=temperature, ) - + start = time.time() spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( test_llm_generator, prompts, sampling_params) + end = time.time() + print('time for spec decode ' + str(end - start)) + start = time.time() (baseline_batch_tokens, baseline_batch_token_ids) = get_output_from_llm_generator( baseline_llm_generator, prompts, sampling_params) + end = time.time() + print('time for base line ' + str(end - start)) - assert len(baseline_batch_token_ids) == len(prompts) + #assert len(baseline_batch_token_ids) == len(prompts) assert len(spec_batch_token_ids) == len(prompts) for i, (baseline_token_ids, baseline_tokens, spec_token_ids, @@ -351,7 +292,6 @@ def compare_sampler_test(baseline_llm_generator, print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids - def wait_for_gpu_memory_to_clear(devices: List[int], threshold_bytes: int, timeout_s: float = 120) -> None: diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 1670435d2dad6..c25800b91ad56 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -177,6 +177,14 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": "rejection_sampler" + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": ( + "typical_acceptance_sampler" + ) }, ]) @pytest.mark.parametrize( @@ -221,9 +229,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( [ # Try two different tiny base models. # Note that one is equal to the draft model, another isn't. - #{ - # "model": "JackFram/llama-68m", - #}, + { + "model": "JackFram/llama-68m", + }, { "model": "JackFram/llama-160m", }, @@ -233,6 +241,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 3, + "speculative_draft_token_sampling_method": "rejection_sampler" + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": ( + "typical_acceptance_sampler" + ) }, ]) @pytest.mark.parametrize( @@ -281,6 +297,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_1( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": "rejection_sampler" + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": ( + "typical_acceptance_sampler" + ) }, ]) @pytest.mark.parametrize("max_output_len", [ @@ -322,6 +346,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": "rejection_sampler" + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": ( + "typical_acceptance_sampler" + ) }, ]) @pytest.mark.parametrize("batch_size", [1]) @@ -366,9 +398,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": "rejection_sampler" + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": ( + "typical_acceptance_sampler" + ) }, ]) -@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( "output_len", [ @@ -413,6 +453,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": "rejection_sampler" + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": ( + "typical_acceptance_sampler" + ) }, ]) @pytest.mark.parametrize( @@ -467,6 +515,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": "rejection_sampler" + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_draft_token_sampling_method": ( + "typical_acceptance_sampler" + ) }, ]) @pytest.mark.parametrize("batch_size", [2]) @@ -508,10 +564,20 @@ def test_spec_decode_different_block_size(baseline_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. "speculative_max_model_len": 32, + "speculative_draft_token_sampling_method": "rejection_sampler" + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + # Artificially limit the draft model max model len; this forces vLLM + # to skip speculation once the sequences grow beyond 32-k tokens. + "speculative_max_model_len": 32, + "speculative_draft_token_sampling_method": ( + "typical_acceptance_sampler" + ) }, ]) @pytest.mark.parametrize("batch_size", [8]) @@ -556,6 +622,15 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, "speculative_disable_by_batch_size": 2, + "speculative_draft_token_sampling_method": "rejection_sampler" + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_disable_by_batch_size": 2, + "speculative_draft_token_sampling_method": ( + "typical_acceptance_sampler" + ) }, ]) @pytest.mark.parametrize("batch_size", [8]) @@ -591,9 +666,12 @@ def test_disable_speculation(baseline_llm_generator, test_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": k, + "speculative_draft_token_sampling_method": method, } # Try a range of common k, as well as large speculation. for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] + # Try both methods of sampling in the verifier. + for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize( diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 48fa862b2e41a..e5899fc9bb8f3 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -11,25 +11,27 @@ from vllm.spec_decode.top1_proposer import Top1Proposer from .utils import create_batch, mock_worker - +from .test_utils import mock_sampler_factory @pytest.mark.parametrize('queue_size', [4]) @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('k', [1]) +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): +def test_disable_spec_tokens( + queue_size: int, batch_size: int, k: int, mock_sampler_factory): """Verify that speculative tokens are disabled when the batch size exceeds the threshold. """ disable_by_batch_size = 3 - draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(proposer_worker=draft_worker, scorer_worker=target_worker, - rejection_sampler=rejection_sampler, + rejection_sampler=mock_sampler_factory[0], + typical_acceptance_sampler=mock_sampler_factory[1], metrics_collector=metrics_collector, disable_by_batch_size=disable_by_batch_size) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index ef9d32f73d668..17ec60ad8e117 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -6,6 +6,7 @@ import torch from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.typical_acceptance_sampler import TypicalAcceptanceSampler from vllm.model_executor.utils import set_random_seed from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals @@ -16,22 +17,24 @@ split_num_cache_blocks_evenly) from .utils import create_batch, create_sampler_output_list, mock_worker - +from .test_utils import mock_sampler_factory @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_correctly_calls_draft_model(k: int, batch_size: int): +def test_correctly_calls_draft_model( + k: int, batch_size: int, mock_sampler_factory): """Verify SpecDecodeWorker calls the draft worker with correct inputs. Everything else is mocked out. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, - metrics_collector) - + worker = SpecDecodeWorker(draft_worker, target_worker, + mock_sampler_factory[0], + mock_sampler_factory[1], metrics_collector) exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) @@ -52,8 +55,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_correctly_calls_target_model(k: int, batch_size: int): +def test_correctly_calls_target_model( + k: int, batch_size: int, mock_sampler_factory): """Verify SpecDecodeWorker calls the target model with correct inputs. Everything else is mocked out. """ @@ -68,7 +74,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int): set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, + mock_sampler_factory[0], + mock_sampler_factory[1], metrics_collector) worker.init_device() @@ -132,8 +140,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int): @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_correctly_calls_rejection_sampler(k: int, batch_size: int): +def test_correctly_calls_rejection_sampler( + k: int, batch_size: int, mock_sampler_factory): """Verify SpecDecodeWorker calls the rejection sampler with correct inputs. Everything else is mocked out. """ @@ -143,16 +154,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 + rejection_sampler, typical_acceptance_sampler = mock_sampler_factory metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, target_worker, + rejection_sampler, + typical_acceptance_sampler, metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -198,16 +209,25 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): target_worker.execute_model.return_value = [target_output[0]] exception_secret = 'artificial stop' - rejection_sampler.side_effect = ValueError(exception_secret) + if rejection_sampler: + rejection_sampler.side_effect = ValueError(exception_secret) + else: + typical_acceptance_sampler.side_effect = ValueError( + exception_secret) with pytest.raises(ValueError, match=exception_secret): worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) - - assert len(rejection_sampler.call_args_list) == 1 - _, kwargs = rejection_sampler.call_args_list[0] - actual = SimpleNamespace(**kwargs) + + if rejection_sampler: + assert len(rejection_sampler.call_args_list) == 1 + _, kwargs = rejection_sampler.call_args_list[0] + actual = SimpleNamespace(**kwargs) + else: + assert len(typical_acceptance_sampler.call_args_list) == 1 + _, kwargs = typical_acceptance_sampler.call_args_list[0] + actual = SimpleNamespace(**kwargs) assert torch.equal(actual.bonus_token_ids, target_token_ids.reshape(batch_size, k + 1)[:, -1:]) @@ -215,13 +235,17 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): actual.target_probs, target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) assert torch.equal(actual.draft_token_ids, proposal_token_ids) - assert torch.equal(actual.draft_probs, proposal_probs) + if rejection_sampler: + assert torch.equal(actual.draft_probs, proposal_probs) @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_correctly_formats_output(k: int, batch_size: int): +def test_correctly_formats_output( + k: int, batch_size: int, mock_sampler_factory): """Verify SpecDecodeWorker formats sampler output correctly. Everything else is mocked out. """ @@ -231,15 +255,15 @@ def test_correctly_formats_output(k: int, batch_size: int): vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) - - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + rejection_sampler, typical_acceptance_sampler = mock_sampler_factory + worker = SpecDecodeWorker(draft_worker, target_worker, + rejection_sampler, + typical_acceptance_sampler, metrics_collector) worker.init_device() @@ -285,24 +309,25 @@ def test_correctly_formats_output(k: int, batch_size: int): target_worker.execute_model.return_value = [target_output[0]] - rejection_sampler_output = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k + 1), - dtype=torch.int64, - device='cuda') + sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') for i in range(batch_size): minimum_accepted_tokens = 1 - rejection_sampler_output[i][ + sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - - rejection_sampler.return_value = rejection_sampler_output - + if rejection_sampler: + rejection_sampler.return_value = sampler_output + else: + typical_acceptance_sampler.return_value = sampler_output output = worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) expected_output = create_sampler_output_list( - token_ids=rejection_sampler_output.transpose(0, 1), + token_ids=sampler_output.transpose(0, 1), probs=[None for _ in range(k + 1)], logprobs=[None for _ in range(k + 1)]) @@ -343,8 +368,11 @@ def test_correctly_formats_output(k: int, batch_size: int): @pytest.mark.parametrize('k', [1, 2]) @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('returns_metrics', [True, False]) +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): +def test_collects_metrics( + k: int, batch_size: int, returns_metrics: bool, mock_sampler_factory): """Verify SpecDecodeWorker collects metrics. """ vocab_size = 32_000 @@ -353,8 +381,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 + rejection_sampler, typical_acceptance_sampler = mock_sampler_factory metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -362,6 +389,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + typical_acceptance_sampler, metrics_collector) worker.init_device() @@ -407,17 +435,19 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): target_worker.execute_model.return_value = [target_output[0]] - rejection_sampler_output = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k + 1), - dtype=torch.int64, - device='cuda') + sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') for i in range(batch_size): minimum_accepted_tokens = 1 - rejection_sampler_output[i][ + sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - - rejection_sampler.return_value = rejection_sampler_output + if rejection_sampler: + rejection_sampler.return_value = sampler_output + else: + typical_acceptance_sampler.return_value = sampler_output mock_rejsample_metrics = MagicMock( spec=SpecDecodeWorkerMetrics) if returns_metrics else None @@ -438,15 +468,15 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): @pytest.mark.parametrize('k', [0]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_k_equals_zero(k: int, batch_size: int): +def test_k_equals_zero(k: int, batch_size: int, mock_sampler_factory): """Verify that the SpecDecodeWorker calls the draft and target workers when k is zero. This happens during prefill. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] @@ -456,7 +486,9 @@ def test_k_equals_zero(k: int, batch_size: int): set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, + mock_sampler_factory[0], + mock_sampler_factory[1], metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, @@ -478,16 +510,16 @@ def test_k_equals_zero(k: int, batch_size: int): @pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('batch_size', [0]) +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_empty_input_batch(k: int, batch_size: int): +def test_empty_input_batch(k: int, batch_size: int, mock_sampler_factory): """Verify that the SpecDecodeWorker calls the draft and target workers when the input batch is empty. This can happen if the engine communicates to the workers information without scheduling a batch. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] @@ -497,7 +529,9 @@ def test_empty_input_batch(k: int, batch_size: int): set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, + mock_sampler_factory[0], + mock_sampler_factory[1], metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, @@ -516,19 +550,21 @@ def test_empty_input_batch(k: int, batch_size: int): draft_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req) - +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @pytest.mark.skip_global_cleanup -def test_init_device(): +def test_init_device(mock_sampler_factory): """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as well as other GPU initialization. """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 + rejection_sampler, typical_acceptance_sampler = mock_sampler_factory metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, + rejection_sampler, + typical_acceptance_sampler, metrics_collector) worker.init_device() @@ -538,21 +574,26 @@ def test_init_device(): target_worker.init_device.assert_called_once() metrics_collector.init_gpu_tensors.assert_called_once() - rejection_sampler.init_gpu_tensors.assert_called_once() + if rejection_sampler: + rejection_sampler.init_gpu_tensors.assert_called_once() + else: + typical_acceptance_sampler.init_gpu_tensors.assert_called_once() +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_initialize_cache(): +def test_initialize_cache(mock_sampler_factory): """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer workers. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, + mock_sampler_factory[0], + mock_sampler_factory[1], metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} @@ -566,11 +607,14 @@ def test_initialize_cache(): @pytest.mark.parametrize('available_cpu_blocks', [500]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) +@pytest.mark.parametrize("mock_sampler_factory", + ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @pytest.mark.skip_global_cleanup def test_determine_num_available_blocks(available_gpu_blocks: int, available_cpu_blocks: int, target_cache_block_size_bytes: int, - draft_kv_size_bytes: int): + draft_kv_size_bytes: int, + mock_sampler_factory): """Verify SpecDecodeWorker correctly profiles num available GPU blocks. Specifically, it should run profiling in the scorer worker, and then evenly split the blocks between proposer and scorer worker. @@ -587,7 +631,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int, target_cache_block_size_bytes) draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, + mock_sampler_factory[0], + mock_sampler_factory[1], metrics_collector) num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 6b6f35a1a1d05..004c1bd8f7e30 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -1,9 +1,12 @@ from unittest.mock import MagicMock import pytest - +import torch from vllm.sequence import SequenceGroupMetadata from vllm.spec_decode.util import get_all_seq_ids, split_batch_by_proposal_len +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.typical_acceptance_sampler import TypicalAcceptanceSampler + def test_get_all_seq_ids(): @@ -109,3 +112,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): assert filtered_groups == [] assert indices == [] + +@pytest.fixture +def mock_sampler_factory(request): + def create_samplers(value): + if value == "rejection_sampler": + sampler = MagicMock(spec=RejectionSampler) + sampler.token_id_dtype = torch.int64 + return sampler, None + elif value == "typical_acceptance_sampler": + sampler = MagicMock(spec=TypicalAcceptanceSampler) + sampler.token_id_dtype = torch.int64 + return None, sampler + else: + return None, None # Return None for both samplers if the value is not recognized + + value = request.param # Get the value passed to the fixture + return create_samplers(value) + diff --git a/vllm/config.py b/vllm/config.py index 84fa2efaa9de8..267ad2639b8ce 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -689,7 +689,6 @@ def __init__( self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill self.embedding_mode = embedding_mode - self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 217fa572ca419..a5fe21ed78a59 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -96,9 +96,9 @@ class EngineArgs: speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None - speculative_draft_token_sampling_method: str = 'typical_acceptance_sampler' - typical_acceptance_sampler_posterior_threshold: float = 0.09 - typical_acceptance_sampler_posterior_alpha: float = 0.3 + speculative_draft_token_sampling_method: str = 'rejection_sampler' + typical_acceptance_sampler_posterior_threshold: float = 0.49 + typical_acceptance_sampler_posterior_alpha: float = 0.7 qlora_adapter_name_or_path: Optional[str] = None def __post_init__(self): @@ -680,10 +680,12 @@ def create_engine_config(self, ) -> EngineConfig: use_v2_block_manager=self.use_v2_block_manager, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, - draft_token_sampling_method="rejection_sampler", - #draft_token_sampling_method="typical_acceptance_sampler", - typical_acceptance_sampler_posterior_threshold=0.09, - typical_acceptance_sampler_posterior_alpha=0.3, + draft_token_sampling_method=self. + speculative_draft_token_sampling_method, + typical_acceptance_sampler_posterior_threshold=self. + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=self. + typical_acceptance_sampler_posterior_alpha, ) scheduler_config = SchedulerConfig( diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index aded4505b0749..798ed23556bbe 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -36,6 +36,7 @@ def __init__( threshold in typical acceptance sampling. Typically defaults to sqrt of posterior_threshold and is set to 0.3. """ + print('Hello in tas') self._posterior_threshold = posterior_threshold self._posterior_alpha = posterior_alpha super().__init__() @@ -148,7 +149,7 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): device = target_probs.device candidates_prob = torch.gather( target_probs, dim=-1, - index=draft_token_ids.unsqueeze(-1).to(device), ).squeeze(-1) + index=draft_token_ids.unsqueeze(-1), ).squeeze(-1) posterior_entropy = -torch.sum( target_probs * torch.log(target_probs + 1e-5), dim=-1) #print('posterior_entropy ' + str(posterior_entropy)) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 40cf07d85834b..09fa41022963e 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -121,8 +121,6 @@ def create_worker( rejection_sampler: RejectionSampler = None typical_acceptance_sampler: TypicalAcceptanceSampler = None - print('draft_token_sampling_method ' + str(draft_token_sampling_method)) - if draft_token_sampling_method == "rejection_sampler": rejection_sampler = RejectionSampler( disable_bonus_tokens=disable_bonus_tokens, ) @@ -132,7 +130,7 @@ def create_worker( posterior_threshold=\ typical_acceptance_sampler_posterior_threshold, posterior_alpha=typical_acceptance_sampler_posterior_alpha, - ) + ) return SpecDecodeWorker( proposer_worker, @@ -172,21 +170,13 @@ def __init__( self.disable_by_batch_size = disable_by_batch_size or float("inf") self.rejection_sampler = rejection_sampler self.typical_acceptance_sampler = typical_acceptance_sampler - if self.rejection_sampler is not None: - self._metrics = AsyncMetricsCollector( - self.rejection_sampler - ) if metrics_collector is None else metrics_collector - self.probs_dtype = self.rejection_sampler.probs_dtype - self.token_id_dtype = self.rejection_sampler.token_id_dtype - elif self.typical_acceptance_sampler is not None: - self._metrics = AsyncMetricsCollector( - self.typical_acceptance_sampler - ) if metrics_collector is None else metrics_collector - self.probs_dtype = self.typical_acceptance_sampler.probs_dtype - self.token_id_dtype = self.typical_acceptance_sampler.token_id_dtype - self._total_time_in_verify = 0 - self._num_calls = 0 - + sampler = self.rejection_sampler or self.typical_acceptance_sampler + assert sampler is not None, "Sampler is Not set, which is not expected." + self._metrics = AsyncMetricsCollector( + sampler + ) if metrics_collector is None else metrics_collector + self.probs_dtype = sampler.probs_dtype + self.token_id_dtype = sampler.token_id_dtype # Lazy initiazliation. self.scorer: SpeculativeScorer @@ -217,7 +207,7 @@ def init_device(self) -> None: def load_model(self, *args, **kwargs): pass - + def _configure_model_sampler_for_spec_decode(self): """Configure model sampler to emit GPU tensors. This allows spec decode to keep data on device without transferring to CPU and serializing, @@ -485,27 +475,6 @@ def _verify_tokens( # Get proposed tokens. proposal_token_ids = proposals.proposal_token_ids[spec_indices] - if self.rejection_sampler is not None: - start_time = time.time() - accepted_token_ids = self.rejection_sampler( - target_probs=proposal_verifier_probs, - bonus_token_ids=bonus_token_ids, - draft_probs=proposal_probs, - draft_token_ids=proposal_token_ids, - ) - end_time = time.time() - self._total_time_in_verify += (end_time - start_time) - self._num_calls += 1 - elif self.typical_acceptance_sampler is not None: - start_time = time.time() - accepted_token_ids = self.typical_acceptance_sampler( - target_probs=proposal_verifier_probs, - bonus_token_ids=bonus_token_ids, - draft_token_ids=proposal_token_ids, - ) - end_time = time.time() - self._total_time_in_verify += (end_time - start_time) - self._num_calls += 1 #print('_total_time_in_verify ' + str(self._total_time_in_verify)) #print('_num_calls ' + str(self._num_calls)) @@ -514,16 +483,38 @@ def _verify_tokens( non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + 1).clone() non_spec_token_ids[:, 1:] = -1 + accepted_token_ids = self._get_accepted_token_ids( + proposal_verifier_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, + proposal_probs=proposal_probs, proposal_token_ids=proposal_token_ids) accepted_token_ids = torch.cat( [accepted_token_ids, non_spec_token_ids]) logprobs = proposal_scores.logprobs - # Rearrange so that results are in the order of the original seq group # metadata. accepted_token_ids[original_indices] = accepted_token_ids.clone() return accepted_token_ids, logprobs + def _get_accepted_token_ids(self, proposal_verifier_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + proposal_probs: torch.Tensor, + proposal_token_ids: torch.Tensor): + if self.rejection_sampler is not None: + accepted_token_ids = self.rejection_sampler( + target_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_probs=proposal_probs, + draft_token_ids=proposal_token_ids, + ) + else: + assert self.typical_acceptance_sampler is not None + accepted_token_ids = self.typical_acceptance_sampler( + target_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_token_ids=proposal_token_ids, + ) + return accepted_token_ids + def _create_output_sampler_list( self, seq_group_metadata_list: List[SequenceGroupMetadata], From 26c7c57f5058a7d704c3dc561c8c2fe901d38fa7 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 10 Jun 2024 01:39:09 +0000 Subject: [PATCH 03/38] adding missing commit --- tests/spec_decode/e2e/conftest.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index ac086b2482c21..e7537e5b97b4c 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -265,18 +265,12 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, ignore_eos=ignore_eos, temperature=temperature, ) - start = time.time() spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( test_llm_generator, prompts, sampling_params) - end = time.time() - print('time for spec decode ' + str(end - start)) - start = time.time() (baseline_batch_tokens, baseline_batch_token_ids) = get_output_from_llm_generator( baseline_llm_generator, prompts, sampling_params) - end = time.time() - print('time for base line ' + str(end - start)) #assert len(baseline_batch_token_ids) == len(prompts) assert len(spec_batch_token_ids) == len(prompts) From 090f0bfc0f093df87b8e7a178822e283af242286 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 10 Jun 2024 16:22:03 +0000 Subject: [PATCH 04/38] reverting changes to conftest --- tests/spec_decode/e2e/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index e7537e5b97b4c..61f61ac8ccea4 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -272,7 +272,7 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, baseline_batch_token_ids) = get_output_from_llm_generator( baseline_llm_generator, prompts, sampling_params) - #assert len(baseline_batch_token_ids) == len(prompts) + assert len(baseline_batch_token_ids) == len(prompts) assert len(spec_batch_token_ids) == len(prompts) for i, (baseline_token_ids, baseline_tokens, spec_token_ids, From 733cc6edf72e98848a16aaeb7289d02593152d44 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 10 Jun 2024 16:23:14 +0000 Subject: [PATCH 05/38] reverting changes to conftest --- tests/spec_decode/e2e/conftest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 61f61ac8ccea4..1d060e265848a 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -265,6 +265,7 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, ignore_eos=ignore_eos, temperature=temperature, ) + spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( test_llm_generator, prompts, sampling_params) @@ -286,6 +287,7 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids + def wait_for_gpu_memory_to_clear(devices: List[int], threshold_bytes: int, timeout_s: float = 120) -> None: From acf8d2c22435d48902d82371d38430fbc99e4f44 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 10 Jun 2024 16:26:27 +0000 Subject: [PATCH 06/38] Dummy commit --- tests/spec_decode/e2e/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 1d060e265848a..d53513ab6806e 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -12,7 +12,6 @@ if (not is_hip()): from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit) - from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine From 2010b35e28762a1e9923495248e5aedd487e19f0 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 10 Jun 2024 16:39:50 +0000 Subject: [PATCH 07/38] Revert unnecessary commits --- tests/spec_decode/e2e/conftest.py | 1 + tests/spec_decode/e2e/test_multistep_correctness.py | 11 +++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index d53513ab6806e..1d060e265848a 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -12,6 +12,7 @@ if (not is_hip()): from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit) + from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index c25800b91ad56..2ae06d9b8d2d0 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -221,8 +221,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( # Print spec metrics. "disable_log_stats": False, - - #"tensor_parallel_size" : 1, }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -257,9 +255,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( # Use small output len for fast test. 256, ]) -@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("batch_size", [64]) @pytest.mark.parametrize("seed", [1]) -def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_1( +def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): """Verify greedy equality on a tiny model and large batch size. @@ -268,7 +266,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_1( test_llm_generator, batch_size, max_output_len=output_len, - force_output_len=False) + force_output_len=True) @pytest.mark.parametrize( @@ -408,7 +406,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( ) }, ]) -@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize( "output_len", [ @@ -564,6 +562,7 @@ def test_spec_decode_different_block_size(baseline_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. "speculative_max_model_len": 32, From dea6fbd6c7d07f622a4789509c2cb71b18c5d5bb Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 10 Jun 2024 19:45:51 +0000 Subject: [PATCH 08/38] Pass only one sampler which can either be the RejectionSampler of the TypicalAcceptanceSampler --- vllm/spec_decode/spec_decode_worker.py | 31 ++++++++++---------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index d619ddb7c5b74..e58535cc8fef6 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -5,6 +5,7 @@ from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger +from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.typical_acceptance_sampler import TypicalAcceptanceSampler from vllm.sequence import (ExecuteModelRequest, SamplerOutput, @@ -120,13 +121,12 @@ def create_worker( logger.info("Configuring SpecDecodeWorker with proposer=%s", type(proposer_worker)) - rejection_sampler: RejectionSampler = None - typical_acceptance_sampler: TypicalAcceptanceSampler = None + sampler: SpecDecodeBaseSampler = None if draft_token_sampling_method == "rejection_sampler": - rejection_sampler = RejectionSampler( + sampler = RejectionSampler( disable_bonus_tokens=disable_bonus_tokens, ) elif draft_token_sampling_method == "typical_acceptance_sampler": - typical_acceptance_sampler = TypicalAcceptanceSampler( + sampler = TypicalAcceptanceSampler( disable_bonus_tokens=disable_bonus_tokens, posterior_threshold=\ typical_acceptance_sampler_posterior_threshold, @@ -137,16 +137,14 @@ def create_worker( proposer_worker, scorer_worker, disable_by_batch_size=disable_by_batch_size, - rejection_sampler=rejection_sampler, - typical_acceptance_sampler=typical_acceptance_sampler) + sampler=sampler) def __init__( self, proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, - rejection_sampler: Optional[RejectionSampler] = None, - typical_acceptance_sampler: Optional[TypicalAcceptanceSampler] = None, + sampler: SpecDecodeBaseSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, ): @@ -169,9 +167,7 @@ def __init__( self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker self.disable_by_batch_size = disable_by_batch_size or float("inf") - self.rejection_sampler = rejection_sampler - self.typical_acceptance_sampler = typical_acceptance_sampler - sampler = self.rejection_sampler or self.typical_acceptance_sampler + self.verification_sampler = sampler assert sampler is not None, "Sampler is Not set, which is not expected." self._metrics = AsyncMetricsCollector( sampler @@ -194,10 +190,7 @@ def init_device(self) -> None: self.proposer_worker.load_model() self._metrics.init_gpu_tensors(self.rank) - if (self.rejection_sampler is not None): - self.rejection_sampler.init_gpu_tensors(self.rank) - if (self.typical_acceptance_sampler is not None): - self.typical_acceptance_sampler.init_gpu_tensors(self.rank) + self.verification_sampler.init_gpu_tensors(self.rank) self.scorer = BatchExpansionTop1Scorer( scorer_worker=self.scorer_worker, @@ -500,16 +493,16 @@ def _get_accepted_token_ids(self, proposal_verifier_probs: torch.Tensor, bonus_token_ids: torch.Tensor, proposal_probs: torch.Tensor, proposal_token_ids: torch.Tensor): - if self.rejection_sampler is not None: - accepted_token_ids = self.rejection_sampler( + if isinstance(self.verification_sampler, RejectionSampler): + accepted_token_ids = self.verification_sampler( target_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_probs=proposal_probs, draft_token_ids=proposal_token_ids, ) else: - assert self.typical_acceptance_sampler is not None - accepted_token_ids = self.typical_acceptance_sampler( + assert isinstance(self.verification_sampler, TypicalAcceptanceSampler) + accepted_token_ids = self.verification_sampler( target_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_token_ids=proposal_token_ids, From c3383dbd757d51a86c75084c29a6a7b88633e46d Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Mon, 10 Jun 2024 23:41:27 +0000 Subject: [PATCH 09/38] Fix test scripture --- tests/spec_decode/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 004c1bd8f7e30..16b34cef7e903 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -119,13 +119,13 @@ def create_samplers(value): if value == "rejection_sampler": sampler = MagicMock(spec=RejectionSampler) sampler.token_id_dtype = torch.int64 - return sampler, None + return sampler elif value == "typical_acceptance_sampler": sampler = MagicMock(spec=TypicalAcceptanceSampler) sampler.token_id_dtype = torch.int64 - return None, sampler + return sampler else: - return None, None # Return None for both samplers if the value is not recognized + return None # Return None for both samplers if the value is not recognized value = request.param # Get the value passed to the fixture return create_samplers(value) From b15abba01ab3d8661a4f85faa822dd8e5e6b683f Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 11 Jun 2024 00:20:02 +0000 Subject: [PATCH 10/38] Fix tests --- tests/spec_decode/test_spec_decode_worker.py | 82 +++++++------------- 1 file changed, 27 insertions(+), 55 deletions(-) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 17ec60ad8e117..690cc9f4d4d23 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -33,8 +33,7 @@ def test_correctly_calls_draft_model( target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory[0], - mock_sampler_factory[1], metrics_collector) + mock_sampler_factory, metrics_collector) exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) @@ -65,8 +64,6 @@ def test_correctly_calls_target_model( """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' @@ -75,8 +72,7 @@ def test_correctly_calls_target_model( set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory[0], - mock_sampler_factory[1], + mock_sampler_factory, metrics_collector) worker.init_device() @@ -143,7 +139,7 @@ def test_correctly_calls_target_model( @pytest.mark.parametrize("mock_sampler_factory", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_correctly_calls_rejection_sampler( +def test_correctly_calls_verification_sampler( k: int, batch_size: int, mock_sampler_factory): """Verify SpecDecodeWorker calls the rejection sampler with correct inputs. Everything else is mocked out. @@ -154,7 +150,7 @@ def test_correctly_calls_rejection_sampler( vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - rejection_sampler, typical_acceptance_sampler = mock_sampler_factory + verification_sampler = mock_sampler_factory metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -162,8 +158,7 @@ def test_correctly_calls_rejection_sampler( set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - rejection_sampler, - typical_acceptance_sampler, metrics_collector) + verification_sampler, metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -209,25 +204,17 @@ def test_correctly_calls_rejection_sampler( target_worker.execute_model.return_value = [target_output[0]] exception_secret = 'artificial stop' - if rejection_sampler: - rejection_sampler.side_effect = ValueError(exception_secret) - else: - typical_acceptance_sampler.side_effect = ValueError( - exception_secret) + + verification_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) - - if rejection_sampler: - assert len(rejection_sampler.call_args_list) == 1 - _, kwargs = rejection_sampler.call_args_list[0] - actual = SimpleNamespace(**kwargs) - else: - assert len(typical_acceptance_sampler.call_args_list) == 1 - _, kwargs = typical_acceptance_sampler.call_args_list[0] - actual = SimpleNamespace(**kwargs) + + assert len(verification_sampler.call_args_list) == 1 + _, kwargs = verification_sampler.call_args_list[0] + actual = SimpleNamespace(**kwargs) assert torch.equal(actual.bonus_token_ids, target_token_ids.reshape(batch_size, k + 1)[:, -1:]) @@ -235,7 +222,7 @@ def test_correctly_calls_rejection_sampler( actual.target_probs, target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) assert torch.equal(actual.draft_token_ids, proposal_token_ids) - if rejection_sampler: + if isinstance(verification_sampler, RejectionSampler): assert torch.equal(actual.draft_probs, proposal_probs) @@ -260,10 +247,9 @@ def test_correctly_formats_output( target_worker.device = 'cuda' set_random_seed(1) - rejection_sampler, typical_acceptance_sampler = mock_sampler_factory + verification_sampler = mock_sampler_factory worker = SpecDecodeWorker(draft_worker, target_worker, - rejection_sampler, - typical_acceptance_sampler, + verification_sampler, metrics_collector) worker.init_device() @@ -318,10 +304,8 @@ def test_correctly_formats_output( minimum_accepted_tokens = 1 sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - if rejection_sampler: - rejection_sampler.return_value = sampler_output - else: - typical_acceptance_sampler.return_value = sampler_output + + verification_sampler.return_value = sampler_output output = worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) @@ -381,15 +365,15 @@ def test_collects_metrics( vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - rejection_sampler, typical_acceptance_sampler = mock_sampler_factory + verification_sampler = mock_sampler_factory metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, - typical_acceptance_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, + verification_sampler, metrics_collector) worker.init_device() @@ -444,10 +428,7 @@ def test_collects_metrics( minimum_accepted_tokens = 1 sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - if rejection_sampler: - rejection_sampler.return_value = sampler_output - else: - typical_acceptance_sampler.return_value = sampler_output + verification_sampler.return_value = sampler_output mock_rejsample_metrics = MagicMock( spec=SpecDecodeWorkerMetrics) if returns_metrics else None @@ -487,8 +468,7 @@ def test_k_equals_zero(k: int, batch_size: int, mock_sampler_factory): set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory[0], - mock_sampler_factory[1], + mock_sampler_factory, metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, @@ -530,8 +510,7 @@ def test_empty_input_batch(k: int, batch_size: int, mock_sampler_factory): set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory[0], - mock_sampler_factory[1], + mock_sampler_factory, metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, @@ -559,12 +538,11 @@ def test_init_device(mock_sampler_factory): """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) - rejection_sampler, typical_acceptance_sampler = mock_sampler_factory + verification_sampler = mock_sampler_factory metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, - rejection_sampler, - typical_acceptance_sampler, + verification_sampler, metrics_collector) worker.init_device() @@ -574,11 +552,7 @@ def test_init_device(mock_sampler_factory): target_worker.init_device.assert_called_once() metrics_collector.init_gpu_tensors.assert_called_once() - if rejection_sampler: - rejection_sampler.init_gpu_tensors.assert_called_once() - else: - typical_acceptance_sampler.init_gpu_tensors.assert_called_once() - + verification_sampler.init_gpu_tensors.assert_called_once() @pytest.mark.parametrize("mock_sampler_factory", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @@ -592,8 +566,7 @@ def test_initialize_cache(mock_sampler_factory): metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory[0], - mock_sampler_factory[1], + mock_sampler_factory, metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} @@ -632,8 +605,7 @@ def test_determine_num_available_blocks(available_gpu_blocks: int, draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory[0], - mock_sampler_factory[1], + mock_sampler_factory, metrics_collector) num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() From 6ca731cb11be2f7b9eb6162fd537122daf33365e Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 11 Jun 2024 00:26:06 +0000 Subject: [PATCH 11/38] Fix tests --- tests/spec_decode/test_spec_decode_worker.py | 32 +++++++++----------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 690cc9f4d4d23..73df70fa8586f 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -295,23 +295,23 @@ def test_correctly_formats_output( target_worker.execute_model.return_value = [target_output[0]] - sampler_output = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k + 1), - dtype=torch.int64, - device='cuda') + verification_sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') for i in range(batch_size): minimum_accepted_tokens = 1 - sampler_output[i][ + verification_sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - verification_sampler.return_value = sampler_output + verification_sampler.return_value = verification_sampler_output output = worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) expected_output = create_sampler_output_list( - token_ids=sampler_output.transpose(0, 1), + token_ids=verification_sampler_output.transpose(0, 1), probs=[None for _ in range(k + 1)], logprobs=[None for _ in range(k + 1)]) @@ -419,16 +419,16 @@ def test_collects_metrics( target_worker.execute_model.return_value = [target_output[0]] - sampler_output = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k + 1), - dtype=torch.int64, - device='cuda') + verification_sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') for i in range(batch_size): minimum_accepted_tokens = 1 - sampler_output[i][ + verification_sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - verification_sampler.return_value = sampler_output + verification_sampler.return_value = verification_sampler_output mock_rejsample_metrics = MagicMock( spec=SpecDecodeWorkerMetrics) if returns_metrics else None @@ -594,8 +594,6 @@ def test_determine_num_available_blocks(available_gpu_blocks: int, """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) target_worker.determine_num_available_blocks.return_value = ( From 483c671cc84a5cea770a64b66283366ba21af434 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 11 Jun 2024 00:32:02 +0000 Subject: [PATCH 12/38] Pass only 1 verification_sampler which can either be rejectionSampler of TypicalAcceptanceSampler --- tests/spec_decode/test_dynamic_spec_decode.py | 4 +--- vllm/spec_decode/spec_decode_worker.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index d3a68c1a0dbae..a5d1d49518242 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -3,7 +3,6 @@ import pytest import torch -from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.multi_step_worker import MultiStepWorker @@ -30,8 +29,7 @@ def test_disable_spec_tokens( metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(proposer_worker=draft_worker, scorer_worker=target_worker, - rejection_sampler=mock_sampler_factory[0], - typical_acceptance_sampler=mock_sampler_factory[1], + verification_sampler=mock_sampler_factory, metrics_collector=metrics_collector, disable_by_batch_size=disable_by_batch_size) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index e58535cc8fef6..78ed8397b255c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -144,7 +144,7 @@ def __init__( self, proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, - sampler: SpecDecodeBaseSampler, + verification_sampler: SpecDecodeBaseSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, ): @@ -167,13 +167,16 @@ def __init__( self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker self.disable_by_batch_size = disable_by_batch_size or float("inf") - self.verification_sampler = sampler - assert sampler is not None, "Sampler is Not set, which is not expected." + self.verification_sampler = verification_sampler + assert ( + self.verification_sampler is not None, + "Sampler is Not set, which is not expected." + ) self._metrics = AsyncMetricsCollector( - sampler + self.verification_sampler ) if metrics_collector is None else metrics_collector - self.probs_dtype = sampler.probs_dtype - self.token_id_dtype = sampler.token_id_dtype + self.probs_dtype = self.verification_sampler.probs_dtype + self.token_id_dtype = self.verification_sampler.token_id_dtype # Lazy initiazliation. self.scorer: SpeculativeScorer From 2c6d06c47aaeb4e20331863edf7dd91fed35d7ba Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 11 Jun 2024 00:39:25 +0000 Subject: [PATCH 13/38] Update metrics.py to take the base sampler class --- vllm/spec_decode/metrics.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index da7066fb493c2..1df2ad391562e 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -41,11 +41,6 @@ class SpecDecodeWorkerMetrics: # The number of speculative tokens per sequence. num_spec_tokens: int - total_time: float - total_calls: int - avg_time: float - - Timer = Callable[[], float] @@ -164,10 +159,6 @@ def _collect_rejsample_metrics( else: system_efficiency = float("nan") - #print('emitted_tokens ' + str(emitted_tokens)) - #print('accepted_tokens ' + str(accepted_tokens)) - #print('draft_tokens ' + str(draft_tokens)) - return SpecDecodeWorkerMetrics( num_spec_tokens=k, draft_acceptance_rate=draft_acceptance_rate, @@ -175,9 +166,6 @@ def _collect_rejsample_metrics( accepted_tokens=accepted_tokens, draft_tokens=draft_tokens, emitted_tokens=emitted_tokens, - total_time=self.spec_decode_base_sampler.total_time, - total_calls=self.spec_decode_base_sampler.total_calls, - avg_time= self.spec_decode_base_sampler.total_time * 1.0 / self.spec_decode_base_sampler.total_calls * 1.0 ) @staticmethod From 027b4855494f178fa4745d3e84fed815b7b935a8 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 11 Jun 2024 06:34:52 +0000 Subject: [PATCH 14/38] Fix tests and comments --- .../e2e/test_multistep_correctness.py | 120 +++++------------- tests/spec_decode/test_dynamic_spec_decode.py | 2 +- tests/spec_decode/test_spec_decode_worker.py | 58 ++++----- vllm/config.py | 15 ++- vllm/engine/arg_utils.py | 15 ++- vllm/engine/metrics.py | 5 +- vllm/spec_decode/metrics.py | 10 +- vllm/spec_decode/spec_decode_worker.py | 42 +++--- 8 files changed, 115 insertions(+), 152 deletions(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 2ae06d9b8d2d0..a844d3fee36bf 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -9,11 +9,15 @@ Since speculative decoding with rejection sampling guarantees that the output distribution matches the target model's output distribution (up to hardware numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy -equality. This gives us good coverage of temp=0. +equality. This gives us good coverage of temp=0. At temp=0, the +TypicalAcceptanceSampler ensures that only the tokens with the highest +probability in the target distribution are accepted. Therefore, we can +expect greedy equality for the TypicalAcceptanceSampler at temp=0. For temp>0, we rely on unit tests on the rejection sampler to verify that the output distribution is the same with spec decode vs. no spec decode (this would -be prohibitively expensive to run with a real model). +be prohibitively expensive to run with a real model). Similary, for the +TypicalAcceptance sampler, we rely on unit tests to validate temp>0 test cases. NOTE: Speculative decoding's distribution equality requires that the measured distributions of the target model and proposal model be deterministic given the @@ -177,15 +181,9 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": "rejection_sampler" - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": ( - "typical_acceptance_sampler" - ) - }, + "speculative_draft_token_sampling_method": method + } + for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize( "output_len", @@ -239,15 +237,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 3, - "speculative_draft_token_sampling_method": "rejection_sampler" - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": ( - "typical_acceptance_sampler" - ) - }, + "speculative_draft_token_sampling_method": method + } + for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize( "output_len", @@ -295,15 +287,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": "rejection_sampler" - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": ( - "typical_acceptance_sampler" - ) - }, + "speculative_draft_token_sampling_method": method + } + for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("max_output_len", [ 256, @@ -344,15 +330,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": "rejection_sampler" - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": ( - "typical_acceptance_sampler" - ) - }, + "speculative_draft_token_sampling_method": method + } + for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize( @@ -396,15 +376,9 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": "rejection_sampler" - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": ( - "typical_acceptance_sampler" - ) - }, + "speculative_draft_token_sampling_method": method + } + for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize( @@ -451,15 +425,9 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": "rejection_sampler" - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": ( - "typical_acceptance_sampler" - ) - }, + "speculative_draft_token_sampling_method": method + } + for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize( "output_len", @@ -513,15 +481,9 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": "rejection_sampler" - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": ( - "typical_acceptance_sampler" - ) - }, + "speculative_draft_token_sampling_method": method + } + for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize( @@ -566,18 +528,9 @@ def test_spec_decode_different_block_size(baseline_llm_generator, # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. "speculative_max_model_len": 32, - "speculative_draft_token_sampling_method": "rejection_sampler" - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - # Artificially limit the draft model max model len; this forces vLLM - # to skip speculation once the sequences grow beyond 32-k tokens. - "speculative_max_model_len": 32, - "speculative_draft_token_sampling_method": ( - "typical_acceptance_sampler" - ) - }, + "speculative_draft_token_sampling_method": method + } + for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( @@ -621,16 +574,9 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, "speculative_disable_by_batch_size": 2, - "speculative_draft_token_sampling_method": "rejection_sampler" - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_disable_by_batch_size": 2, - "speculative_draft_token_sampling_method": ( - "typical_acceptance_sampler" - ) - }, + "speculative_draft_token_sampling_method": method + } + for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("output_len", [10]) diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index a5d1d49518242..5ee2480ac14df 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -29,7 +29,7 @@ def test_disable_spec_tokens( metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(proposer_worker=draft_worker, scorer_worker=target_worker, - verification_sampler=mock_sampler_factory, + spec_decode_sampler=mock_sampler_factory, metrics_collector=metrics_collector, disable_by_batch_size=disable_by_batch_size) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 73df70fa8586f..4e8c75b999b76 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -139,7 +139,7 @@ def test_correctly_calls_target_model( @pytest.mark.parametrize("mock_sampler_factory", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_correctly_calls_verification_sampler( +def test_correctly_calls_spec_decode_sampler( k: int, batch_size: int, mock_sampler_factory): """Verify SpecDecodeWorker calls the rejection sampler with correct inputs. Everything else is mocked out. @@ -150,7 +150,7 @@ def test_correctly_calls_verification_sampler( vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - verification_sampler = mock_sampler_factory + spec_decode_base_sampler = mock_sampler_factory metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -158,7 +158,7 @@ def test_correctly_calls_verification_sampler( set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - verification_sampler, metrics_collector) + spec_decode_base_sampler, metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -205,15 +205,15 @@ def test_correctly_calls_verification_sampler( exception_secret = 'artificial stop' - verification_sampler.side_effect = ValueError(exception_secret) + spec_decode_base_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) - assert len(verification_sampler.call_args_list) == 1 - _, kwargs = verification_sampler.call_args_list[0] + assert len(spec_decode_base_sampler.call_args_list) == 1 + _, kwargs = spec_decode_base_sampler.call_args_list[0] actual = SimpleNamespace(**kwargs) assert torch.equal(actual.bonus_token_ids, @@ -222,7 +222,7 @@ def test_correctly_calls_verification_sampler( actual.target_probs, target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) assert torch.equal(actual.draft_token_ids, proposal_token_ids) - if isinstance(verification_sampler, RejectionSampler): + if isinstance(spec_decode_base_sampler, RejectionSampler): assert torch.equal(actual.draft_probs, proposal_probs) @@ -247,9 +247,9 @@ def test_correctly_formats_output( target_worker.device = 'cuda' set_random_seed(1) - verification_sampler = mock_sampler_factory + spec_decode_base_sampler = mock_sampler_factory worker = SpecDecodeWorker(draft_worker, target_worker, - verification_sampler, + spec_decode_base_sampler, metrics_collector) worker.init_device() @@ -295,23 +295,23 @@ def test_correctly_formats_output( target_worker.execute_model.return_value = [target_output[0]] - verification_sampler_output = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k + 1), - dtype=torch.int64, - device='cuda') + spec_decode_sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') for i in range(batch_size): minimum_accepted_tokens = 1 - verification_sampler_output[i][ + spec_decode_sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - verification_sampler.return_value = verification_sampler_output + spec_decode_base_sampler.return_value = spec_decode_sampler_output output = worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) expected_output = create_sampler_output_list( - token_ids=verification_sampler_output.transpose(0, 1), + token_ids=spec_decode_sampler_output.transpose(0, 1), probs=[None for _ in range(k + 1)], logprobs=[None for _ in range(k + 1)]) @@ -365,7 +365,7 @@ def test_collects_metrics( vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - verification_sampler = mock_sampler_factory + spec_decode_sampler = mock_sampler_factory metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -373,7 +373,7 @@ def test_collects_metrics( set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - verification_sampler, + spec_decode_sampler, metrics_collector) worker.init_device() @@ -419,16 +419,16 @@ def test_collects_metrics( target_worker.execute_model.return_value = [target_output[0]] - verification_sampler_output = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k + 1), - dtype=torch.int64, - device='cuda') + spec_decode_sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') for i in range(batch_size): minimum_accepted_tokens = 1 - verification_sampler_output[i][ + spec_decode_sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - verification_sampler.return_value = verification_sampler_output + spec_decode_sampler.return_value = spec_decode_sampler_output mock_rejsample_metrics = MagicMock( spec=SpecDecodeWorkerMetrics) if returns_metrics else None @@ -538,11 +538,11 @@ def test_init_device(mock_sampler_factory): """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) - verification_sampler = mock_sampler_factory + spec_decode_sampler = mock_sampler_factory metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, - verification_sampler, + spec_decode_sampler, metrics_collector) worker.init_device() @@ -552,7 +552,7 @@ def test_init_device(mock_sampler_factory): target_worker.init_device.assert_called_once() metrics_collector.init_gpu_tensors.assert_called_once() - verification_sampler.init_gpu_tensors.assert_called_once() + spec_decode_sampler.init_gpu_tensors.assert_called_once() @pytest.mark.parametrize("mock_sampler_factory", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) diff --git a/vllm/config.py b/vllm/config.py index d84fcff008e62..31a4cb48a0d5d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -803,7 +803,20 @@ def maybe_create_spec_config( window, if provided. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token window, if provided. - + draft_token_sampling_method (Optional[str]): The sampling method + to use for accepting draft tokens. This can take two possible + values 'rejection_sampler' and 'typical_acceptance_sampler' + for RejectionSampler and TypicalAcceptanceSampler + respectively. + typical_acceptance_sampler_posterior_threshold (Optional[float]): + A threshold value that sets a lower bound on the posterior + probability of a token in target model for it to be accepted. + This threshold is only used when we use a + TypicalAcceptanceSampler for token acceptance. + typical_acceptance_sampler_posterior_alpha (Optional[float]): + A scaling factor for the entropy-based threshold in the + TypicalAcceptanceSampler. + Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b9d9c102a6497..db14626d0691e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -564,16 +564,21 @@ def add_cli_args( type=str, default=EngineArgs.speculative_draft_token_sampling_method, choices=['rejection_sampler', 'typical_acceptance_sampler'], - help='The draft token sampler to use for speculative decoding.') + help='Specify the draft token sampling method for speculative decoding. ' + 'Two types of samplers are supported: ' + '1) RejectionSampler which does not allow changing the ' + 'acceptance rate of draft tokens, ' + '2) TypicalAccpetanceSampler which is configurable, allowing for a higher ' + 'acceptance rate at the cost of lower quality, and vice versa.') parser.add_argument( '--typical-acceptance-sampler-posterior-threshold', type=float, default=EngineArgs.typical_acceptance_sampler_posterior_threshold, - help='A threshold value that sets a lower bound on the ' - 'posterior probability of a token for it to be accepted. This ' - 'parameter is used by the TypicalAcceptanceSampler for making ' - 'sampling decisions during speculative decoding.') + help='Set the lower bound threshold for the posterior ' + 'probability of a token to be accepted. This threshold is ' + 'used by the TypicalAcceptanceSampler to make sampling decisions ' + 'during speculative decoding.') parser.add_argument( '--typical-acceptance-sampler-posterior-alpha', diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index d20ba890bcc31..ffc9eb5db6a48 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -372,7 +372,4 @@ def _format_spec_decode_metrics_str( f"Number of speculative tokens: {metrics.num_spec_tokens}, " f"Number of accepted tokens: {metrics.accepted_tokens}, " f"Number of draft tokens tokens: {metrics.draft_tokens}, " - f"Number of emitted tokens tokens: {metrics.emitted_tokens}, " - f"Total Time: {metrics.total_time}, " - f"Total Calls: {metrics.total_calls}, " - f"Avg Time: {metrics.avg_time}.") + f"Number of emitted tokens tokens: {metrics.emitted_tokens}, ") diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 1df2ad391562e..7899e18abab77 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -45,8 +45,8 @@ class SpecDecodeWorkerMetrics: class AsyncMetricsCollector: - """Class which copies rejection sampler metrics from the device to CPU on a - non-default Torch stream. + """Class which copies rejection/typical-acceptance sampler metrics + from the device to CPU on a non-default Torch stream. """ def __init__(self, @@ -94,7 +94,7 @@ def maybe_collect_rejsample_metrics( return None def _should_collect_rejsample_metrics(self, now: float) -> bool: - """Return whether or not this iteration should print rejection sampling + """Return whether or not this iteration should print sampling metrics. """ if self._rank != 0: @@ -106,8 +106,8 @@ def _should_collect_rejsample_metrics(self, now: float) -> bool: return True def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: - """Copy rejection sampling metrics (number of accepted tokens, etc) to - CPU asynchronously. + """Copy rejection/typical-acceptance sampling metrics + (number of accepted tokens, etc) to CPU asynchronously. Returns a CUDA event recording when the copy is complete. """ diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 78ed8397b255c..c71bee9c7dcb0 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -82,8 +82,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): welcome!). * Only top-1 proposal and scoring are implemented. Tree-attention is left as future work. - * Only lossless rejection sampling is supported. Contributions adding lossy - verification routines are welcome (e.g. Medusa's typical acceptance). * All sequences in a batch must have the same proposal length, or zero. This can be improved by having per-sequence speculation in the future. * The scoring forward pass is done without an MQA kernel, which is @@ -121,12 +119,12 @@ def create_worker( logger.info("Configuring SpecDecodeWorker with proposer=%s", type(proposer_worker)) - sampler: SpecDecodeBaseSampler = None + spec_decode_sampler: SpecDecodeBaseSampler = None if draft_token_sampling_method == "rejection_sampler": - sampler = RejectionSampler( + spec_decode_sampler = RejectionSampler( disable_bonus_tokens=disable_bonus_tokens, ) elif draft_token_sampling_method == "typical_acceptance_sampler": - sampler = TypicalAcceptanceSampler( + spec_decode_sampler = TypicalAcceptanceSampler( disable_bonus_tokens=disable_bonus_tokens, posterior_threshold=\ typical_acceptance_sampler_posterior_threshold, @@ -137,14 +135,14 @@ def create_worker( proposer_worker, scorer_worker, disable_by_batch_size=disable_by_batch_size, - sampler=sampler) + spec_decode_sampler=spec_decode_sampler) def __init__( self, proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, - verification_sampler: SpecDecodeBaseSampler, + spec_decode_sampler: SpecDecodeBaseSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, ): @@ -157,8 +155,12 @@ def __init__( scorer_worker: A worker that produces probabilities of speculative tokens according to some base model. Typically a vanilla vLLM Worker. - rejection_sampler: A Torch module used to perform modified rejection - sampling for speculative decoding. + spec_decode_sampler: A Torch module used to perform modified + sampling of the draft tokens in the verification step of + speculative decoding. Currently we support two different + types of sampler namely RejectionSampler and + TypicalAcceptanceSampler. 'spec_decode_sampler' is either an + instance of RejectionSampler or TypicalAcceptanceSampler. disable_by_batch_size: If the batch size is larger than this, disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set @@ -167,16 +169,16 @@ def __init__( self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker self.disable_by_batch_size = disable_by_batch_size or float("inf") - self.verification_sampler = verification_sampler + self.spec_decode_sampler = spec_decode_sampler assert ( - self.verification_sampler is not None, + self.spec_decode_sampler is not None, "Sampler is Not set, which is not expected." ) self._metrics = AsyncMetricsCollector( - self.verification_sampler + self.spec_decode_sampler ) if metrics_collector is None else metrics_collector - self.probs_dtype = self.verification_sampler.probs_dtype - self.token_id_dtype = self.verification_sampler.token_id_dtype + self.probs_dtype = self.spec_decode_sampler.probs_dtype + self.token_id_dtype = self.spec_decode_sampler.token_id_dtype # Lazy initiazliation. self.scorer: SpeculativeScorer @@ -193,7 +195,7 @@ def init_device(self) -> None: self.proposer_worker.load_model() self._metrics.init_gpu_tensors(self.rank) - self.verification_sampler.init_gpu_tensors(self.rank) + self.spec_decode_sampler.init_gpu_tensors(self.rank) self.scorer = BatchExpansionTop1Scorer( scorer_worker=self.scorer_worker, @@ -208,7 +210,7 @@ def load_model(self, *args, **kwargs): def _configure_model_sampler_for_spec_decode(self): """Configure model sampler to emit GPU tensors. This allows spec decode to keep data on device without transferring to CPU and serializing, - which significantly reduces overhead of rejection sampling. + which significantly reduces overhead of sampling during verification. NOTE(cade): This breaks abstraction boundaries pretty badly. The better design is to have the "move to CPU and serialize" sampling decision be @@ -496,16 +498,16 @@ def _get_accepted_token_ids(self, proposal_verifier_probs: torch.Tensor, bonus_token_ids: torch.Tensor, proposal_probs: torch.Tensor, proposal_token_ids: torch.Tensor): - if isinstance(self.verification_sampler, RejectionSampler): - accepted_token_ids = self.verification_sampler( + if isinstance(self.spec_decode_sampler, RejectionSampler): + accepted_token_ids = self.spec_decode_sampler( target_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_probs=proposal_probs, draft_token_ids=proposal_token_ids, ) else: - assert isinstance(self.verification_sampler, TypicalAcceptanceSampler) - accepted_token_ids = self.verification_sampler( + assert isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler) + accepted_token_ids = self.spec_decode_sampler( target_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_token_ids=proposal_token_ids, From ded92acf86d74753b8202c80b1c64398a8be76d3 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 11 Jun 2024 07:00:08 +0000 Subject: [PATCH 15/38] Fix test fixture and default values of args --- tests/spec_decode/test_dynamic_spec_decode.py | 8 +- tests/spec_decode/test_spec_decode_worker.py | 79 ++++++++++--------- tests/spec_decode/test_utils.py | 4 +- vllm/engine/arg_utils.py | 4 +- 4 files changed, 49 insertions(+), 46 deletions(-) diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 5ee2480ac14df..73790ed2db6b4 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -10,16 +10,16 @@ from vllm.spec_decode.top1_proposer import Top1Proposer from .utils import create_batch, mock_worker -from .test_utils import mock_sampler_factory +from .test_utils import mock_spec_decode_sampler @pytest.mark.parametrize('queue_size', [4]) @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('k', [1]) -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() def test_disable_spec_tokens( - queue_size: int, batch_size: int, k: int, mock_sampler_factory): + queue_size: int, batch_size: int, k: int, mock_spec_decode_sampler): """Verify that speculative tokens are disabled when the batch size exceeds the threshold. """ @@ -29,7 +29,7 @@ def test_disable_spec_tokens( metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(proposer_worker=draft_worker, scorer_worker=target_worker, - spec_decode_sampler=mock_sampler_factory, + spec_decode_sampler=mock_spec_decode_sampler, metrics_collector=metrics_collector, disable_by_batch_size=disable_by_batch_size) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 4e8c75b999b76..66840ab2f62e7 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -17,15 +17,15 @@ split_num_cache_blocks_evenly) from .utils import create_batch, create_sampler_output_list, mock_worker -from .test_utils import mock_sampler_factory +from .test_utils import mock_spec_decode_sampler @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() def test_correctly_calls_draft_model( - k: int, batch_size: int, mock_sampler_factory): + k: int, batch_size: int, mock_spec_decode_sampler): """Verify SpecDecodeWorker calls the draft worker with correct inputs. Everything else is mocked out. """ @@ -33,7 +33,7 @@ def test_correctly_calls_draft_model( target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory, metrics_collector) + mock_spec_decode_sampler, metrics_collector) exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) @@ -54,11 +54,11 @@ def test_correctly_calls_draft_model( @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() def test_correctly_calls_target_model( - k: int, batch_size: int, mock_sampler_factory): + k: int, batch_size: int, mock_spec_decode_sampler): """Verify SpecDecodeWorker calls the target model with correct inputs. Everything else is mocked out. """ @@ -72,7 +72,7 @@ def test_correctly_calls_target_model( set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory, + mock_spec_decode_sampler, metrics_collector) worker.init_device() @@ -136,11 +136,11 @@ def test_correctly_calls_target_model( @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() def test_correctly_calls_spec_decode_sampler( - k: int, batch_size: int, mock_sampler_factory): + k: int, batch_size: int, mock_spec_decode_sampler): """Verify SpecDecodeWorker calls the rejection sampler with correct inputs. Everything else is mocked out. """ @@ -150,7 +150,7 @@ def test_correctly_calls_spec_decode_sampler( vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - spec_decode_base_sampler = mock_sampler_factory + spec_decode_sampler = mock_spec_decode_sampler metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -158,7 +158,7 @@ def test_correctly_calls_spec_decode_sampler( set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - spec_decode_base_sampler, metrics_collector) + spec_decode_sampler, metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -205,15 +205,15 @@ def test_correctly_calls_spec_decode_sampler( exception_secret = 'artificial stop' - spec_decode_base_sampler.side_effect = ValueError(exception_secret) + spec_decode_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) - assert len(spec_decode_base_sampler.call_args_list) == 1 - _, kwargs = spec_decode_base_sampler.call_args_list[0] + assert len(spec_decode_sampler.call_args_list) == 1 + _, kwargs = spec_decode_sampler.call_args_list[0] actual = SimpleNamespace(**kwargs) assert torch.equal(actual.bonus_token_ids, @@ -222,17 +222,17 @@ def test_correctly_calls_spec_decode_sampler( actual.target_probs, target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) assert torch.equal(actual.draft_token_ids, proposal_token_ids) - if isinstance(spec_decode_base_sampler, RejectionSampler): + if isinstance(spec_decode_sampler, RejectionSampler): assert torch.equal(actual.draft_probs, proposal_probs) @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() def test_correctly_formats_output( - k: int, batch_size: int, mock_sampler_factory): + k: int, batch_size: int, mock_spec_decode_sampler): """Verify SpecDecodeWorker formats sampler output correctly. Everything else is mocked out. """ @@ -247,9 +247,9 @@ def test_correctly_formats_output( target_worker.device = 'cuda' set_random_seed(1) - spec_decode_base_sampler = mock_sampler_factory + spec_decode_sampler = mock_spec_decode_sampler worker = SpecDecodeWorker(draft_worker, target_worker, - spec_decode_base_sampler, + spec_decode_sampler, metrics_collector) worker.init_device() @@ -305,7 +305,7 @@ def test_correctly_formats_output( spec_decode_sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - spec_decode_base_sampler.return_value = spec_decode_sampler_output + spec_decode_sampler.return_value = spec_decode_sampler_output output = worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) @@ -352,11 +352,12 @@ def test_correctly_formats_output( @pytest.mark.parametrize('k', [1, 2]) @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('returns_metrics', [True, False]) -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() def test_collects_metrics( - k: int, batch_size: int, returns_metrics: bool, mock_sampler_factory): + k: int, batch_size: int, returns_metrics: bool, + mock_spec_decode_sampler): """Verify SpecDecodeWorker collects metrics. """ vocab_size = 32_000 @@ -365,7 +366,7 @@ def test_collects_metrics( vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - spec_decode_sampler = mock_sampler_factory + spec_decode_sampler = mock_spec_decode_sampler metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -449,10 +450,11 @@ def test_collects_metrics( @pytest.mark.parametrize('k', [0]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_k_equals_zero(k: int, batch_size: int, mock_sampler_factory): +def test_k_equals_zero( + k: int, batch_size: int, mock_spec_decode_sampler): """Verify that the SpecDecodeWorker calls the draft and target workers when k is zero. This happens during prefill. """ @@ -468,7 +470,7 @@ def test_k_equals_zero(k: int, batch_size: int, mock_sampler_factory): set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory, + mock_spec_decode_sampler, metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, @@ -490,10 +492,11 @@ def test_k_equals_zero(k: int, batch_size: int, mock_sampler_factory): @pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('batch_size', [0]) -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_empty_input_batch(k: int, batch_size: int, mock_sampler_factory): +def test_empty_input_batch( + k: int, batch_size: int, mock_spec_decode_sampler): """Verify that the SpecDecodeWorker calls the draft and target workers when the input batch is empty. This can happen if the engine communicates to the workers information without scheduling a batch. @@ -510,7 +513,7 @@ def test_empty_input_batch(k: int, batch_size: int, mock_sampler_factory): set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory, + mock_spec_decode_sampler, metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, @@ -529,16 +532,16 @@ def test_empty_input_batch(k: int, batch_size: int, mock_sampler_factory): draft_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req) -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @pytest.mark.skip_global_cleanup -def test_init_device(mock_sampler_factory): +def test_init_device(mock_spec_decode_sampler): """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as well as other GPU initialization. """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) - spec_decode_sampler = mock_sampler_factory + spec_decode_sampler = mock_spec_decode_sampler metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, @@ -554,10 +557,10 @@ def test_init_device(mock_sampler_factory): metrics_collector.init_gpu_tensors.assert_called_once() spec_decode_sampler.init_gpu_tensors.assert_called_once() -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @torch.inference_mode() -def test_initialize_cache(mock_sampler_factory): +def test_initialize_cache(mock_spec_decode_sampler): """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer workers. """ @@ -566,7 +569,7 @@ def test_initialize_cache(mock_sampler_factory): metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory, + mock_spec_decode_sampler, metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} @@ -580,14 +583,14 @@ def test_initialize_cache(mock_sampler_factory): @pytest.mark.parametrize('available_cpu_blocks', [500]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) -@pytest.mark.parametrize("mock_sampler_factory", +@pytest.mark.parametrize("mock_spec_decode_sampler", ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) @pytest.mark.skip_global_cleanup def test_determine_num_available_blocks(available_gpu_blocks: int, available_cpu_blocks: int, target_cache_block_size_bytes: int, draft_kv_size_bytes: int, - mock_sampler_factory): + mock_spec_decode_sampler): """Verify SpecDecodeWorker correctly profiles num available GPU blocks. Specifically, it should run profiling in the scorer worker, and then evenly split the blocks between proposer and scorer worker. @@ -603,7 +606,7 @@ def test_determine_num_available_blocks(available_gpu_blocks: int, draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes worker = SpecDecodeWorker(draft_worker, target_worker, - mock_sampler_factory, + mock_spec_decode_sampler, metrics_collector) num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 16b34cef7e903..7dc866ffbb796 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -114,7 +114,7 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): assert indices == [] @pytest.fixture -def mock_sampler_factory(request): +def mock_spec_decode_sampler(request): def create_samplers(value): if value == "rejection_sampler": sampler = MagicMock(spec=RejectionSampler) @@ -125,7 +125,7 @@ def create_samplers(value): sampler.token_id_dtype = torch.int64 return sampler else: - return None # Return None for both samplers if the value is not recognized + return None # Return None if the value is not recognized value = request.param # Get the value passed to the fixture return create_samplers(value) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index db14626d0691e..189f5e75dfd08 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -98,8 +98,8 @@ class EngineArgs: ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None speculative_draft_token_sampling_method: str = 'rejection_sampler' - typical_acceptance_sampler_posterior_threshold: float = 0.49 - typical_acceptance_sampler_posterior_alpha: float = 0.7 + typical_acceptance_sampler_posterior_threshold: float = 0.09 + typical_acceptance_sampler_posterior_alpha: float = 0.3 qlora_adapter_name_or_path: Optional[str] = None def __post_init__(self): From 738871edce4482dde265109c16ff77534d9efa9e Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 11 Jun 2024 18:22:20 +0000 Subject: [PATCH 16/38] Small misc fixes --- tests/spec_decode/e2e/test_multistep_correctness.py | 10 ++++++---- tests/spec_decode/test_spec_decode_worker.py | 3 +++ tests/spec_decode/test_utils.py | 5 +++++ vllm/config.py | 4 ++-- vllm/engine/arg_utils.py | 2 +- vllm/engine/metrics.py | 2 +- 6 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index a844d3fee36bf..7a2a7e9cc8d9e 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -9,15 +9,17 @@ Since speculative decoding with rejection sampling guarantees that the output distribution matches the target model's output distribution (up to hardware numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy -equality. This gives us good coverage of temp=0. At temp=0, the -TypicalAcceptanceSampler ensures that only the tokens with the highest -probability in the target distribution are accepted. Therefore, we can +equality. This gives us good coverage of temp=0. + +At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the +highest probability in the target distribution are accepted. Therefore, we can expect greedy equality for the TypicalAcceptanceSampler at temp=0. For temp>0, we rely on unit tests on the rejection sampler to verify that the output distribution is the same with spec decode vs. no spec decode (this would be prohibitively expensive to run with a real model). Similary, for the -TypicalAcceptance sampler, we rely on unit tests to validate temp>0 test cases. +TypicalAcceptance sampler also, we rely on unit tests to validate temp>0 +test cases. NOTE: Speculative decoding's distribution equality requires that the measured distributions of the target model and proposal model be deterministic given the diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 66840ab2f62e7..295e067bd33a2 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -223,6 +223,9 @@ def test_correctly_calls_spec_decode_sampler( target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) assert torch.equal(actual.draft_token_ids, proposal_token_ids) if isinstance(spec_decode_sampler, RejectionSampler): + # The draft probabilites is used only by the RejectionSampler. + # Ensure that if the sampler is a RejectionSampler then the + # draft probs are being passed. assert torch.equal(actual.draft_probs, proposal_probs) diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 7dc866ffbb796..db58251fef760 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -115,6 +115,11 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): @pytest.fixture def mock_spec_decode_sampler(request): + """ + Returns either a RejectionSampler or TypicalAcceptanceSampler + object depending on wether value is 'rejection_sampler' or + 'typical_acceptance_sampler' respectively. + """ def create_samplers(value): if value == "rejection_sampler": sampler = MagicMock(spec=RejectionSampler) diff --git a/vllm/config.py b/vllm/config.py index 31a4cb48a0d5d..41b90ca786b9d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -810,8 +810,8 @@ def maybe_create_spec_config( respectively. typical_acceptance_sampler_posterior_threshold (Optional[float]): A threshold value that sets a lower bound on the posterior - probability of a token in target model for it to be accepted. - This threshold is only used when we use a + probability of a token in the target model for it to be + accepted. This threshold is used only when we use the TypicalAcceptanceSampler for token acceptance. typical_acceptance_sampler_posterior_alpha (Optional[float]): A scaling factor for the entropy-based threshold in the diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 189f5e75dfd08..abd52c98becf8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -568,7 +568,7 @@ def add_cli_args( 'Two types of samplers are supported: ' '1) RejectionSampler which does not allow changing the ' 'acceptance rate of draft tokens, ' - '2) TypicalAccpetanceSampler which is configurable, allowing for a higher ' + '2) TypicalAcceptanceSampler which is configurable, allowing for a higher ' 'acceptance rate at the cost of lower quality, and vice versa.') parser.add_argument( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index ffc9eb5db6a48..28dfc2f9adb88 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -372,4 +372,4 @@ def _format_spec_decode_metrics_str( f"Number of speculative tokens: {metrics.num_spec_tokens}, " f"Number of accepted tokens: {metrics.accepted_tokens}, " f"Number of draft tokens tokens: {metrics.draft_tokens}, " - f"Number of emitted tokens tokens: {metrics.emitted_tokens}, ") + f"Number of emitted tokens tokens: {metrics.emitted_tokens}. ") From 50e8771a0fe4e5bbf4dfed0c0dd0517e46861ccd Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 11 Jun 2024 18:37:34 +0000 Subject: [PATCH 17/38] Fix spec_decode/test_metrics.py --- tests/spec_decode/test_metrics.py | 90 +++++++++++++------------- vllm/spec_decode/metrics.py | 10 +-- vllm/spec_decode/spec_decode_worker.py | 3 - 3 files changed, 50 insertions(+), 53 deletions(-) diff --git a/tests/spec_decode/test_metrics.py b/tests/spec_decode/test_metrics.py index 312878804b86e..c1d5aa60d8d1b 100644 --- a/tests/spec_decode/test_metrics.py +++ b/tests/spec_decode/test_metrics.py @@ -10,16 +10,16 @@ def test_initial_call_returns_none(): """Expect first call to get metrics to return None. """ - rej_sampler = MagicMock() - rej_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_draft_tokens = 0 - - collector = AsyncMetricsCollector(rej_sampler) + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_draft_tokens = 0 + + collector = AsyncMetricsCollector(spec_decode_sampler) collector.init_gpu_tensors(rank=0) maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5) assert maybe_metrics is None @@ -28,14 +28,14 @@ def test_initial_call_returns_none(): def test_second_call_returns_metrics(): """Expect second call to not return None. """ - rej_sampler = MagicMock() - rej_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_draft_tokens = 0 + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_draft_tokens = 0 collect_interval_s = 5.0 timer = MagicMock() @@ -43,7 +43,7 @@ def test_second_call_returns_metrics(): 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 ] - collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, + collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, timer=timer, collect_interval_s=collect_interval_s) collector.init_gpu_tensors(rank=0) @@ -56,16 +56,16 @@ def test_second_call_returns_metrics(): def test_nonzero_rank_noop(rank): """Verify nonzero ranks don't collect metrics. """ - rej_sampler = MagicMock() - rej_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_draft_tokens = 0 - - collector = AsyncMetricsCollector(rej_sampler) + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_draft_tokens = 0 + + collector = AsyncMetricsCollector(spec_decode_sampler) collector.init_gpu_tensors(rank=rank) _ = collector.maybe_collect_rejsample_metrics(k=5) metrics = collector.maybe_collect_rejsample_metrics(k=5) @@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank): def test_noop_until_time(): """Verify metrics aren't collected until enough time passes. """ - rej_sampler = MagicMock() - rej_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_draft_tokens = 0 + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_draft_tokens = 0 collect_interval_s = 5.0 timer = MagicMock() @@ -91,7 +91,7 @@ def test_noop_until_time(): collect_interval_s + 0.1, collect_interval_s + 0.1 ] - collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, + collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, timer=timer, collect_interval_s=collect_interval_s) collector.init_gpu_tensors(rank=0) @@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool): max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens( num_draft_tokens, k) - rej_sampler = MagicMock() - rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens, - dtype=torch.long, - device='cuda') - rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, dtype=torch.long, device='cuda') - rej_sampler.num_draft_tokens = num_draft_tokens + spec_decode_sampler.num_draft_tokens = num_draft_tokens collect_interval_s = 5.0 timer = MagicMock() @@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 ] - collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, + collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, timer=timer, collect_interval_s=collect_interval_s) collector.init_gpu_tensors(rank=0) diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 7899e18abab77..67100ebf1a7de 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -50,10 +50,10 @@ class AsyncMetricsCollector: """ def __init__(self, - spec_decode_base_sampler: SpecDecodeBaseSampler, + spec_decode_sampler: SpecDecodeBaseSampler, timer: Optional[Timer] = None, collect_interval_s: float = 5.0): - self.spec_decode_base_sampler = spec_decode_base_sampler + self.spec_decode_sampler = spec_decode_sampler self._timer = time.time if timer is None else timer self._rank: Optional[int] = None @@ -116,13 +116,13 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: with torch.cuda.stream(self._copy_stream): self._aggregate_num_accepted_tokens.copy_( - self.spec_decode_base_sampler.num_accepted_tokens, non_blocking=True) + self.spec_decode_sampler.num_accepted_tokens, non_blocking=True) self._aggregate_num_emitted_tokens.copy_( - self.spec_decode_base_sampler.num_emitted_tokens, non_blocking=True) + self.spec_decode_sampler.num_emitted_tokens, non_blocking=True) # Number of draft tokens is calculated on CPU, so no copy is # required. self._aggregate_num_draft_tokens = ( - self.spec_decode_base_sampler.num_draft_tokens) + self.spec_decode_sampler.num_draft_tokens) aggregate_metrics_ready = torch.cuda.Event() aggregate_metrics_ready.record(self._copy_stream) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index c71bee9c7dcb0..25bc0e9a45801 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -474,9 +474,6 @@ def _verify_tokens( # Get proposed tokens. proposal_token_ids = proposals.proposal_token_ids[spec_indices] - - #print('_total_time_in_verify ' + str(self._total_time_in_verify)) - #print('_num_calls ' + str(self._num_calls)) # Append output tokens from non-speculative sequences to # the accepted token ids tensor. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + From cc760a0533ee794e66e782e80cc08ea6c967040e Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 25 Jun 2024 07:07:57 +0000 Subject: [PATCH 18/38] Make rejection_sampler.py and typical_acceptance_sampler.py implement the same interface --- .../test_typical_acceptance_sampler.py | 28 +++++++++++------ .../layers/rejection_sampler.py | 8 ++--- .../layers/spec_decode_base_sampler.py | 15 ++++++++- .../layers/typical_acceptance_sampler.py | 17 +++++----- vllm/spec_decode/spec_decode_worker.py | 31 +++++-------------- 5 files changed, 54 insertions(+), 45 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 87cf37bc926bc..86c3b0d83399d 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -76,7 +76,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, size=(batch_size, k), dtype=torch.int64) # Verify that sampling succeeds for all cases. - typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + typical_acceptance_sampler(target_probs, bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) @@ -126,7 +128,8 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, with pytest.raises(AssertionError): typical_acceptance_sampler(target_probs, bonus_token_ids, - draft_token_ids) + draft_probs=None, + draft_token_ids=draft_token_ids) @pytest.mark.parametrize("seed", list(range(10))) @@ -165,7 +168,8 @@ def test_uniform_target_distribution_accepts_all_tokens( dtype=torch.int64) output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, - draft_token_ids) + draft_probs=None, + draft_token_ids=draft_token_ids) # We are using a uniform target probability distribution. # For a uniform distribution the entropy is very high and it # should lead to all draft tokens being accepted. Verify that. @@ -226,7 +230,8 @@ def test_temperature_zero_target_distribution(seed: int, # Verify the same. output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, - draft_token_ids) + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, -1] == -1) @@ -279,7 +284,8 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, dtype=torch.int64) output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, - draft_token_ids) + draft_probs=None, + draft_token_ids=draft_token_ids) # verify the shape of output_token_ids assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) @@ -341,7 +347,8 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, dtype=torch.int64) output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, - draft_token_ids) + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) @@ -359,7 +366,8 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, - draft_token_ids) + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) @@ -404,7 +412,8 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, dtype=torch.int64) output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, - draft_token_ids) + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 1:-1] == -1) @@ -420,7 +429,8 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, typical_acceptance_sampler.init_gpu_tensors(rank=0) output_token_ids = typical_acceptance_sampler(target_probs, bonus_token_ids, - draft_token_ids) + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 1763d18771d28..e5b829adb91a4 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) -class RejectionSampler(SpecDecodeBaseSampler, nn.Module): +class RejectionSampler(SpecDecodeBaseSampler): """Apply modified rejection sampling as described in "Accelerating Large Language Model Decoding with Speculative Sampling" https://arxiv.org/pdf/2302.01318.pdf. @@ -29,8 +29,9 @@ def __init__(self, during sampling. This catches correctness issues but adds nontrivial latency. """ - SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode) - nn.Module.__init__(self) + super().__init__( + disable_bonus_tokens=disable_bonus_tokens, + strict_mode=strict_mode) def forward( self, @@ -244,7 +245,6 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny - # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. # Note that we always sample with replacement. diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 8eda86c5ccdbb..da61170349ece 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,10 +1,12 @@ from typing import Optional +from abc import ABC, abstractmethod import torch import torch.jit +import torch.nn as nn -class SpecDecodeBaseSampler(): +class SpecDecodeBaseSampler(nn.Module): """Base class for samplers used for Speculative Decoding verification step. """ @@ -54,6 +56,17 @@ def probs_dtype(self): def token_id_dtype(self): return torch.int64 + @abstractmethod + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + def _create_output( self, accepted: torch.Tensor, # [batch_size, k] diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index d1920db4ad91a..64a484bc35715 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -7,7 +7,7 @@ SpecDecodeBaseSampler) -class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): +class TypicalAcceptanceSampler(SpecDecodeBaseSampler): """Apply typical acceptance sampling as described in section 3.3.1 in "MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" @@ -38,17 +38,16 @@ def __init__( """ self._posterior_threshold = posterior_threshold self._posterior_alpha = posterior_alpha - super().__init__() - SpecDecodeBaseSampler.__init__( - self, + super().__init__( disable_bonus_tokens=disable_bonus_tokens, strict_mode=strict_mode) - nn.Module.__init__(self) + def forward( self, target_probs: torch.Tensor, bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, ) -> torch.Tensor: """Sample token ids using typical acceptance sampling. This accepts @@ -70,6 +69,8 @@ def forward( speculative tokens in a sequence are accepted. shape = [batch_size, num_bonus_tokens] + draft_probs: This parameter is unused by the acceptance sampler. + draft_token_ids: The token ids that were sampled from the draft probabilities. shape = [batch_size, num_speculative_tokens] @@ -145,8 +146,10 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): posterior_entropy = -torch.sum( target_probs * torch.log(target_probs + 1e-5), dim=-1) threshold = torch.minimum( - torch.ones_like(posterior_entropy, device=device) * self._posterior_threshold, - index=draft_token_ids.unsqueeze(-1)).squeeze(-1) + torch.ones_like(posterior_entropy, device=device) * + self._posterior_threshold, + torch.exp(-posterior_entropy) * self._posterior_alpha, + ) # A small constant added to prevent computing the logarithm of zero, # which can lead to undefined values. epsilon = 1e-5 diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 66cd6a093a8cd..c6f0b10aef3f2 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -504,9 +504,12 @@ def _verify_tokens( non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + 1).clone() non_spec_token_ids[:, 1:] = -1 - accepted_token_ids = self._get_accepted_token_ids( - proposal_verifier_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, - proposal_probs=proposal_probs, proposal_token_ids=proposal_token_ids) + accepted_token_ids = self.spec_decode_sampler( + target_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_probs=proposal_probs, + draft_token_ids=proposal_token_ids, + ) accepted_token_ids = torch.cat( [accepted_token_ids, non_spec_token_ids]) logprobs = proposal_scores.logprobs @@ -529,27 +532,7 @@ def _verify_tokens( hidden_states) return accepted_token_ids, logprobs - - def _get_accepted_token_ids(self, proposal_verifier_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - proposal_probs: torch.Tensor, - proposal_token_ids: torch.Tensor): - if isinstance(self.spec_decode_sampler, RejectionSampler): - accepted_token_ids = self.spec_decode_sampler( - target_probs=proposal_verifier_probs, - bonus_token_ids=bonus_token_ids, - draft_probs=proposal_probs, - draft_token_ids=proposal_token_ids, - ) - else: - assert isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler) - accepted_token_ids = self.spec_decode_sampler( - target_probs=proposal_verifier_probs, - bonus_token_ids=bonus_token_ids, - draft_token_ids=proposal_token_ids, - ) - return accepted_token_ids - + def _create_output_sampler_list( self, seq_group_metadata_list: List[SequenceGroupMetadata], From 360ce0bc4ea0f30af1b45ec7193cc47ffae444e6 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 25 Jun 2024 07:19:13 +0000 Subject: [PATCH 19/38] Raise exception instead of returning None for invalid sampler name --- tests/spec_decode/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 69097f6e41e0b..c43fc4f3e76d3 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -4,7 +4,7 @@ import torch from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids -from vllm.spec_decode.util import get_all_seq_ids, split_batch_by_proposal_len +from vllm.spec_decode.util import split_batch_by_proposal_len from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.typical_acceptance_sampler import TypicalAcceptanceSampler @@ -130,7 +130,7 @@ def create_samplers(value): sampler.token_id_dtype = torch.int64 return sampler else: - return None # Return None if the value is not recognized + raise ValueError(f"Invalid sampler name {value}") value = request.param # Get the value passed to the fixture return create_samplers(value) From 6572ba407e75e4e7ff32032b107c4937bdf52ac0 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 25 Jun 2024 07:22:12 +0000 Subject: [PATCH 20/38] Adding log about type of sampler --- vllm/spec_decode/spec_decode_worker.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index c6f0b10aef3f2..38f74267a5a0d 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -137,6 +137,9 @@ def create_worker( typical_acceptance_sampler_posterior_threshold, posterior_alpha=typical_acceptance_sampler_posterior_alpha, ) + logger.info("Configuring SpecDecodeWorker with sampler=%s", + type(spec_decode_sampler)) + return SpecDecodeWorker( proposer_worker, From be85f072a5ce5ad75d727955aa7bc11c703b27f7 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 06:03:29 +0000 Subject: [PATCH 21/38] Misc comment fixes --- .../test_typical_acceptance_sampler.py | 26 +++++--- .../e2e/test_multistep_correctness.py | 65 +++++++++++++------ .../spec_decode/e2e/test_ngram_correctness.py | 46 +++++++++++++ vllm/config.py | 47 +++++++++++--- vllm/engine/arg_utils.py | 19 +++--- .../layers/typical_acceptance_sampler.py | 13 +--- vllm/spec_decode/spec_decode_worker.py | 8 +-- 7 files changed, 164 insertions(+), 60 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 86c3b0d83399d..1725bdb41c578 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -51,6 +51,16 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, break return draft_token_ids +def get_acceptance_sampler( + posterior_threshold: float = 0.03, + posterior_alpha: float = 0.9, + disable_bonus_tokens: bool = False, + strict_mode: bool = False, +) -> TypicalAcceptanceSampler: + return TypicalAcceptanceSampler( + posterior_threshold, posterior_alpha, disable_bonus_tokens, strict_mode) + + @pytest.mark.parametrize("k", list(range(1, 6))) @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @@ -64,7 +74,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, different combinations of k, vocab_size, batch_size and num devices. """ torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler() + typical_acceptance_sampler = get_acceptance_sampler() typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, @@ -96,7 +106,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True) + typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, @@ -154,7 +164,7 @@ def test_uniform_target_distribution_accepts_all_tokens( batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) @@ -207,7 +217,7 @@ def test_temperature_zero_target_distribution(seed: int, vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) # Simulate temperature 0 probability distribution for target probabilities @@ -266,7 +276,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, batch_size = 4 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) # For sequences 0 and 2 set the distribution to a temperature @@ -332,7 +342,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, batch_size = 1 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) # Create a temperature zero target probability distribution and ensure @@ -392,7 +402,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, batch_size = 1 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) # Simulate temperature 0 probability distribution for target @@ -461,7 +471,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool, batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 7a2a7e9cc8d9e..fddc9db838cca 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -183,9 +183,7 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": method } - for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize( "output_len", @@ -239,9 +237,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 3, - "speculative_draft_token_sampling_method": method } - for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize( "output_len", @@ -289,9 +285,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": method } - for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("max_output_len", [ 256, @@ -332,9 +326,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": method } - for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize( @@ -378,9 +370,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": method } - for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize( @@ -427,9 +417,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": method } - for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize( "output_len", @@ -483,9 +471,7 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - "speculative_draft_token_sampling_method": method } - for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize( @@ -530,9 +516,7 @@ def test_spec_decode_different_block_size(baseline_llm_generator, # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. "speculative_max_model_len": 32, - "speculative_draft_token_sampling_method": method } - for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( @@ -576,9 +560,7 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, "speculative_disable_by_batch_size": 2, - "speculative_draft_token_sampling_method": method } - for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("output_len", [10]) @@ -613,12 +595,9 @@ def test_disable_speculation(baseline_llm_generator, test_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": k, - "speculative_draft_token_sampling_method": method, } # Try a range of common k, as well as large speculation. for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] - # Try both methods of sampling in the verifier. - for method in ["rejection_sampler", "typical_acceptance_sampler"] ]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize( @@ -638,3 +617,47 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": k, + "spec_decoding_acceptance_method": "typical_acceptance_sampler" + } + # Try a range of common k, as well as large speculation. + for k in [1, 2, 63] + ]) +@pytest.mark.parametrize("batch_size", [1, 64]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_typical_acceptance_sampling( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify that speculative decoding produces exact equality to without spec + decode with many different values of k. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index d475d37af6425..b20ac2ac68498 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -211,3 +211,49 @@ def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, batch_size, max_output_len=output_len, force_output_len=True) + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 3, + "spec_decoding_acceptance_method": "typical_acceptance_sampler" + } + # Try a range of common k, as well as large speculation. + for k in [1, 3, 5] + ]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("batch_size", [1, 32]) +def test_ngram_typical_acceptance_sampling( + baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/vllm/config.py b/vllm/config.py index a7cf15ae0b537..091e65eb88557 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -803,7 +803,7 @@ def maybe_create_spec_config( speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], - draft_token_sampling_method: Optional[str], + draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: Optional[float], typical_acceptance_sampler_posterior_alpha: Optional[float], ) -> Optional["SpeculativeConfig"]: @@ -840,8 +840,8 @@ def maybe_create_spec_config( window, if provided. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token window, if provided. - draft_token_sampling_method (Optional[str]): The sampling method - to use for accepting draft tokens. This can take two possible + draft_token_acceptance_method (str): The method to use for + accepting draft tokens. This can take two possible values 'rejection_sampler' and 'typical_acceptance_sampler' for RejectionSampler and TypicalAcceptanceSampler respectively. @@ -961,6 +961,11 @@ def maybe_create_spec_config( "num_speculative_tokens must be provided with " "speculative_model unless the draft model config contains an " "n_predict parameter.") + + if typical_acceptance_sampler_posterior_threshold is None: + typical_acceptance_sampler_posterior_threshold = 0.09 + if typical_acceptance_sampler_posterior_alpha is None: + typical_acceptance_sampler_posterior_alpha = 0.3 return SpeculativeConfig( draft_model_config, @@ -969,7 +974,7 @@ def maybe_create_spec_config( speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, - draft_token_sampling_method=draft_token_sampling_method, + draft_token_acceptance_method=draft_token_acceptance_method, typical_acceptance_sampler_posterior_threshold=\ typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=\ @@ -1045,9 +1050,9 @@ def __init__( speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], - draft_token_sampling_method: Optional[str], - typical_acceptance_sampler_posterior_threshold: Optional[float], - typical_acceptance_sampler_posterior_alpha: Optional[float], + draft_token_acceptance_method: str, + typical_acceptance_sampler_posterior_threshold: float, + typical_acceptance_sampler_posterior_alpha: float, ): """Create a SpeculativeConfig object. @@ -1061,6 +1066,19 @@ def __init__( enqueue requests is larger than this value. ngram_prompt_lookup_max: Max size of ngram token window. ngram_prompt_lookup_min: Min size of ngram token window. + draft_token_acceptance_method (str): The method to use for + accepting draft tokens. This can take two possible + values 'rejection_sampler' and 'typical_acceptance_sampler' + for RejectionSampler and TypicalAcceptanceSampler + respectively. + typical_acceptance_sampler_posterior_threshold (Optional[float]): + A threshold value that sets a lower bound on the posterior + probability of a token in the target model for it to be + accepted. This threshold is used only when we use the + TypicalAcceptanceSampler for token acceptance. + typical_acceptance_sampler_posterior_alpha (Optional[float]): + A scaling factor for the entropy-based threshold in the + TypicalAcceptanceSampler. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config @@ -1069,7 +1087,7 @@ def __init__( speculative_disable_by_batch_size self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 - self.draft_token_sampling_method = draft_token_sampling_method + self.draft_token_acceptance_method = draft_token_acceptance_method self.typical_acceptance_sampler_posterior_threshold = \ typical_acceptance_sampler_posterior_threshold self.typical_acceptance_sampler_posterior_alpha = \ @@ -1085,6 +1103,19 @@ def _verify_args(self) -> None: if self.draft_model_config: self.draft_model_config.verify_with_parallel_config( self.draft_parallel_config) + # Validate and set draft token acceptance related settings. + + if (self.draft_token_acceptance_method is None): + raise ValueError("draft_token_acceptance_method is not set. " + "Expected values are rejection_sampler or " + "typical_acceptance_sampler.") + + if (self.draft_token_acceptance_method != 'rejection_sampler' + and self.draft_token_acceptance_method != 'typical_acceptance_sampler'): + raise ValueError("Expected draft_token_acceptance_method to be either " + "rejection_sampler or typical_acceptance_sampler. Instead it " + f"is {self.draft_token_acceptance_method}") + @property def num_lookahead_slots(self) -> int: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4d581ef364b28..37e6f6fa5da78 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -99,9 +99,9 @@ class EngineArgs: speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None - speculative_draft_token_sampling_method: str = 'rejection_sampler' - typical_acceptance_sampler_posterior_threshold: float = 0.09 - typical_acceptance_sampler_posterior_alpha: float = 0.3 + spec_decoding_acceptance_method: str = 'rejection_sampler' + typical_acceptance_sampler_posterior_threshold: Optional[float] = None + typical_acceptance_sampler_posterior_alpha: Optional[float] = None qlora_adapter_name_or_path: Optional[str] = None otlp_traces_endpoint: Optional[str] = None @@ -570,12 +570,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'decoding.') parser.add_argument( - '--speculative-draft-token-sampling-method', + '--spec-decoding-acceptance-routine', type=str, - default=EngineArgs.speculative_draft_token_sampling_method, + default=EngineArgs.spec_decoding_acceptance_method, choices=['rejection_sampler', 'typical_acceptance_sampler'], - help='Specify the draft token sampling method for speculative decoding. ' - 'Two types of samplers are supported: ' + help='Specify the acceptance method to use during draft token ' + 'verification in speculative decoding. Two types of acceptance ' + 'routines are supported: ' '1) RejectionSampler which does not allow changing the ' 'acceptance rate of draft tokens, ' '2) TypicalAcceptanceSampler which is configurable, allowing for a higher ' @@ -725,8 +726,8 @@ def create_engine_config(self, ) -> EngineConfig: use_v2_block_manager=self.use_v2_block_manager, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, - draft_token_sampling_method=self. - speculative_draft_token_sampling_method, + draft_token_acceptance_method=\ + self.spec_decoding_acceptance_method, typical_acceptance_sampler_posterior_threshold=self. typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=self. diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 64a484bc35715..132802ce75539 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -15,10 +15,10 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler): """ def __init__( self, + posterior_threshold: float, + posterior_alpha: float, disable_bonus_tokens: bool = False, strict_mode: bool = False, - posterior_threshold: float = 0.09, - posterior_alpha: float = 0.3, ): """Create a Typical Acceptance Sampler. @@ -142,14 +142,7 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): device = target_probs.device candidates_prob = torch.gather( target_probs, dim=-1, - index=draft_token_ids.unsqueeze(-1), ).squeeze(-1) - posterior_entropy = -torch.sum( - target_probs * torch.log(target_probs + 1e-5), dim=-1) - threshold = torch.minimum( - torch.ones_like(posterior_entropy, device=device) * - self._posterior_threshold, - torch.exp(-posterior_entropy) * self._posterior_alpha, - ) + index=draft_token_ids.unsqueeze(-1)).squeeze(-1) # A small constant added to prevent computing the logarithm of zero, # which can lead to undefined values. epsilon = 1e-5 diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 38f74267a5a0d..e0c8a384a0c5a 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -58,7 +58,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": disable_by_batch_size=speculative_config. speculative_disable_by_batch_size, draft_token_sampling_method=speculative_config. - draft_token_sampling_method, + draft_token_acceptance_method, typical_acceptance_sampler_posterior_threshold=speculative_config. typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=speculative_config. @@ -133,9 +133,9 @@ def create_worker( elif draft_token_sampling_method == "typical_acceptance_sampler": spec_decode_sampler = TypicalAcceptanceSampler( disable_bonus_tokens=disable_bonus_tokens, - posterior_threshold=\ - typical_acceptance_sampler_posterior_threshold, - posterior_alpha=typical_acceptance_sampler_posterior_alpha, + #posterior_threshold=\ + # typical_acceptance_sampler_posterior_threshold, + #posterior_alpha=typical_acceptance_sampler_posterior_alpha, ) logger.info("Configuring SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) From 6dc9efe96ec6da73433ce06259266af5197ce113 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 06:30:43 +0000 Subject: [PATCH 22/38] Misc fixes --- .../layers/rejection_sampler.py | 5 +--- .../layers/spec_decode_base_sampler.py | 2 -- .../layers/typical_acceptance_sampler.py | 2 +- vllm/spec_decode/spec_decode_worker.py | 27 +++++++++---------- 4 files changed, 14 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index e5b829adb91a4..f6b0e61e20f32 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -76,7 +76,6 @@ def forward( """ # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. - start = time.time() if self._strict_mode: self._raise_if_incorrect_input(target_probs, bonus_token_ids, draft_probs, draft_token_ids) @@ -93,9 +92,7 @@ def forward( draft_token_ids, bonus_token_ids, ) - end = time.time() - self.total_time += (end - start) - self.total_calls += 1 + return output_token_ids def _batch_modified_rejection_sampling( diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index da61170349ece..fa57125f3e514 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -35,8 +35,6 @@ def __init__(self, self.num_accepted_tokens: Optional[torch.Tensor] = None self.num_emitted_tokens: Optional[torch.Tensor] = None self.num_draft_tokens: int = 0 - self.total_time: float = 0 - self.total_calls: float = 0 def init_gpu_tensors(self, rank: int) -> None: assert self.num_accepted_tokens is None diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 132802ce75539..16389dca0319f 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -110,7 +110,7 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): A tensor of shape (batch_size, k) representing the proposed token ids. - A draft token_id x_{n+k} is accepted if it satisifies the + A draft token_id x_{n+k} is accepted if it satisfies the following condition .. math:: diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index e0c8a384a0c5a..d7e0b2f8dfef7 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -100,9 +100,9 @@ def create_worker( scorer_worker: WorkerBase, draft_worker_kwargs: Dict[str, Any], disable_by_batch_size: Optional[int], - draft_token_sampling_method: Optional[str] = "rejection_sampler", - typical_acceptance_sampler_posterior_threshold: Optional[float] = 0.09, - typical_acceptance_sampler_posterior_alpha: Optional[float] = 0.3, + draft_token_sampling_method: str, + typical_acceptance_sampler_posterior_threshold: float, + typical_acceptance_sampler_posterior_alpha: float, ) -> "SpecDecodeWorker": ngram_prompt_lookup_max = ( @@ -133,9 +133,9 @@ def create_worker( elif draft_token_sampling_method == "typical_acceptance_sampler": spec_decode_sampler = TypicalAcceptanceSampler( disable_bonus_tokens=disable_bonus_tokens, - #posterior_threshold=\ - # typical_acceptance_sampler_posterior_threshold, - #posterior_alpha=typical_acceptance_sampler_posterior_alpha, + posterior_threshold=\ + typical_acceptance_sampler_posterior_threshold, + posterior_alpha=typical_acceptance_sampler_posterior_alpha, ) logger.info("Configuring SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) @@ -180,10 +180,6 @@ def __init__( self.scorer_worker = scorer_worker self.disable_by_batch_size = disable_by_batch_size or float("inf") self.spec_decode_sampler = spec_decode_sampler - assert ( - self.spec_decode_sampler is not None, - "Sampler is Not set, which is not expected." - ) self._metrics = AsyncMetricsCollector( self.spec_decode_sampler ) if metrics_collector is None else metrics_collector @@ -502,17 +498,18 @@ def _verify_tokens( # Get proposed tokens. proposal_token_ids = proposals.proposal_token_ids[spec_indices] - # Append output tokens from non-speculative sequences to - # the accepted token ids tensor. - non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + - 1).clone() - non_spec_token_ids[:, 1:] = -1 accepted_token_ids = self.spec_decode_sampler( target_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_probs=proposal_probs, draft_token_ids=proposal_token_ids, ) + + # Append output tokens from non-speculative sequences to + # the accepted token ids tensor. + non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + + 1).clone() + non_spec_token_ids[:, 1:] = -1 accepted_token_ids = torch.cat( [accepted_token_ids, non_spec_token_ids]) logprobs = proposal_scores.logprobs From 512fad9b9546e43abe9ec89b586efeeaa93b776d Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 07:04:44 +0000 Subject: [PATCH 23/38] Misc fixes --- .../e2e/test_multistep_correctness.py | 27 ++++++++++--------- .../spec_decode/e2e/test_ngram_correctness.py | 4 +-- vllm/spec_decode/spec_decode_worker.py | 10 +++---- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index fddc9db838cca..7368a2f1f776e 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -183,7 +183,7 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - } + }, ]) @pytest.mark.parametrize( "output_len", @@ -236,8 +236,8 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( @pytest.mark.parametrize("test_llm_kwargs", [ { "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - } + "num_speculative_tokens": 5, + }, ]) @pytest.mark.parametrize( "output_len", @@ -285,7 +285,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - } + }, ]) @pytest.mark.parametrize("max_output_len", [ 256, @@ -326,7 +326,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - } + }, ]) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize( @@ -370,7 +370,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - } + }, ]) @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize( @@ -417,7 +417,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - } + }, ]) @pytest.mark.parametrize( "output_len", @@ -471,7 +471,7 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - } + }, ]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize( @@ -512,11 +512,11 @@ def test_spec_decode_different_block_size(baseline_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - + # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. "speculative_max_model_len": 32, - } + }, ]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( @@ -560,7 +560,7 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, "speculative_disable_by_batch_size": 2, - } + }, ]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("output_len", [10]) @@ -621,7 +621,7 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, @pytest.mark.parametrize( "common_llm_kwargs", [{ - "model": "JackFram/llama-68m", + "model": "JackFram/llama-160m", # Skip cuda graph recording for fast test. "enforce_eager": True, @@ -654,7 +654,8 @@ def test_typical_acceptance_sampling( baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): """Verify that speculative decoding produces exact equality to without spec - decode with many different values of k. + decode with many TypicalAcceptanceSampler as the draft token acceptance + sampling method. """ run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index b20ac2ac68498..2c98639334992 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -249,8 +249,8 @@ def test_ngram_typical_acceptance_sampling( baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): """Verify that ngram speculative decoding produces exact equality - to without spec decode with many different values of k and - different ngram_prompt_lookup_max. + to without spec decode with many different values of k, batch_size and + using TypicalAcceptanceSampler as the draft token acceptance method. """ run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index d7e0b2f8dfef7..16d10721ef326 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -57,7 +57,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": draft_worker_kwargs=draft_worker_kwargs, disable_by_batch_size=speculative_config. speculative_disable_by_batch_size, - draft_token_sampling_method=speculative_config. + draft_token_acceptance_method=speculative_config. draft_token_acceptance_method, typical_acceptance_sampler_posterior_threshold=speculative_config. typical_acceptance_sampler_posterior_threshold, @@ -100,7 +100,7 @@ def create_worker( scorer_worker: WorkerBase, draft_worker_kwargs: Dict[str, Any], disable_by_batch_size: Optional[int], - draft_token_sampling_method: str, + draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, ) -> "SpecDecodeWorker": @@ -127,10 +127,10 @@ def create_worker( type(proposer_worker)) spec_decode_sampler: SpecDecodeBaseSampler = None - if draft_token_sampling_method == "rejection_sampler": + if draft_token_acceptance_method == "rejection_sampler": spec_decode_sampler = RejectionSampler( disable_bonus_tokens=disable_bonus_tokens, ) - elif draft_token_sampling_method == "typical_acceptance_sampler": + elif draft_token_acceptance_method == "typical_acceptance_sampler": spec_decode_sampler = TypicalAcceptanceSampler( disable_bonus_tokens=disable_bonus_tokens, posterior_threshold=\ @@ -165,7 +165,7 @@ def __init__( scorer_worker: A worker that produces probabilities of speculative tokens according to some base model. Typically a vanilla vLLM Worker. - spec_decode_sampler: A Torch module used to perform modified + spec_decode_sampler: A Torch module used to perform acceptance sampling of the draft tokens in the verification step of speculative decoding. Currently we support two different types of sampler namely RejectionSampler and From b1d510c4ee8884286c0eacb16993c0786dc723e7 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 07:11:31 +0000 Subject: [PATCH 24/38] Misc fixes --- tests/spec_decode/test_spec_decode_worker.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 264bcd4025289..5ba085529d67d 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -223,11 +223,7 @@ def test_correctly_calls_spec_decode_sampler( actual.target_probs, target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) assert torch.equal(actual.draft_token_ids, proposal_token_ids) - if isinstance(spec_decode_sampler, RejectionSampler): - # The draft probabilites is used only by the RejectionSampler. - # Ensure that if the sampler is a RejectionSampler then the - # draft probs are being passed. - assert torch.equal(actual.draft_probs, proposal_probs) + assert torch.equal(actual.draft_probs, proposal_probs) @pytest.mark.parametrize('k', [1, 2, 6]) From f4b9e4dea7a0eedf4733d5063d802e9785819bea Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 07:13:24 +0000 Subject: [PATCH 25/38] Misc fixes --- tests/spec_decode/e2e/test_multistep_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 7368a2f1f776e..14432cbe37afd 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -512,7 +512,7 @@ def test_spec_decode_different_block_size(baseline_llm_generator, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, - + # Artificially limit the draft model max model len; this forces vLLM # to skip speculation once the sequences grow beyond 32-k tokens. "speculative_max_model_len": 32, From 0ea94087d2dd1c3aca10553b54a0ad4d2b1ca735 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 07:50:06 +0000 Subject: [PATCH 26/38] Documentation --- tests/samplers/test_typical_acceptance_sampler.py | 7 +++++-- vllm/engine/arg_utils.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 1725bdb41c578..dee74271788cd 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -57,8 +57,11 @@ def get_acceptance_sampler( disable_bonus_tokens: bool = False, strict_mode: bool = False, ) -> TypicalAcceptanceSampler: - return TypicalAcceptanceSampler( - posterior_threshold, posterior_alpha, disable_bonus_tokens, strict_mode) + """ + Initializes and returns a TypicalAcceptanceSampler. + """ + return TypicalAcceptanceSampler( + posterior_threshold, posterior_alpha, disable_bonus_tokens, strict_mode) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 37e6f6fa5da78..bb8c970d0efd7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -589,7 +589,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='Set the lower bound threshold for the posterior ' 'probability of a token to be accepted. This threshold is ' 'used by the TypicalAcceptanceSampler to make sampling decisions ' - 'during speculative decoding.') + 'during speculative decoding. Defaults to 0.09') parser.add_argument( '--typical-acceptance-sampler-posterior-alpha', @@ -597,7 +597,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.typical_acceptance_sampler_posterior_alpha, help='A scaling factor for the entropy-based threshold for token ' 'acceptance in the TypicalAcceptanceSampler. Typically defaults ' - 'to sqrt of --typical-acceptance-sampler-posterior-threshold.') + 'to sqrt of --typical-acceptance-sampler-posterior-threshold i.e 0.3') parser.add_argument('--model-loader-extra-config', type=nullable_str, From 5772d04c1e28a738d487d3c19b550edffc9a9c2b Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 07:59:41 +0000 Subject: [PATCH 27/38] Fix comments --- vllm/model_executor/layers/typical_acceptance_sampler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 16389dca0319f..dc44ae523463d 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -31,10 +31,9 @@ def __init__( nontrivial latency. posterior_threshold : A threshold value that sets a lower bound on the posterior probability of a token in target model for it - to be accepted. Default is 0.09 + to be accepted. posterior_alpha : A scaling factor for the entropy-based - threshold in typical acceptance sampling. Typically defaults to - sqrt of posterior_threshold and is set to 0.3. + threshold in typical acceptance sampling. """ self._posterior_threshold = posterior_threshold self._posterior_alpha = posterior_alpha From b7254e71725613fcc3016543802b0d34812aa63c Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 15:14:34 +0000 Subject: [PATCH 28/38] Fix arg name --- vllm/engine/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bb8c970d0efd7..4f370ba086c8c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -570,7 +570,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'decoding.') parser.add_argument( - '--spec-decoding-acceptance-routine', + '--spec-decoding-acceptance-method', type=str, default=EngineArgs.spec_decoding_acceptance_method, choices=['rejection_sampler', 'typical_acceptance_sampler'], From ef93081ec6427881b9f3d017906e731543e87eda Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 15:52:19 +0000 Subject: [PATCH 29/38] Fixing a test --- tests/spec_decode/e2e/test_multistep_correctness.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 14432cbe37afd..6f0d66f9ce5e0 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -640,9 +640,9 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, "spec_decoding_acceptance_method": "typical_acceptance_sampler" } # Try a range of common k, as well as large speculation. - for k in [1, 2, 63] + for k in [1, 2, 3] ]) -@pytest.mark.parametrize("batch_size", [1, 64]) +@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize( "output_len", [ From 01658428fbdcc64bd1756984a3e5777be19156da Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 16:51:19 +0000 Subject: [PATCH 30/38] Fix comment --- tests/spec_decode/e2e/test_multistep_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 6f0d66f9ce5e0..c88b81ee3c836 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -654,7 +654,7 @@ def test_typical_acceptance_sampling( baseline_llm_generator, test_llm_generator, batch_size: int, output_len: int): """Verify that speculative decoding produces exact equality to without spec - decode with many TypicalAcceptanceSampler as the draft token acceptance + decode with TypicalAcceptanceSampler as the draft token acceptance sampling method. """ run_greedy_equality_correctness_test(baseline_llm_generator, From 510974bf444da861620c43bd3c0ebd4dbb5e3f33 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 19:33:41 +0000 Subject: [PATCH 31/38] Fix formatting --- .../test_typical_acceptance_sampler.py | 83 +++++++++-------- .../e2e/test_multistep_correctness.py | 9 +- .../spec_decode/e2e/test_ngram_correctness.py | 7 +- tests/spec_decode/test_dynamic_spec_decode.py | 9 +- tests/spec_decode/test_metrics.py | 4 +- tests/spec_decode/test_spec_decode_worker.py | 93 +++++++++---------- tests/spec_decode/test_utils.py | 16 ++-- vllm/config.py | 32 +++++-- vllm/engine/arg_utils.py | 12 ++- .../layers/rejection_sampler.py | 21 ++--- .../layers/spec_decode_base_sampler.py | 3 +- .../layers/typical_acceptance_sampler.py | 9 +- vllm/spec_decode/metrics.py | 7 +- vllm/spec_decode/spec_decode_worker.py | 35 ++++--- 14 files changed, 181 insertions(+), 159 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index dee74271788cd..4f6290795b2ce 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -51,18 +51,18 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, break return draft_token_ids + def get_acceptance_sampler( - posterior_threshold: float = 0.03, - posterior_alpha: float = 0.9, - disable_bonus_tokens: bool = False, - strict_mode: bool = False, + posterior_threshold: float = 0.03, + posterior_alpha: float = 0.9, + disable_bonus_tokens: bool = False, + strict_mode: bool = False, ) -> TypicalAcceptanceSampler: """ Initializes and returns a TypicalAcceptanceSampler. """ - return TypicalAcceptanceSampler( - posterior_threshold, posterior_alpha, disable_bonus_tokens, strict_mode) - + return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha, + disable_bonus_tokens, strict_mode) @pytest.mark.parametrize("k", list(range(1, 6))) @@ -89,7 +89,8 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, size=(batch_size, k), dtype=torch.int64) # Verify that sampling succeeds for all cases. - typical_acceptance_sampler(target_probs, bonus_token_ids, + typical_acceptance_sampler(target_probs, + bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -140,7 +141,8 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, oob_token_ids[0][0] = rogue_token_id with pytest.raises(AssertionError): - typical_acceptance_sampler(target_probs, bonus_token_ids, + typical_acceptance_sampler(target_probs, + bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -179,10 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens( high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) # We are using a uniform target probability distribution. # For a uniform distribution the entropy is very high and it # should lead to all draft tokens being accepted. Verify that. @@ -241,10 +244,11 @@ def test_temperature_zero_target_distribution(seed: int, # 1.0 tokens in the target distribution we will reject all of them and # fallback to the greedy sampling for selecting 1 token for each sequence. # Verify the same. - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, -1] == -1) @@ -295,10 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) # verify the shape of output_token_ids assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) @@ -358,10 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) @@ -377,10 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, batch_size, k, vocab_size, zero_temperature_token_ids) draft_token_ids = torch.cat( (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) @@ -423,10 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 1:-1] == -1) @@ -440,10 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, posterior_threshold=0.0, posterior_alpha=0.0) typical_acceptance_sampler.init_gpu_tensors(rank=0) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index c88b81ee3c836..ef8d737ae8418 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -17,7 +17,7 @@ For temp>0, we rely on unit tests on the rejection sampler to verify that the output distribution is the same with spec decode vs. no spec decode (this would -be prohibitively expensive to run with a real model). Similary, for the +be prohibitively expensive to run with a real model). Similarly, for the TypicalAcceptance sampler also, we rely on unit tests to validate temp>0 test cases. @@ -618,6 +618,7 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, max_output_len=output_len, force_output_len=True) + @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -650,9 +651,9 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_typical_acceptance_sampling( - baseline_llm_generator, test_llm_generator, batch_size: int, - output_len: int): +def test_typical_acceptance_sampling(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): """Verify that speculative decoding produces exact equality to without spec decode with TypicalAcceptanceSampler as the draft token acceptance sampling method. diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 2c98639334992..43aef11e053b9 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -212,6 +212,7 @@ def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, max_output_len=output_len, force_output_len=True) + @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -245,9 +246,9 @@ def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, ]) @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("batch_size", [1, 32]) -def test_ngram_typical_acceptance_sampling( - baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): +def test_ngram_typical_acceptance_sampling(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): """Verify that ngram speculative decoding produces exact equality to without spec decode with many different values of k, batch_size and using TypicalAcceptanceSampler as the draft token acceptance method. diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 73790ed2db6b4..8ded6ab702f21 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -10,16 +10,17 @@ from vllm.spec_decode.top1_proposer import Top1Proposer from .utils import create_batch, mock_worker -from .test_utils import mock_spec_decode_sampler + @pytest.mark.parametrize('queue_size', [4]) @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('k', [1]) @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @torch.inference_mode() -def test_disable_spec_tokens( - queue_size: int, batch_size: int, k: int, mock_spec_decode_sampler): +def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int, + mock_spec_decode_sampler): """Verify that speculative tokens are disabled when the batch size exceeds the threshold. """ diff --git a/tests/spec_decode/test_metrics.py b/tests/spec_decode/test_metrics.py index c1d5aa60d8d1b..2918fabddc900 100644 --- a/tests/spec_decode/test_metrics.py +++ b/tests/spec_decode/test_metrics.py @@ -127,8 +127,8 @@ def test_initial_metrics_has_correct_values(has_data: bool): dtype=torch.long, device='cuda') spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, - dtype=torch.long, - device='cuda') + dtype=torch.long, + device='cuda') spec_decode_sampler.num_draft_tokens = num_draft_tokens collect_interval_s = 5.0 diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 5ba085529d67d..c6da98c553295 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -6,8 +6,6 @@ import pytest import torch -from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.model_executor.layers.typical_acceptance_sampler import TypicalAcceptanceSampler from vllm.model_executor.utils import set_random_seed from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput from vllm.spec_decode.interfaces import SpeculativeProposals @@ -18,15 +16,16 @@ split_num_cache_blocks_evenly) from .utils import create_batch, create_sampler_output_list, mock_worker -from .test_utils import mock_spec_decode_sampler + @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @torch.inference_mode() -def test_correctly_calls_draft_model( - k: int, batch_size: int, mock_spec_decode_sampler): +def test_correctly_calls_draft_model(k: int, batch_size: int, + mock_spec_decode_sampler): """Verify SpecDecodeWorker calls the draft worker with correct inputs. Everything else is mocked out. """ @@ -56,10 +55,11 @@ def test_correctly_calls_draft_model( @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @torch.inference_mode() -def test_correctly_calls_target_model( - k: int, batch_size: int, mock_spec_decode_sampler): +def test_correctly_calls_target_model(k: int, batch_size: int, + mock_spec_decode_sampler): """Verify SpecDecodeWorker calls the target model with correct inputs. Everything else is mocked out. """ @@ -73,8 +73,7 @@ def test_correctly_calls_target_model( set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, - metrics_collector) + mock_spec_decode_sampler, metrics_collector) worker.init_device() vocab_size = 32_000 @@ -138,10 +137,11 @@ def test_correctly_calls_target_model( @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @torch.inference_mode() -def test_correctly_calls_spec_decode_sampler( - k: int, batch_size: int, mock_spec_decode_sampler): +def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, + mock_spec_decode_sampler): """Verify SpecDecodeWorker calls the rejection sampler with correct inputs. Everything else is mocked out. """ @@ -158,8 +158,8 @@ def test_correctly_calls_spec_decode_sampler( set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, - spec_decode_sampler, metrics_collector) + worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, + metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -229,10 +229,11 @@ def test_correctly_calls_spec_decode_sampler( @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @torch.inference_mode() -def test_correctly_formats_output( - k: int, batch_size: int, mock_spec_decode_sampler): +def test_correctly_formats_output(k: int, batch_size: int, + mock_spec_decode_sampler): """Verify SpecDecodeWorker formats sampler output correctly. Everything else is mocked out. """ @@ -248,8 +249,7 @@ def test_correctly_formats_output( set_random_seed(1) spec_decode_sampler = mock_spec_decode_sampler - worker = SpecDecodeWorker(draft_worker, target_worker, - spec_decode_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, metrics_collector) worker.init_device() @@ -304,7 +304,7 @@ def test_correctly_formats_output( minimum_accepted_tokens = 1 spec_decode_sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - + spec_decode_sampler.return_value = spec_decode_sampler_output output = worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, @@ -359,11 +359,11 @@ def test_correctly_formats_output( @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('returns_metrics', [True, False]) @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @torch.inference_mode() -def test_collects_metrics( - k: int, batch_size: int, returns_metrics: bool, - mock_spec_decode_sampler): +def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool, + mock_spec_decode_sampler): """Verify SpecDecodeWorker collects metrics. """ vocab_size = 32_000 @@ -379,8 +379,7 @@ def test_collects_metrics( set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, - spec_decode_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, metrics_collector) worker.init_device() @@ -457,10 +456,10 @@ def test_collects_metrics( @pytest.mark.parametrize('k', [0]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @torch.inference_mode() -def test_k_equals_zero( - k: int, batch_size: int, mock_spec_decode_sampler): +def test_k_equals_zero(k: int, batch_size: int, mock_spec_decode_sampler): """Verify that the SpecDecodeWorker calls the draft and target workers when k is zero. This happens during prefill. """ @@ -478,8 +477,7 @@ def test_k_equals_zero( set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, - metrics_collector) + mock_spec_decode_sampler, metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -501,10 +499,10 @@ def test_k_equals_zero( @pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('batch_size', [0]) @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @torch.inference_mode() -def test_empty_input_batch( - k: int, batch_size: int, mock_spec_decode_sampler): +def test_empty_input_batch(k: int, batch_size: int, mock_spec_decode_sampler): """Verify that the SpecDecodeWorker calls the draft and target workers when the input batch is empty. This can happen if the engine communicates to the workers information without scheduling a batch. @@ -523,8 +521,7 @@ def test_empty_input_batch( set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, - metrics_collector) + mock_spec_decode_sampler, metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -542,8 +539,10 @@ def test_empty_input_batch( draft_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req) + @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @pytest.mark.skip_global_cleanup def test_init_device(mock_spec_decode_sampler): """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as @@ -554,8 +553,7 @@ def test_init_device(mock_spec_decode_sampler): spec_decode_sampler = mock_spec_decode_sampler metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, - spec_decode_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, metrics_collector) worker.init_device() @@ -567,8 +565,10 @@ def test_init_device(mock_spec_decode_sampler): metrics_collector.init_gpu_tensors.assert_called_once() spec_decode_sampler.init_gpu_tensors.assert_called_once() + @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @torch.inference_mode() def test_initialize_cache(mock_spec_decode_sampler): """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer @@ -579,8 +579,7 @@ def test_initialize_cache(mock_spec_decode_sampler): metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, - metrics_collector) + mock_spec_decode_sampler, metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} worker.initialize_cache(**kwargs) @@ -594,7 +593,8 @@ def test_initialize_cache(mock_spec_decode_sampler): @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) @pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], indirect=True) + ["rejection_sampler", "typical_acceptance_sampler"], + indirect=True) @pytest.mark.skip_global_cleanup def test_determine_num_available_blocks(available_gpu_blocks: int, available_cpu_blocks: int, @@ -616,8 +616,7 @@ def test_determine_num_available_blocks(available_gpu_blocks: int, draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, - metrics_collector) + mock_spec_decode_sampler, metrics_collector) num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index c43fc4f3e76d3..9d56875121b82 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -1,12 +1,13 @@ from unittest.mock import MagicMock import pytest - import torch + +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids from vllm.spec_decode.util import split_batch_by_proposal_len -from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.model_executor.layers.typical_acceptance_sampler import TypicalAcceptanceSampler def test_get_all_seq_ids(): @@ -113,16 +114,18 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): assert filtered_groups == [] assert indices == [] + @pytest.fixture def mock_spec_decode_sampler(request): """ Returns either a RejectionSampler or TypicalAcceptanceSampler - object depending on wether value is 'rejection_sampler' or + object depending on whether value is 'rejection_sampler' or 'typical_acceptance_sampler' respectively. """ + def create_samplers(value): if value == "rejection_sampler": - sampler = MagicMock(spec=RejectionSampler) + sampler = MagicMock(spec=RejectionSampler) sampler.token_id_dtype = torch.int64 return sampler elif value == "typical_acceptance_sampler": @@ -131,7 +134,6 @@ def create_samplers(value): return sampler else: raise ValueError(f"Invalid sampler name {value}") - + value = request.param # Get the value passed to the fixture return create_samplers(value) - diff --git a/vllm/config.py b/vllm/config.py index 091e65eb88557..58dd3578e2403 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -805,7 +805,7 @@ def maybe_create_spec_config( ngram_prompt_lookup_min: Optional[int], draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: Optional[float], - typical_acceptance_sampler_posterior_alpha: Optional[float], + typical_acceptance_sampler_posterior_alpha: Optional[float], ) -> Optional["SpeculativeConfig"]: """Create a SpeculativeConfig if possible, else return None. @@ -961,7 +961,7 @@ def maybe_create_spec_config( "num_speculative_tokens must be provided with " "speculative_model unless the draft model config contains an " "n_predict parameter.") - + if typical_acceptance_sampler_posterior_threshold is None: typical_acceptance_sampler_posterior_threshold = 0.09 if typical_acceptance_sampler_posterior_alpha is None: @@ -978,7 +978,7 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_threshold=\ typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=\ - typical_acceptance_sampler_posterior_alpha, + typical_acceptance_sampler_posterior_alpha, ) @staticmethod @@ -1052,7 +1052,7 @@ def __init__( ngram_prompt_lookup_min: Optional[int], draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: float, - typical_acceptance_sampler_posterior_alpha: float, + typical_acceptance_sampler_posterior_alpha: float, ): """Create a SpeculativeConfig object. @@ -1103,19 +1103,31 @@ def _verify_args(self) -> None: if self.draft_model_config: self.draft_model_config.verify_with_parallel_config( self.draft_parallel_config) - # Validate and set draft token acceptance related settings. + # Validate and set draft token acceptance related settings. if (self.draft_token_acceptance_method is None): raise ValueError("draft_token_acceptance_method is not set. " - "Expected values are rejection_sampler or " + "Expected values are rejection_sampler or " "typical_acceptance_sampler.") if (self.draft_token_acceptance_method != 'rejection_sampler' - and self.draft_token_acceptance_method != 'typical_acceptance_sampler'): - raise ValueError("Expected draft_token_acceptance_method to be either " - "rejection_sampler or typical_acceptance_sampler. Instead it " - f"is {self.draft_token_acceptance_method}") + and self.draft_token_acceptance_method != + 'typical_acceptance_sampler'): + raise ValueError( + "Expected draft_token_acceptance_method to be either " + "rejection_sampler or typical_acceptance_sampler. Instead it " + f"is {self.draft_token_acceptance_method}") + if (self.typical_acceptance_sampler_posterior_threshold < 0 + or self.typical_acceptance_sampler_posterior_alpha < 0): + raise ValueError( + "Expected typical_acceptance_sampler_posterior_threshold " + "and typical_acceptance_sampler_posterior_alpha to be > 0. " + "Instead found " + f"typical_acceptance_sampler_posterior_threshold = " + f"{self.typical_acceptance_sampler_posterior_threshold} and " + f"typical_acceptance_sampler_posterior_alpha = " + f"{self.typical_acceptance_sampler_posterior_alpha}") @property def num_lookahead_slots(self) -> int: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4f370ba086c8c..3efd252fcb049 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -579,8 +579,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'routines are supported: ' '1) RejectionSampler which does not allow changing the ' 'acceptance rate of draft tokens, ' - '2) TypicalAcceptanceSampler which is configurable, allowing for a higher ' - 'acceptance rate at the cost of lower quality, and vice versa.') + '2) TypicalAcceptanceSampler which is configurable, allowing for ' + 'a higher acceptance rate at the cost of lower quality, ' + 'and vice versa.') parser.add_argument( '--typical-acceptance-sampler-posterior-threshold', @@ -596,8 +597,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=float, default=EngineArgs.typical_acceptance_sampler_posterior_alpha, help='A scaling factor for the entropy-based threshold for token ' - 'acceptance in the TypicalAcceptanceSampler. Typically defaults ' - 'to sqrt of --typical-acceptance-sampler-posterior-threshold i.e 0.3') + 'acceptance in the TypicalAcceptanceSampler. Typically defaults ' + 'to sqrt of --typical-acceptance-sampler-posterior-threshold ' + 'i.e. 0.3') parser.add_argument('--model-loader-extra-config', type=nullable_str, @@ -731,7 +733,7 @@ def create_engine_config(self, ) -> EngineConfig: typical_acceptance_sampler_posterior_threshold=self. typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=self. - typical_acceptance_sampler_posterior_alpha, + typical_acceptance_sampler_posterior_alpha, ) scheduler_config = SchedulerConfig( diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index f6b0e61e20f32..e189610461a70 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -3,13 +3,11 @@ import torch import torch.jit -import time -import torch.nn as nn -from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) + class RejectionSampler(SpecDecodeBaseSampler): """Apply modified rejection sampling as described in "Accelerating Large Language Model Decoding with Speculative Sampling" @@ -29,9 +27,8 @@ def __init__(self, during sampling. This catches correctness issues but adds nontrivial latency. """ - super().__init__( - disable_bonus_tokens=disable_bonus_tokens, - strict_mode=strict_mode) + super().__init__(disable_bonus_tokens=disable_bonus_tokens, + strict_mode=strict_mode) def forward( self, @@ -80,11 +77,12 @@ def forward( self._raise_if_incorrect_input(target_probs, bonus_token_ids, draft_probs, draft_token_ids) - accepted, recovered_token_ids = self._batch_modified_rejection_sampling( - target_probs, - draft_probs, - draft_token_ids, - ) + accepted, recovered_token_ids = ( + self._batch_modified_rejection_sampling( + target_probs, + draft_probs, + draft_token_ids, + )) output_token_ids = self._create_output( accepted, @@ -242,6 +240,7 @@ def _smallest_positive_value(self) -> float: """ return torch.finfo(self.probs_dtype).tiny + # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead that skips the sync. # Note that we always sample with replacement. diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index fa57125f3e514..692024056495c 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,5 +1,5 @@ +from abc import abstractmethod from typing import Optional -from abc import ABC, abstractmethod import torch import torch.jit @@ -64,7 +64,6 @@ def forward( ) -> torch.Tensor: raise NotImplementedError - def _create_output( self, accepted: torch.Tensor, # [batch_size, k] diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index dc44ae523463d..9bf3c84a161c5 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -1,7 +1,5 @@ import torch import torch.jit -import torch.nn as nn -import time from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) @@ -13,6 +11,7 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler): Multiple Decoding Heads" https://arxiv.org/pdf/2401.10774 """ + def __init__( self, posterior_threshold: float, @@ -37,10 +36,8 @@ def __init__( """ self._posterior_threshold = posterior_threshold self._posterior_alpha = posterior_alpha - super().__init__( - disable_bonus_tokens=disable_bonus_tokens, - strict_mode=strict_mode) - + super().__init__(disable_bonus_tokens=disable_bonus_tokens, + strict_mode=strict_mode) def forward( self, diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 67100ebf1a7de..2c4ae0b22744b 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -4,7 +4,8 @@ import torch -from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) from vllm.utils import is_pin_memory_available @@ -41,6 +42,7 @@ class SpecDecodeWorkerMetrics: # The number of speculative tokens per sequence. num_spec_tokens: int + Timer = Callable[[], float] @@ -116,7 +118,8 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: with torch.cuda.stream(self._copy_stream): self._aggregate_num_accepted_tokens.copy_( - self.spec_decode_sampler.num_accepted_tokens, non_blocking=True) + self.spec_decode_sampler.num_accepted_tokens, + non_blocking=True) self._aggregate_num_emitted_tokens.copy_( self.spec_decode_sampler.num_emitted_tokens, non_blocking=True) # Number of draft tokens is calculated on CPU, so no copy is diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 16d10721ef326..e66d68fcc918a 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -6,9 +6,11 @@ from vllm.config import SpeculativeConfig from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger -from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.model_executor.layers.typical_acceptance_sampler import TypicalAcceptanceSampler +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) +from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SamplerOutput, SequenceGroupMetadata, get_all_seq_ids) @@ -26,7 +28,6 @@ split_batch_by_proposal_len) from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase -import time logger = init_logger(__name__) @@ -62,8 +63,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_threshold=speculative_config. typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=speculative_config. - typical_acceptance_sampler_posterior_alpha - ) + typical_acceptance_sampler_posterior_alpha) return spec_decode_worker @@ -101,8 +101,8 @@ def create_worker( draft_worker_kwargs: Dict[str, Any], disable_by_batch_size: Optional[int], draft_token_acceptance_method: str, - typical_acceptance_sampler_posterior_threshold: float, - typical_acceptance_sampler_posterior_alpha: float, + typical_acceptance_sampler_posterior_threshold: float, + typical_acceptance_sampler_posterior_alpha: float, ) -> "SpecDecodeWorker": ngram_prompt_lookup_max = ( @@ -125,7 +125,7 @@ def create_worker( logger.info("Configuring SpecDecodeWorker with proposer=%s", type(proposer_worker)) - + spec_decode_sampler: SpecDecodeBaseSampler = None if draft_token_acceptance_method == "rejection_sampler": spec_decode_sampler = RejectionSampler( @@ -136,17 +136,14 @@ def create_worker( posterior_threshold=\ typical_acceptance_sampler_posterior_threshold, posterior_alpha=typical_acceptance_sampler_posterior_alpha, - ) + ) logger.info("Configuring SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) - - return SpecDecodeWorker( - proposer_worker, - scorer_worker, - disable_by_batch_size=disable_by_batch_size, - spec_decode_sampler=spec_decode_sampler) - + return SpecDecodeWorker(proposer_worker, + scorer_worker, + disable_by_batch_size=disable_by_batch_size, + spec_decode_sampler=spec_decode_sampler) def __init__( self, @@ -206,7 +203,7 @@ def init_device(self) -> None: self._metrics.init_gpu_tensors(self.rank) self.spec_decode_sampler.init_gpu_tensors(self.rank) - + self.scorer = BatchExpansionTop1Scorer( scorer_worker=self.scorer_worker, device=self.device, @@ -216,7 +213,7 @@ def init_device(self) -> None: def load_model(self, *args, **kwargs): pass - + def _configure_model_sampler_for_spec_decode(self): """Configure model sampler to emit GPU tensors. This allows spec decode to keep data on device without transferring to CPU and serializing, @@ -532,7 +529,7 @@ def _verify_tokens( hidden_states) return accepted_token_ids, logprobs - + def _create_output_sampler_list( self, seq_group_metadata_list: List[SequenceGroupMetadata], From 396fa547aa3b9328ea836976b49febdf4b066a37 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Wed, 26 Jun 2024 23:05:49 +0000 Subject: [PATCH 32/38] Fixing tests and lint failures --- tests/spec_decode/test_dynamic_spec_decode.py | 11 +- tests/spec_decode/test_spec_decode_worker.py | 111 +++++++++--------- tests/spec_decode/test_utils.py | 32 ++--- 3 files changed, 74 insertions(+), 80 deletions(-) diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 8ded6ab702f21..29ed96999cb4c 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -9,18 +9,18 @@ from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.top1_proposer import Top1Proposer +from .test_utils import mock_spec_decode_sampler from .utils import create_batch, mock_worker @pytest.mark.parametrize('queue_size', [4]) @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('k', [1]) -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int, - mock_spec_decode_sampler): + acceptance_sampler_method: str): """Verify that speculative tokens are disabled when the batch size exceeds the threshold. """ @@ -30,7 +30,8 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int, metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(proposer_worker=draft_worker, scorer_worker=target_worker, - spec_decode_sampler=mock_spec_decode_sampler, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), metrics_collector=metrics_collector, disable_by_batch_size=disable_by_batch_size) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index c6da98c553295..527e7eddd7e33 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -15,25 +15,26 @@ from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) +from .test_utils import mock_spec_decode_sampler from .utils import create_batch, create_sampler_output_list, mock_worker @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_correctly_calls_draft_model(k: int, batch_size: int, - mock_spec_decode_sampler): + acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the draft worker with correct inputs. Everything else is mocked out. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) @@ -54,12 +55,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_correctly_calls_target_model(k: int, batch_size: int, - mock_spec_decode_sampler): + acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the target model with correct inputs. Everything else is mocked out. """ @@ -72,8 +72,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int, set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) worker.init_device() vocab_size = 32_000 @@ -136,12 +137,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int, @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, - mock_spec_decode_sampler): + acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the rejection sampler with correct inputs. Everything else is mocked out. """ @@ -151,7 +151,7 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - spec_decode_sampler = mock_spec_decode_sampler + spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -228,12 +228,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_correctly_formats_output(k: int, batch_size: int, - mock_spec_decode_sampler): + acceptance_sampler_method: str): """Verify SpecDecodeWorker formats sampler output correctly. Everything else is mocked out. """ @@ -248,7 +247,7 @@ def test_correctly_formats_output(k: int, batch_size: int, target_worker.device = 'cuda' set_random_seed(1) - spec_decode_sampler = mock_spec_decode_sampler + spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, metrics_collector) worker.init_device() @@ -358,12 +357,11 @@ def test_correctly_formats_output(k: int, batch_size: int, @pytest.mark.parametrize('k', [1, 2]) @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('returns_metrics', [True, False]) -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool, - mock_spec_decode_sampler): + acceptance_sampler_method: str): """Verify SpecDecodeWorker collects metrics. """ vocab_size = 32_000 @@ -372,7 +370,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool, vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - spec_decode_sampler = mock_spec_decode_sampler + spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -455,11 +453,11 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool, @pytest.mark.parametrize('k', [0]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_k_equals_zero(k: int, batch_size: int, mock_spec_decode_sampler): +def test_k_equals_zero(k: int, batch_size: int, + acceptance_sampler_method: str): """Verify that the SpecDecodeWorker calls the draft and target workers when k is zero. This happens during prefill. """ @@ -476,8 +474,9 @@ def test_k_equals_zero(k: int, batch_size: int, mock_spec_decode_sampler): set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -498,11 +497,11 @@ def test_k_equals_zero(k: int, batch_size: int, mock_spec_decode_sampler): @pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('batch_size', [0]) -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_empty_input_batch(k: int, batch_size: int, mock_spec_decode_sampler): +def test_empty_input_batch(k: int, batch_size: int, + acceptance_sampler_method: str): """Verify that the SpecDecodeWorker calls the draft and target workers when the input batch is empty. This can happen if the engine communicates to the workers information without scheduling a batch. @@ -520,8 +519,9 @@ def test_empty_input_batch(k: int, batch_size: int, mock_spec_decode_sampler): set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -540,17 +540,16 @@ def test_empty_input_batch(k: int, batch_size: int, mock_spec_decode_sampler): target_worker.execute_model.assert_called_once_with(execute_model_req) -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @pytest.mark.skip_global_cleanup -def test_init_device(mock_spec_decode_sampler): +def test_init_device(acceptance_sampler_method: str): """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as well as other GPU initialization. """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) - spec_decode_sampler = mock_spec_decode_sampler + spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, @@ -566,11 +565,10 @@ def test_init_device(mock_spec_decode_sampler): spec_decode_sampler.init_gpu_tensors.assert_called_once() -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_initialize_cache(mock_spec_decode_sampler): +def test_initialize_cache(acceptance_sampler_method): """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer workers. """ @@ -578,8 +576,9 @@ def test_initialize_cache(mock_spec_decode_sampler): target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} worker.initialize_cache(**kwargs) @@ -592,15 +591,14 @@ def test_initialize_cache(mock_spec_decode_sampler): @pytest.mark.parametrize('available_cpu_blocks', [500]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) -@pytest.mark.parametrize("mock_spec_decode_sampler", - ["rejection_sampler", "typical_acceptance_sampler"], - indirect=True) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @pytest.mark.skip_global_cleanup def test_determine_num_available_blocks(available_gpu_blocks: int, available_cpu_blocks: int, target_cache_block_size_bytes: int, draft_kv_size_bytes: int, - mock_spec_decode_sampler): + acceptance_sampler_method: str): """Verify SpecDecodeWorker correctly profiles num available GPU blocks. Specifically, it should run profiling in the scorer worker, and then evenly split the blocks between proposer and scorer worker. @@ -615,8 +613,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int, target_cache_block_size_bytes) draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes - worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler, metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 9d56875121b82..18dbdd5bc952f 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -115,25 +115,19 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): assert indices == [] -@pytest.fixture -def mock_spec_decode_sampler(request): +def mock_spec_decode_sampler(acceptance_sampler_method): """ Returns either a RejectionSampler or TypicalAcceptanceSampler - object depending on whether value is 'rejection_sampler' or - 'typical_acceptance_sampler' respectively. + object depending on whether acceptance_sampler_method is + 'rejection_sampler' or 'typical_acceptance_sampler' respectively. """ - - def create_samplers(value): - if value == "rejection_sampler": - sampler = MagicMock(spec=RejectionSampler) - sampler.token_id_dtype = torch.int64 - return sampler - elif value == "typical_acceptance_sampler": - sampler = MagicMock(spec=TypicalAcceptanceSampler) - sampler.token_id_dtype = torch.int64 - return sampler - else: - raise ValueError(f"Invalid sampler name {value}") - - value = request.param # Get the value passed to the fixture - return create_samplers(value) + if acceptance_sampler_method == "rejection_sampler": + sampler = MagicMock(spec=RejectionSampler) + sampler.token_id_dtype = torch.int64 + return sampler + elif acceptance_sampler_method == "typical_acceptance_sampler": + sampler = MagicMock(spec=TypicalAcceptanceSampler) + sampler.token_id_dtype = torch.int64 + return sampler + else: + raise ValueError(f"Invalid sampler name {acceptance_sampler_method}") From f8cc895d39174da6af55a10ce9474afac5825ea8 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Thu, 27 Jun 2024 18:17:10 +0000 Subject: [PATCH 33/38] Removing e2e test for TypicalAcceptanceSampler from test_ngram_correctness.py --- .../spec_decode/e2e/test_ngram_correctness.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 43aef11e053b9..d475d37af6425 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -211,50 +211,3 @@ def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, batch_size, max_output_len=output_len, force_output_len=True) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - "speculative_model": "[ngram]", - "num_speculative_tokens": k, - "ngram_prompt_lookup_max": 3, - "spec_decoding_acceptance_method": "typical_acceptance_sampler" - } - # Try a range of common k, as well as large speculation. - for k in [1, 3, 5] - ]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("batch_size", [1, 32]) -def test_ngram_typical_acceptance_sampling(baseline_llm_generator, - test_llm_generator, batch_size: int, - output_len: int): - """Verify that ngram speculative decoding produces exact equality - to without spec decode with many different values of k, batch_size and - using TypicalAcceptanceSampler as the draft token acceptance method. - """ - run_greedy_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len=output_len, - force_output_len=True) From 439117d351265627f65cb1ff2ef0f30bbdb59b94 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Thu, 27 Jun 2024 18:40:12 +0000 Subject: [PATCH 34/38] Fix a comment --- tests/spec_decode/e2e/test_multistep_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index ef8d737ae8418..94cc36f22875a 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -640,7 +640,7 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, "num_speculative_tokens": k, "spec_decoding_acceptance_method": "typical_acceptance_sampler" } - # Try a range of common k, as well as large speculation. + # Try a range of common k. for k in [1, 2, 3] ]) @pytest.mark.parametrize("batch_size", [1, 32]) From 75f034f609176f071cb2de960f2b4648a39e207e Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Thu, 27 Jun 2024 23:11:20 +0000 Subject: [PATCH 35/38] Dummy commit --- tests/spec_decode/e2e/test_multistep_correctness.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 94cc36f22875a..7641375841564 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -663,3 +663,4 @@ def test_typical_acceptance_sampling(baseline_llm_generator, batch_size, max_output_len=output_len, force_output_len=True) + From 3082255b671fe6bc892205e57af6977e2a1256fa Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Thu, 27 Jun 2024 23:17:53 +0000 Subject: [PATCH 36/38] Fix format error --- tests/spec_decode/e2e/test_multistep_correctness.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 7641375841564..94cc36f22875a 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -663,4 +663,3 @@ def test_typical_acceptance_sampling(baseline_llm_generator, batch_size, max_output_len=output_len, force_output_len=True) - From d26c624371924503771ddb072086ce50eb6cdb20 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Sat, 29 Jun 2024 02:29:47 +0000 Subject: [PATCH 37/38] Dummy fix --- tests/spec_decode/e2e/test_multistep_correctness.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 94cc36f22875a..7641375841564 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -663,3 +663,4 @@ def test_typical_acceptance_sampling(baseline_llm_generator, batch_size, max_output_len=output_len, force_output_len=True) + From f186844885ae73211202d2f23defe0adbcbea771 Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Fri, 28 Jun 2024 20:34:53 -0700 Subject: [PATCH 38/38] Update test_multistep_correctness.py --- tests/spec_decode/e2e/test_multistep_correctness.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 7641375841564..94cc36f22875a 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -663,4 +663,3 @@ def test_typical_acceptance_sampling(baseline_llm_generator, batch_size, max_output_len=output_len, force_output_len=True) -