From 49782fcb769eb4f04a3cf5179c1e6c13ab633ce1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 1 Apr 2024 13:22:06 -0700 Subject: [PATCH] [Misc] Some minor simplifications to detokenization logic (#3670) Some simplifications made for clarity. Also moves detokenization-related functions from tokenizer.py to detokenizer.py. --- tests/tokenization/test_detokenize.py | 4 +- vllm/transformers_utils/detokenizer.py | 164 +++++++++++++++++++++++-- vllm/transformers_utils/tokenizer.py | 156 +---------------------- 3 files changed, 159 insertions(+), 165 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 92587b40dd45a..9bc9becb2a6f1 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -4,8 +4,8 @@ from transformers import AutoTokenizer from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer import (Detokenizer, + detokenize_incrementally) from vllm.transformers_utils.tokenizer_group import get_tokenizer_group TRUTH = [ diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 419687e23b718..486c1938e1e10 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,10 +1,8 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple, Union -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.tokenizer import (convert_prompt_ids_to_tokens, - detokenize_incrementally) from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) @@ -148,10 +146,160 @@ def decode_sequence_inplace(self, seq: Sequence, ) sample_logprob.decoded_token = new_text - if seq.tokens is None: - seq.tokens = new_tokens - else: - seq.tokens.extend(new_tokens) + seq.tokens.extend(new_tokens) seq.prefix_offset = prefix_offset seq.read_offset = read_offset seq.output_text += new_decoded_token_text + + +def _convert_tokens_to_string_with_added_encoders( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + output_tokens: List[str], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, +) -> str: + # Adapted from + # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts = [] + current_sub_text = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + if spaces_between_special_tokens: + return " ".join(sub_texts) + else: + return "".join(sub_texts) + + +# 5 is an arbitrary value that should work for all +# tokenizers (bigger = more conservative). +INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + + +def convert_prompt_ids_to_tokens( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + prompt_ids: List[int], + skip_special_tokens: bool = False, +) -> Tuple[List[str], int, int]: + """Converts the prompt ids to tokens and returns the tokens and offsets + for incremental detokenization. + + Note that not all tokens are converted to strings. Only the tokens that + are necessary for incremental detokenization are converted to strings. + """ + # We do not need to convert the whole prompt to tokens. + # Offset a little more in case we have special tokens. + new_tokens = tokenizer.convert_ids_to_tokens( + prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], + skip_special_tokens=skip_special_tokens) + read_offset = len(new_tokens) + prefix_offset = max( + read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + return new_tokens, prefix_offset, read_offset + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int, + read_offset: int, + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = True, +) -> Tuple[List[str], str, int, int]: + """Detokenizes the input ids incrementally and returns the new tokens + and the new text. + + If `prev_tokens` is None, this function will convert the input ids to + tokens and return the tokens and the new text. Otherwise, it will return the + new tokens and the new text. + + This function will also return the new prefix offset and the new read + offset to be used in the next iteration. + + The offsets are necessary to defeat cleanup algorithms in the decode which + decide to add a space or not depending on the surrounding ids. + + Args: + tokenizer: The tokenizer to use. + all_input_ids: The input ids. The last id is the new token id. + prev_tokens: The previous tokens. If None, this function will convert + the input ids to tokens and return the tokens and the new text. + prefix_offset: The prefix offset. + read_offset: The read offset. + skip_special_tokens: Whether to skip special tokens. + spaces_between_special_tokens: Whether to add spaces between special + tokens. + """ + new_token_id = all_input_ids[-1] + # This is the first iteration for this sequence + is_first_iter = prev_tokens is None + if is_first_iter: + (prev_tokens, prefix_offset, + read_offset) = convert_prompt_ids_to_tokens( + tokenizer, + all_input_ids[:-1], + skip_special_tokens=skip_special_tokens) + + # If the new token id is out of bounds, return an empty string. + if new_token_id >= len(tokenizer): + new_tokens = [""] + else: + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) + output_tokens = prev_tokens + new_tokens + + # If this is the first iteration, return all tokens. + if is_first_iter: + new_tokens = output_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset]) + new_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + if len(new_text) <= len(prefix_text) or new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + return new_tokens, "", prefix_offset, read_offset + + new_text = new_text[len(prefix_text):] + return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 3bda3f419d8a2..e216a99af91f9 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Optional, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) @@ -126,157 +126,3 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, get_lora_tokenizer_async = make_async(get_lora_tokenizer) - - -def _convert_tokens_to_string_with_added_encoders( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - output_tokens: List[str], - skip_special_tokens: bool, - spaces_between_special_tokens: bool, -) -> str: - # Adapted from - # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 - # NOTE(woosuk): The following code is slow because it runs a for loop over - # the output_tokens. In Python, running a for loop over a list can be slow - # even when the loop body is very simple. - sub_texts = [] - current_sub_text = [] - all_special_tokens = set(tokenizer.all_special_tokens) - for token in output_tokens: - if skip_special_tokens and token in all_special_tokens: - continue - if token in tokenizer.get_added_vocab(): - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - current_sub_text = [] - sub_texts.append(token) - else: - current_sub_text.append(token) - if current_sub_text: - sub_text = tokenizer.convert_tokens_to_string(current_sub_text) - sub_texts.append(sub_text) - if spaces_between_special_tokens: - return " ".join(sub_texts) - else: - return "".join(sub_texts) - - -# 5 is an arbitrary value that should work for all -# tokenizers (bigger = more conservative). -INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 - - -def convert_prompt_ids_to_tokens( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - prompt_ids: List[int], - skip_special_tokens: bool = False, -) -> Tuple[List[str], int, int]: - """Converts the prompt ids to tokens and returns the tokens and offsets - for incremental detokenization. - - Note that not all tokens are converted to strings. Only the tokens that - are necessary for incremental detokenization are converted to strings. - """ - # Offset a little more in case we have special tokens. - prefix_offset = max( - len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0) - # We do not need to convert the whole prompt to tokens. - new_tokens = tokenizer.convert_ids_to_tokens( - prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens) - prefix_offset = max( - len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) - read_offset = len(new_tokens) - return new_tokens, prefix_offset, read_offset - - -# Based on -# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 -# under Apache 2.0 license -def detokenize_incrementally( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - all_input_ids: List[int], - prev_tokens: Optional[List[str]], - prefix_offset: int, - read_offset: int, - skip_special_tokens: bool = False, - spaces_between_special_tokens: bool = True, -) -> Tuple[List[str], str, int, int]: - """Detokenizes the input ids incrementally and returns the new tokens - and the new text. - - If `prev_tokens` is None, this function will convert the input ids to - tokens and return the tokens and the new text. Otherwise, it will return the - new tokens and the new text. - - This function will also return the new prefix offset and the new read - offset to be used in the next iteration. - - The offsets are necessary to defeat cleanup algorithms in the decode which - decide to add a space or not depending on the surrounding ids. - - Args: - tokenizer: The tokenizer to use. - all_input_ids: The input ids. The last id is the new token id. - prev_tokens: The previous tokens. If None, this function will convert - the input ids to tokens and return the tokens and the new text. - prefix_offset: The prefix offset. - read_offset: The read offset. - skip_special_tokens: Whether to skip special tokens. - spaces_between_special_tokens: Whether to add spaces between special - tokens. - """ - new_token_id = all_input_ids[-1] - # This is the first iteration for this sequence - is_first_iter = prev_tokens is None - if is_first_iter: - (prev_tokens, prefix_offset, - read_offset) = convert_prompt_ids_to_tokens( - tokenizer, - all_input_ids[:-1], - skip_special_tokens=skip_special_tokens) - - # If the new token id is out of bounds, return an empty string. - if new_token_id >= len(tokenizer): - new_tokens = [""] - else: - # Put new_token_id in a list so skip_special_tokens is respected - new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) - output_tokens = prev_tokens + new_tokens - - # If this is the first iteration, return all tokens. - if is_first_iter: - new_tokens = output_tokens - - # The prefix text is necessary only to defeat cleanup algorithms in - # the decode which decide to add a space or not depending on the - # surrounding ids. - if tokenizer.is_fast or not tokenizer.get_added_vocab(): - prefix_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:read_offset]) - new_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:]) - else: - prefix_text = _convert_tokens_to_string_with_added_encoders( - tokenizer, - output_tokens[prefix_offset:read_offset], - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - new_text = _convert_tokens_to_string_with_added_encoders( - tokenizer, - output_tokens[prefix_offset:], - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - - if len(new_text) > len(prefix_text) and not new_text.endswith("�"): - # utf-8 char at the end means it's a potential unfinished byte sequence - # from byte fallback tokenization. - # If it's in the middle, it's probably a real invalid id generated - # by the model - new_text = new_text[len(prefix_text):] - return new_tokens, new_text, read_offset, len(output_tokens) - else: - return new_tokens, "", prefix_offset, read_offset