From 8999ec3c1632c91c194ab27df6bf274f5bcb0b5f Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 5 Mar 2024 15:35:43 -0800 Subject: [PATCH] Store `eos_token_id` in `Sequence` for easy access (#3166) --- tests/test_cache_block_hashing.py | 3 +- vllm/core/scheduler.py | 7 ++--- vllm/engine/llm_engine.py | 30 +++++++++----------- vllm/model_executor/layers/sampler.py | 1 - vllm/outputs.py | 41 ++++++++++++++------------- vllm/sequence.py | 11 ++++--- 6 files changed, 44 insertions(+), 49 deletions(-) diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 7c4ade7f8c8ed..c2067e52b59c0 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -54,7 +54,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int): for prompt in prompts: hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, + tokenizer.tokenizer.eos_token_id) num_blocks = len(prompt_token_ids) // block_size for idx in range(num_blocks): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1ae58f525b0fb..c96c6d62ef19d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -59,10 +59,9 @@ def is_empty(self) -> bool: and not self.blocks_to_swap_out and not self.blocks_to_copy) def _sort_by_lora_ids(self) -> bool: - self.scheduled_seq_groups = sorted( - self.scheduled_seq_groups, - key=lambda g: (g.lora_request.lora_int_id - if g.lora_request else 0, g.request_id)) + self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, + key=lambda g: + (g.lora_int_id, g.request_id)) @property def lora_requests(self) -> Set[LoRARequest]: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1f518cbf39b21..52dc96e2b82e1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -491,8 +491,10 @@ def add_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) + eos_token_id = self.tokenizer.get_lora_tokenizer( + lora_request).eos_token_id seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - lora_request) + eos_token_id, lora_request) # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects @@ -548,15 +550,13 @@ def _check_beam_search_early_stopping( if early_stopping is True: return True - current_worst_score = (current_worst_seq.get_beam_search_score( + current_worst_score = current_worst_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.get_tokenizer_for_seq( - current_worst_seq).eos_token_id)) + eos_token_id=current_worst_seq.eos_token_id) if early_stopping is False: - highest_attainable_score = (best_running_seq.get_beam_search_score( + highest_attainable_score = best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.get_tokenizer_for_seq( - best_running_seq).eos_token_id)) + eos_token_id=best_running_seq.eos_token_id) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -570,8 +570,7 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.get_tokenizer_for_seq( - best_running_seq).eos_token_id, + eos_token_id=best_running_seq.eos_token_id, seq_len=max_possible_length)) else: # Otherwise, beam search will prefer shorter sequences. The @@ -580,8 +579,7 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.get_tokenizer_for_seq( - best_running_seq).eos_token_id)) + eos_token_id=best_running_seq.eos_token_id)) return current_worst_score >= highest_attainable_score def _process_sequence_group_outputs(self, seq_group: SequenceGroup, @@ -679,8 +677,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, all_finished_seqs = existing_finished_seqs + new_finished_seqs # Sort the finished sequences by their scores. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), + length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), reverse=True) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: @@ -707,8 +704,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if not seq.is_finished()] # Sort the running sequences by their scores. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), + length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), reverse=True) # Check if we can stop the beam search. @@ -1014,8 +1010,8 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) and seq.get_last_token_id() - == self.get_tokenizer_for_seq(seq).eos_token_id): + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): seq.status = SequenceStatus.FINISHED_STOPPED return diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index b48dde0318d09..320cb443524ca 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -516,7 +516,6 @@ def _get_logprobs( if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): num_logprobs = sampling_params.prompt_logprobs - prompt_len = sampling_metadata.prompt_lens[i] prompt_tokens = sampling_metadata.seq_data[ seq_ids[0]].prompt_token_ids group_prompt_logprobs: PromptLogprobs = [None] diff --git a/vllm/outputs.py b/vllm/outputs.py index a6de2a5a2257b..4f9eddee11cd4 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -90,29 +90,30 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # Get the top-n sequences. n = seq_group.sampling_params.n seqs = seq_group.get_seqs() - if seq_group.sampling_params.use_beam_search: - sorting_key = lambda seq: seq.get_beam_search_score( - seq_group.sampling_params.length_penalty) + if n == 1: + top_n_seqs = seqs else: - sorting_key = lambda seq: seq.get_cumulative_logprob() - sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) - top_n_seqs = sorted_seqs[:n] + if seq_group.sampling_params.use_beam_search: + sorting_key = lambda seq: seq.get_beam_search_score( + seq_group.sampling_params.length_penalty) + else: + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) + top_n_seqs = sorted_seqs[:n] # Create the outputs. - outputs: List[CompletionOutput] = [] - for seq in top_n_seqs: - logprobs = seq.output_logprobs - if seq_group.sampling_params.logprobs is None: - # NOTE: We need to take care of this case because the sequence - # always has the logprobs of the sampled tokens even if the - # logprobs are not requested. - logprobs = None - finshed_reason = SequenceStatus.get_finished_reason(seq.status) - output = CompletionOutput(seqs.index(seq), seq.output_text, - seq.get_output_token_ids(), - seq.get_cumulative_logprob(), logprobs, - finshed_reason) - outputs.append(output) + # NOTE: We need omit logprobs here explicitly because the sequence + # always has the logprobs of the sampled tokens even if the + # logprobs are not requested. + include_logprobs = seq_group.sampling_params.logprobs + outputs = [ + CompletionOutput(seqs.index(seq), seq.output_text, + seq.get_output_token_ids(), + seq.get_cumulative_logprob(), + seq.output_logprobs if include_logprobs else None, + SequenceStatus.get_finished_reason(seq.status)) + for seq in top_n_seqs + ] # Every sequence in the sequence group should have the same prompt. prompt = seq_group.prompt diff --git a/vllm/sequence.py b/vllm/sequence.py index a110ab6b748f8..97b72fdc4cbeb 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -142,11 +142,13 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + eos_token_id: int, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id self.prompt = prompt self.block_size = block_size + self.eos_token_id = eos_token_id self.lora_request = lora_request self.data = SequenceData(prompt_token_ids) @@ -362,12 +364,9 @@ def get_seqs( self, status: Optional[SequenceStatus] = None, ) -> List[Sequence]: - if status is None: - return list(self.seqs_dict.values()) - else: - return [ - seq for seq in self.seqs_dict.values() if seq.status == status - ] + return list(self.seqs_dict.values()) if status is None else [ + seq for seq in self.seqs_dict.values() if seq.status == status + ] def get_unfinished_seqs(self) -> List[Sequence]: return [