From bfdb1ba5c3fb14387c69acb1f5067102d8028e56 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 22 Mar 2024 13:44:12 -0700 Subject: [PATCH] [Core] Improve detokenization performance for prefill (#3469) Co-authored-by: MeloYang --- tests/tokenization/test_detokenize.py | 163 +++++++++++++++++++++++-- vllm/engine/llm_engine.py | 66 ++-------- vllm/transformers_utils/detokenizer.py | 155 +++++++++++++++++++++++ vllm/transformers_utils/tokenizer.py | 90 +++++++++++--- 4 files changed, 385 insertions(+), 89 deletions(-) create mode 100644 vllm/transformers_utils/detokenizer.py diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 4421739390e3b..082034083aebd 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -1,13 +1,17 @@ import pytest from transformers import AutoTokenizer +from typing import List, Dict +from vllm.sequence import Sequence, Logprob, SamplingParams, SequenceGroup +from vllm.transformers_utils.tokenizer_group import get_tokenizer_group from vllm.transformers_utils.tokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer import Detokenizer TRUTH = [ - "Hello here, this is a simple test", # noqa: E501 - "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa: E501 - "我很感谢你的热情" # noqa: E501 + "Hello here, this is a simple test", + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa + "我很感谢你的热情" ] TOKENIZERS = [ "facebook/opt-125m", @@ -24,12 +28,12 @@ def _run_incremental_decode(tokenizer, all_input_ids, - skip_special_tokens: bool): + skip_special_tokens: bool, starting_index: int): decoded_text = "" offset = 0 token_offset = 0 prev_tokens = None - for i in range(len(all_input_ids)): + for i in range(starting_index, len(all_input_ids)): new_tokens, text, offset, token_offset = detokenize_incrementally( tokenizer, all_input_ids[:i + 1], @@ -46,17 +50,152 @@ def _run_incremental_decode(tokenizer, all_input_ids, @pytest.mark.parametrize("truth", TRUTH) +@pytest.mark.parametrize("with_prompt", [True, False]) @pytest.mark.parametrize("tokenizer_id", TOKENIZERS) @pytest.mark.parametrize("skip_special_tokens", (True, False)) -def test_decode_streaming(tokenizer_id, truth, skip_special_tokens): +def test_decode_streaming(tokenizer_id, truth, with_prompt, + skip_special_tokens): tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) - all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"] + if with_prompt: + truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"] + prompt_input_ids = truth_tokens[:len(truth) // 2] + generated_input_ids = truth_tokens[len(truth) // 2:] + all_input_ids = prompt_input_ids + generated_input_ids + starting_index = len(prompt_input_ids) + prompt = tokenizer.decode(prompt_input_ids, + skip_special_tokens=skip_special_tokens) + generated = truth[len(prompt):] + else: + generated = truth + starting_index = 0 + all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"] if skip_special_tokens: - all_input_ids = ([tokenizer.bos_token_id] - if tokenizer.bos_token_id is not None else - []) + all_input_ids + [tokenizer.eos_token_id] + if tokenizer.bos_token_id is not None: + all_input_ids = [tokenizer.bos_token_id] + all_input_ids + starting_index += 1 + all_input_ids = all_input_ids + [tokenizer.eos_token_id] decoded_text = _run_incremental_decode( - tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens) + tokenizer, + all_input_ids, + skip_special_tokens=skip_special_tokens, + starting_index=starting_index) - assert decoded_text == truth + assert decoded_text == generated + + +@pytest.fixture +def detokenizer(tokenizer_name: str) -> Detokenizer: + init_kwargs = dict( + tokenizer_id=tokenizer_name, + enable_lora=False, + max_num_seqs=100, + max_input_length=None, + tokenizer_mode="auto", + trust_remote_code=False, + revision=None, + ) + + tokenizer_group = get_tokenizer_group( + None, + **init_kwargs, + ) + + return Detokenizer(tokenizer_group) + + +@pytest.fixture(name="complete_sequence_token_ids") +def create_complete_sequence_token_ids(complete_sequence: str, + tokenizer_name: str) -> List[int]: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"] + return complete_sequence_token_ids + + +def create_sequence(prompt_token_ids=None): + prompt_token_ids = prompt_token_ids or [1] + return Sequence( + seq_id=0, + prompt="", + prompt_token_ids=prompt_token_ids, + block_size=16, + ) + + +def create_dummy_logprobs( + complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]: + return [{ + token_id: Logprob(logprob=0.0), + token_id + 1: Logprob(logprob=0.1) + } for token_id in complete_sequence_token_ids] + + +@pytest.mark.parametrize("complete_sequence", TRUTH) +@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) +@pytest.mark.parametrize("skip_special_tokens", [True, False]) +def test_decode_sequence_logprobs(complete_sequence: str, + complete_sequence_token_ids: List[int], + detokenizer: Detokenizer, + skip_special_tokens: bool): + """Verify Detokenizer decodes logprobs correctly.""" + sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, + logprobs=2) + + # Run sequentially. + seq = create_sequence() + dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) + sequential_logprobs_text_chosen_token = [] + sequential_logprobs_text_other_token = [] + for new_token, logprobs in zip(complete_sequence_token_ids, + dummy_logprobs): + seq.append_token_id(new_token, logprobs) + detokenizer.decode_sequence_inplace(seq, sampling_params) + sequential_logprobs_text_chosen_token.append( + seq.output_logprobs[-1][new_token].decoded_token) + sequential_logprobs_text_other_token.append( + seq.output_logprobs[-1][new_token + 1].decoded_token) + sequential_result = seq.output_text + + assert sequential_result == "".join(sequential_logprobs_text_chosen_token) + assert sequential_result != "".join(sequential_logprobs_text_other_token) + + if skip_special_tokens: + # Text for logprobs for the chosen token should be the same as the + # generated text. Note that this will only be true if we skip + # special tokens. + assert sequential_result == complete_sequence + + +@pytest.mark.parametrize("complete_sequence", TRUTH) +@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) +@pytest.mark.parametrize("skip_special_tokens", [True]) +def test_decode_prompt_logprobs(complete_sequence: str, + complete_sequence_token_ids: List[int], + detokenizer: Detokenizer, + skip_special_tokens: bool): + """Verify Detokenizer decodes prompt logprobs correctly.""" + sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, + prompt_logprobs=1) + + # Run sequentially. + seq = create_sequence(complete_sequence_token_ids) + seq_group = SequenceGroup(request_id="1", + seqs=[seq], + sampling_params=sampling_params, + arrival_time=0.0) + dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) + detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs) + decoded_prompt_logprobs = dummy_logprobs + + if skip_special_tokens: + # Text for logprobs for the chosen token should be the same as the + # prompt text. Note that this will only be true if we skip + # special tokens. + assert complete_sequence == "".join([ + logprobs[token_id].decoded_token for token_id, logprobs in zip( + complete_sequence_token_ids, decoded_prompt_logprobs) + ]) + assert complete_sequence != "".join([ + logprobs[token_id + 1].decoded_token for token_id, logprobs in zip( + complete_sequence_token_ids, decoded_prompt_logprobs) + ]) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7247828418da5..283b5d9ac44c1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,5 +1,5 @@ import time -from typing import Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable, List, Optional, Tuple, Type, Union from transformers import PreTrainedTokenizer @@ -15,11 +15,11 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup, +from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) -from vllm.transformers_utils.tokenizer import detokenize_incrementally from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) +from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter logger = init_logger(__name__) @@ -97,6 +97,7 @@ def __init__( self._verify_args() self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) self.seq_counter = Counter() self.model_executor = executor_class(model_config, cache_config, @@ -153,7 +154,7 @@ def __reduce__(self): raise RuntimeError("LLMEngine should not be pickled!") def get_tokenizer(self) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer() + return self.tokenizer.get_lora_tokenizer(None) def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": @@ -370,13 +371,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs if prompt_logprobs is not None: - # We can pick any sequence for the prompt. - seq = next(iter(seq_group.seqs_dict.values())) - all_token_ids = seq.get_token_ids() - for i, prompt_logprobs_for_token in enumerate(prompt_logprobs): - self._decode_logprobs(seq, seq_group.sampling_params, - prompt_logprobs_for_token, - all_token_ids[:i]) + self.detokenizer.decode_prompt_logprobs_inplace( + seq_group, prompt_logprobs) seq_group.prompt_logprobs = prompt_logprobs # Process samples @@ -420,7 +416,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, child_seqs.append((parent, parent)) for seq, _ in child_seqs: - self._decode_sequence(seq, seq_group.sampling_params) + self.detokenizer.decode_sequence_inplace(seq, + seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params) # Non-beam search case @@ -713,51 +710,6 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) - def _decode_logprobs(self, seq: Sequence, prms: SamplingParams, - logprobs: Dict[int, Logprob], - all_input_ids: List[int]) -> None: - if not logprobs: - return - for token_id, sample_logprob in logprobs.items(): - if (sample_logprob.decoded_token is None and token_id != -1): - all_input_ids_with_logprob = all_input_ids[:-1] + [token_id] - (_, new_text, prefix_offset, - read_offset) = detokenize_incrementally( - self.get_tokenizer_for_seq(seq), - all_input_ids=all_input_ids_with_logprob, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - sample_logprob.decoded_token = new_text - - def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: - """Decodes the new token for a sequence.""" - all_input_ids = seq.get_token_ids() - self._decode_logprobs(seq, prms, seq.output_logprobs[-1], - all_input_ids) - - (new_tokens, new_output_text, prefix_offset, - read_offset) = detokenize_incrementally( - self.get_tokenizer_for_seq(seq), - all_input_ids=all_input_ids, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) - if seq.tokens is None: - seq.tokens = new_tokens - else: - seq.tokens.extend(new_tokens) - seq.prefix_offset = prefix_offset - seq.read_offset = read_offset - seq.output_text += new_output_text - def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: """Stop the finished sequences.""" diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py new file mode 100644 index 0000000000000..1f322b3675d02 --- /dev/null +++ b/vllm/transformers_utils/detokenizer.py @@ -0,0 +1,155 @@ +from typing import List, Dict, Optional +from transformers import PreTrainedTokenizer +from vllm.sequence import Sequence, Logprob, SequenceGroup, SamplingParams +from vllm.transformers_utils.tokenizer import (detokenize_incrementally, + convert_prompt_ids_to_tokens) +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) + +# Used eg. for marking rejected tokens in spec decoding. +INVALID_TOKEN_ID = -1 + + +class Detokenizer: + """Provides methods to decode the output of a model into text.""" + + def __init__(self, tokenizer_group: BaseTokenizerGroup): + self.tokenizer_group = tokenizer_group + + def get_tokenizer_for_seq(self, + sequence: Sequence) -> "PreTrainedTokenizer": + """Returns the HF tokenizer to use for a given sequence.""" + return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + def decode_prompt_logprobs_inplace( + self, seq_group: SequenceGroup, + prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None: + """Decodes the logprobs for the prompt of a sequence group. + + Args: + seq_group: The sequence group to decode. + prompt_logprobs: The logprobs to decode. + + Returns: + The prompt logprobs with the decoded tokens. + """ + prms = seq_group.sampling_params + # We can pick any sequence for the prompt. + seq = next(iter(seq_group.seqs_dict.values())) + # Only prompt, without the generated token. + all_token_ids = seq.get_token_ids() + prompt_token_ids = all_token_ids[:-1] + tokenizer = self.get_tokenizer_for_seq(seq) + prefix_offset = 0 + read_offset = 0 + next_iter_prefix_offset = 0 + next_iter_read_offset = 0 + next_iter_tokens = [] + prev_tokens = None + + for token_position, prompt_logprobs_for_token in enumerate( + prompt_logprobs): + if not prompt_logprobs_for_token: + continue + for token_id, sample_logprob in prompt_logprobs_for_token.items(): + if (sample_logprob.decoded_token is None + and token_id != INVALID_TOKEN_ID): + prompt_token_ids_with_token = ( + prompt_token_ids[:token_position] + [token_id]) + (new_tokens, new_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=prompt_token_ids_with_token, + prev_tokens=prev_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms. + spaces_between_special_tokens, + ) + + sample_logprob.decoded_token = new_text + + # Use the offsets & prev tokens corresponding to + # real tokens to ensure detokenization is consistent + # actual with prompt. + if token_id == all_token_ids[token_position]: + next_iter_prefix_offset = new_prefix_offset + next_iter_read_offset = new_read_offset + next_iter_tokens = new_tokens + + # Advance to the next token position. + prefix_offset = next_iter_prefix_offset + read_offset = next_iter_read_offset + if prev_tokens is None: + prev_tokens = next_iter_tokens + else: + prev_tokens.extend(next_iter_tokens) + + def decode_sequence_inplace(self, seq: Sequence, + prms: SamplingParams) -> None: + """Decodes the new token for a sequence. In-place operation. + + Args: + seq: The sequence to decode. + prms: The sampling parameters used to generate the sequence. + """ + all_input_ids = seq.get_token_ids() + token_id_generated_this_iteration = all_input_ids[-1] + tokenizer = self.get_tokenizer_for_seq(seq) + + # Convert prompt token IDs to tokens if necessary. + # Do it here so that we don't have to repeat this + # computation for each logprob. + if seq.tokens is None: + (seq.tokens, seq.prefix_offset, + seq.read_offset) = convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=all_input_ids[:-1], + skip_special_tokens=prms.skip_special_tokens, + ) + + (new_tokens, new_decoded_token_text, prefix_offset, + read_offset) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=all_input_ids, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) + + # Decode logprobs + logprobs = seq.output_logprobs[-1] + if logprobs: + previous_tokens = all_input_ids[:-1] + for token_id, sample_logprob in logprobs.items(): + # If the token was generated this iteration, + # use the provided text. + if token_id == token_id_generated_this_iteration: + sample_logprob.decoded_token = new_decoded_token_text + continue + + if (sample_logprob.decoded_token is None + and token_id != INVALID_TOKEN_ID): + all_input_ids_with_logprob = previous_tokens + [token_id] + (_, new_text, _, _) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=all_input_ids_with_logprob, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms. + spaces_between_special_tokens, + ) + sample_logprob.decoded_token = new_text + + if seq.tokens is None: + seq.tokens = new_tokens + else: + seq.tokens.extend(new_tokens) + seq.prefix_offset = prefix_offset + seq.read_offset = read_offset + seq.output_text += new_decoded_token_text diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index f7a1a19a89bcf..eebdacc4903ca 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -158,6 +158,34 @@ def _convert_tokens_to_string_with_added_encoders( return "".join(sub_texts) +# 5 is an arbitrary value that should work for all +# tokenizers (bigger = more conservative). +INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + + +def convert_prompt_ids_to_tokens( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + prompt_ids: List[int], + skip_special_tokens: bool = False, +) -> Tuple[List[str], int, int]: + """Converts the prompt ids to tokens and returns the tokens and offsets + for incremental detokenization. + + Note that not all tokens are converted to strings. Only the tokens that + are necessary for incremental detokenization are converted to strings. + """ + # Offset a little more in case we have special tokens. + prefix_offset = max( + len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0) + # We do not need to convert the whole prompt to tokens. + new_tokens = tokenizer.convert_ids_to_tokens( + prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens) + prefix_offset = max( + len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + read_offset = len(new_tokens) + return new_tokens, prefix_offset, read_offset + + # Based on # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 # under Apache 2.0 license @@ -165,31 +193,53 @@ def detokenize_incrementally( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], all_input_ids: List[int], prev_tokens: Optional[List[str]], - prefix_offset: int = 0, - read_offset: int = 0, + prefix_offset: int, + read_offset: int, skip_special_tokens: bool = False, spaces_between_special_tokens: bool = True, ) -> Tuple[List[str], str, int, int]: + """Detokenizes the input ids incrementally and returns the new tokens + and the new text. + + If `prev_tokens` is None, this function will convert the input ids to + tokens and return the tokens and the new text. Otherwise, it will return the + new tokens and the new text. + + This function will also return the new prefix offset and the new read + offset to be used in the next iteration. + + The offsets are necessary to defeat cleanup algorithms in the decode which + decide to add a space or not depending on the surrounding ids. + + Args: + tokenizer: The tokenizer to use. + all_input_ids: The input ids. The last id is the new token id. + prev_tokens: The previous tokens. If None, this function will convert + the input ids to tokens and return the tokens and the new text. + prefix_offset: The prefix offset. + read_offset: The read offset. + skip_special_tokens: Whether to skip special tokens. + spaces_between_special_tokens: Whether to add spaces between special + tokens. + """ new_token_id = all_input_ids[-1] # This is the first iteration for this sequence - if prev_tokens is None: - new_tokens = tokenizer.convert_ids_to_tokens( - all_input_ids, skip_special_tokens=skip_special_tokens) - output_tokens = new_tokens - # 5 is an arbitrary value that should work for all - # tokenizers (bigger = more conservative). - # Subtract 1 extra to account for the generated token. - prefix_offset = max(len(output_tokens) - 6, 0) - # If the first new token is a special token, we can't skip 1 extra token - if skip_special_tokens and new_token_id in tokenizer.all_special_ids: - read_offset = max(len(output_tokens), 0) - else: - read_offset = max(len(output_tokens) - 1, 0) - else: - # Put new_token_id in a list so skip_special_tokens is respected - new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) - output_tokens = prev_tokens + new_tokens + is_first_iter = prev_tokens is None + if is_first_iter: + (prev_tokens, prefix_offset, + read_offset) = convert_prompt_ids_to_tokens( + tokenizer, + all_input_ids[:-1], + skip_special_tokens=skip_special_tokens) + + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) + output_tokens = prev_tokens + new_tokens + + # If this is the first iteration, return all tokens. + if is_first_iter: + new_tokens = output_tokens # The prefix text is necessary only to defeat cleanup algorithms in # the decode which decide to add a space or not depending on the