diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index de305c4030aa9..88b40d1eb4674 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -34,8 +34,8 @@ def test_ngram_algo_correctness_for_single_no_match(): max_proposal_len=20, ) - # set ngram window (0, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(0, 3) + # set ngram window [1, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) prompts = [ # shall find no candidate @@ -90,8 +90,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): max_proposal_len=20, ) - # set ngram window (0, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(0, 3) + # set ngram window [1, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) prompts = [ # shall find no candidate @@ -128,11 +128,12 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len]) assert proposals.proposal_lens.shape == torch.Size([5]) + # the first sequence has no match so proposal_len should be overwritten to 0 assert proposals.proposal_lens.tolist( - ) == [proposal_len for _ in range(4)] + [0] + ) == [0] + [proposal_len for _ in range(3)] + [0] for i in range(proposal_len): - assert proposals.proposal_token_ids[0][i] == 0 + assert proposals.proposal_token_ids[0][i] == -1 assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1] assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3] assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5] @@ -167,8 +168,8 @@ def test_ngram_algo_correctness_for_batches_match_all(): max_proposal_len=20, ) - # set ngram window (0, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(0, 3) + # set ngram window [0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) prompts = [ # shall find candidate 12,13,14,15,16 diff --git a/vllm/config.py b/vllm/config.py index fab9cfbf41a2d..435f47dc9459a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -784,12 +784,15 @@ def maybe_create_spec_config( draft_quantization = None if speculative_model == "[ngram]": - assert (ngram_prompt_lookup_max is not None - and ngram_prompt_lookup_max > 0) if ngram_prompt_lookup_min is None: - ngram_prompt_lookup_min = 0 - else: - assert ngram_prompt_lookup_max > ngram_prompt_lookup_min + ngram_prompt_lookup_min = 1 + if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1: + raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0") + if ngram_prompt_lookup_min < 1: + raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0") + if ngram_prompt_lookup_min > ngram_prompt_lookup_max: + raise ValueError(f"{ngram_prompt_lookup_min=} cannot be " + f"larger than {ngram_prompt_lookup_max=}") # TODO: current we still need extract vocab_size from target model # config, in future, we may try refactor it out, and set diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 6cd50fcc1a041..9628f7af5315a 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -77,9 +77,11 @@ def sampler_output( """ self._raise_if_unsupported(execute_model_req) - arr = [] has_spec_out = False - for seq_group_metadata in execute_model_req.seq_group_metadata_list: + token_id_list = [] + token_prob_list = [] + for idx, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): seq_data = next(iter(seq_group_metadata.seq_data.values())) input_ids = torch.as_tensor(seq_data.get_token_ids(), @@ -89,59 +91,64 @@ def sampler_output( for ngram_size in range( min(self.ngram_prompt_lookup_max, input_length - 1), - self.ngram_prompt_lookup_min, + self.ngram_prompt_lookup_min - 1, -1, ): - ngram_tensor = input_ids[-1 * ngram_size:] - windows = input_ids.unfold(dimension=0, - size=ngram_size, - step=1) - matches = (windows == ngram_tensor).all(dim=1) - match_indices = matches.nonzero(as_tuple=True)[0] - if match_indices.size()[0] > 1: + ngram_tensor = input_ids[-ngram_size:] + proposal_start_idx = None + if ngram_size == 1: + # Do not match itself and do not use unfold and all + matches = (input_ids[:-1] == ngram_tensor) + else: + windows = input_ids.unfold(dimension=0, + size=ngram_size, + step=1) + # Do not match itself + matches = (windows[:-1] == ngram_tensor).all(dim=-1) + + # first_match includes "values" (bool), indicating whether + # the match is found, and "indices", indicating the index + # of the first match. + # Note that "first_match.values.item()" triggers GPU-CPU + # sync so it is a bit inefficient, but we have not found + # a better way to do this. + first_match = matches.max(dim=-1) + if first_match.values.item(): + proposal_start_idx = first_match.indices.add_(ngram_size) + spec_indices = ( + proposal_start_idx).repeat(sample_len) + torch.arange( + sample_len, device=self.device) + spec_indices.clamp_(max=input_ids.shape[-1] - 1) + res = input_ids.gather(dim=-1, index=spec_indices) + token_id_list.append(res) + token_prob_list.append( + torch.nn.functional.one_hot( + res, + num_classes=self.vocab_size).to(torch.float32)) has_spec_out = True - res = seq_data.get_token_ids() - res = res[match_indices[0] + ngram_size:match_indices[0] + - ngram_size + sample_len] - res_len = len(res) - # pad 0 towards output as sample_len tokens required - res += [0] * (sample_len - res_len) - break else: - # if no candidate found, fill with 0 - res = [0] * sample_len - - arr.append(res) + token_id_list.append(None) + token_prob_list.append(None) if not has_spec_out: return None, False - outputs = [] - token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device) - indices = token_ids.unsqueeze(2) + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(execute_model_req.seq_group_metadata_list)): + if token_id_list[idx] is None: + outputs.append(None) + else: + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx], + logprobs=torch.zeros((sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device), + sampled_token_ids=token_id_list[idx], + )) - token_probs = torch.zeros( - (len(execute_model_req.seq_group_metadata_list), sample_len, - self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - token_probs.scatter_(2, indices, 1) - token_logprobs = torch.zeros( - (len(execute_model_req.seq_group_metadata_list), sample_len, - self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - for i in range(len(execute_model_req.seq_group_metadata_list)): - outputs.append( - SamplerOutput( - outputs=None, - sampled_token_probs=token_probs[i], - logprobs=token_logprobs[i], - sampled_token_ids=token_ids[i], - )) return outputs, False def get_spec_proposals( diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index ee9462b68dae8..6c7e22207f6b2 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -73,6 +73,14 @@ def get_proposals( execute_model_req=nonzero_execute_model_req, sample_len=proposal_len, ) + ( + proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + ) = self._remove_no_proposal_seqs(proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + transposed) else: # If no sequences can be speculated, set sampler output to None. maybe_sampler_output = None @@ -140,6 +148,61 @@ def _split_by_proposal_len( nonzero_proposal_len_indices, ) + def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices, transposed): + """Remove sequences from nonzero_proposal_len_indices and reset + their proposal_len to 0 the draft worker does not provide a proposal + (maybe_sampler_output=None). This can avoid scoring overheads. + """ + + # If maybe_sampler_output is None, then the draft worker did not + # provide a proposal for any sequence and thus no action needed. + # Also we do not support transposed maybe_sampler_output for now + # because it seems not straightforward for draft workers outputting + # transposed sampler outputs to handle the case of no proposal. + if maybe_sampler_output is None or transposed: + return (proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices) + + new_proposal_lens: List[int] = [] + new_nonzero_proposal_len_indices: List[int] = [] + new_maybe_sampler_output: List[SamplerOutput] = [] + nonzero_proposal_len_idx_ptr = 0 + seq_idx = 0 + while seq_idx < len( + proposal_lens) and nonzero_proposal_len_idx_ptr < len( + nonzero_proposal_len_indices): + if seq_idx < nonzero_proposal_len_indices[ + nonzero_proposal_len_idx_ptr]: + # Sequence is not in the original nonzero_proposal_len_indices, + # meaning that it has a proposal length of 0 before sending to + # the draft worker. + assert proposal_lens[seq_idx] == 0 + new_proposal_lens.append(0) + else: + # Sequence is in the original nonzero_proposal_len_indices + if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None: + # but does not have a proposal from the draft worker. + new_proposal_lens.append(0) + else: + # and has a proposal from the draft worker. Add it to the + # new nonzero proposal list and keep the sampler output. + new_proposal_lens.append(proposal_lens[seq_idx]) + new_nonzero_proposal_len_indices.append(seq_idx) + new_maybe_sampler_output.append( + maybe_sampler_output[nonzero_proposal_len_idx_ptr]) + nonzero_proposal_len_idx_ptr += 1 + seq_idx += 1 + + # The remaining sequences should have proposal length of 0. + new_proposal_lens.extend(proposal_lens[seq_idx:]) + + # We assume sampler_output will not be a list of all Nones. + # In this case this function should not be called. + assert new_maybe_sampler_output + return (new_proposal_lens, new_maybe_sampler_output, + new_nonzero_proposal_len_indices) + def _merge_outputs( self, batch_size: int,