Skip to content

Commit

Permalink
[Core] Improve detokenization performance for prefill (vllm-project#3469
Browse files Browse the repository at this point in the history
)

Co-authored-by: MeloYang <[email protected]>
  • Loading branch information
Yard1 and MeloYang05 authored Mar 22, 2024
1 parent cf2f084 commit bfdb1ba
Show file tree
Hide file tree
Showing 4 changed files with 385 additions and 89 deletions.
163 changes: 151 additions & 12 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import pytest

from transformers import AutoTokenizer
from typing import List, Dict

from vllm.sequence import Sequence, Logprob, SamplingParams, SequenceGroup
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer import detokenize_incrementally
from vllm.transformers_utils.detokenizer import Detokenizer

TRUTH = [
"Hello here, this is a simple test", # noqa: E501
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa: E501
"我很感谢你的热情" # noqa: E501
"Hello here, this is a simple test",
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
"我很感谢你的热情"
]
TOKENIZERS = [
"facebook/opt-125m",
Expand All @@ -24,12 +28,12 @@


def _run_incremental_decode(tokenizer, all_input_ids,
skip_special_tokens: bool):
skip_special_tokens: bool, starting_index: int):
decoded_text = ""
offset = 0
token_offset = 0
prev_tokens = None
for i in range(len(all_input_ids)):
for i in range(starting_index, len(all_input_ids)):
new_tokens, text, offset, token_offset = detokenize_incrementally(
tokenizer,
all_input_ids[:i + 1],
Expand All @@ -46,17 +50,152 @@ def _run_incremental_decode(tokenizer, all_input_ids,


@pytest.mark.parametrize("truth", TRUTH)
@pytest.mark.parametrize("with_prompt", [True, False])
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False))
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
def test_decode_streaming(tokenizer_id, truth, with_prompt,
skip_special_tokens):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
if with_prompt:
truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"]
prompt_input_ids = truth_tokens[:len(truth) // 2]
generated_input_ids = truth_tokens[len(truth) // 2:]
all_input_ids = prompt_input_ids + generated_input_ids
starting_index = len(prompt_input_ids)
prompt = tokenizer.decode(prompt_input_ids,
skip_special_tokens=skip_special_tokens)
generated = truth[len(prompt):]
else:
generated = truth
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
if skip_special_tokens:
all_input_ids = ([tokenizer.bos_token_id]
if tokenizer.bos_token_id is not None else
[]) + all_input_ids + [tokenizer.eos_token_id]
if tokenizer.bos_token_id is not None:
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
starting_index += 1
all_input_ids = all_input_ids + [tokenizer.eos_token_id]

decoded_text = _run_incremental_decode(
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
tokenizer,
all_input_ids,
skip_special_tokens=skip_special_tokens,
starting_index=starting_index)

assert decoded_text == truth
assert decoded_text == generated


@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
init_kwargs = dict(
tokenizer_id=tokenizer_name,
enable_lora=False,
max_num_seqs=100,
max_input_length=None,
tokenizer_mode="auto",
trust_remote_code=False,
revision=None,
)

tokenizer_group = get_tokenizer_group(
None,
**init_kwargs,
)

return Detokenizer(tokenizer_group)


@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer_name: str) -> List[int]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"]
return complete_sequence_token_ids


def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
return Sequence(
seq_id=0,
prompt="<s>",
prompt_token_ids=prompt_token_ids,
block_size=16,
)


def create_dummy_logprobs(
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
return [{
token_id: Logprob(logprob=0.0),
token_id + 1: Logprob(logprob=0.1)
} for token_id in complete_sequence_token_ids]


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False])
def test_decode_sequence_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
logprobs=2)

# Run sequentially.
seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token = []
sequential_logprobs_text_other_token = []
for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs):
seq.append_token_id(new_token, logprobs)
detokenizer.decode_sequence_inplace(seq, sampling_params)
sequential_logprobs_text_chosen_token.append(
seq.output_logprobs[-1][new_token].decoded_token)
sequential_logprobs_text_other_token.append(
seq.output_logprobs[-1][new_token + 1].decoded_token)
sequential_result = seq.output_text

assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
assert sequential_result != "".join(sequential_logprobs_text_other_token)

if skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip
# special tokens.
assert sequential_result == complete_sequence


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True])
def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
prompt_logprobs=1)

# Run sequentially.
seq = create_sequence(complete_sequence_token_ids)
seq_group = SequenceGroup(request_id="1",
seqs=[seq],
sampling_params=sampling_params,
arrival_time=0.0)
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
decoded_prompt_logprobs = dummy_logprobs

if skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that this will only be true if we skip
# special tokens.
assert complete_sequence == "".join([
logprobs[token_id].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
])
assert complete_sequence != "".join([
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
])
66 changes: 9 additions & 57 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Iterable, List, Optional, Tuple, Type, Union

from transformers import PreTrainedTokenizer

Expand All @@ -15,11 +15,11 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter

logger = init_logger(__name__)
Expand Down Expand Up @@ -97,6 +97,7 @@ def __init__(
self._verify_args()

self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter()

self.model_executor = executor_class(model_config, cache_config,
Expand Down Expand Up @@ -153,7 +154,7 @@ def __reduce__(self):
raise RuntimeError("LLMEngine should not be pickled!")

def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer()
return self.tokenizer.get_lora_tokenizer(None)

def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
Expand Down Expand Up @@ -370,13 +371,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
all_token_ids = seq.get_token_ids()
for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
self._decode_logprobs(seq, seq_group.sampling_params,
prompt_logprobs_for_token,
all_token_ids[:i])
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs

# Process samples
Expand Down Expand Up @@ -420,7 +416,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
self._decode_sequence(seq, seq_group.sampling_params)
self.detokenizer.decode_sequence_inplace(seq,
seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)

# Non-beam search case
Expand Down Expand Up @@ -713,51 +710,6 @@ def _get_stats(self,
time_e2e_requests=time_e2e_requests,
)

def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
logprobs: Dict[int, Logprob],
all_input_ids: List[int]) -> None:
if not logprobs:
return
for token_id, sample_logprob in logprobs.items():
if (sample_logprob.decoded_token is None and token_id != -1):
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
(_, new_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text

def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
all_input_ids = seq.get_token_ids()
self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
all_input_ids)

(new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
if seq.tokens is None:
seq.tokens = new_tokens
else:
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_output_text

def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
Expand Down
Loading

0 comments on commit bfdb1ba

Please sign in to comment.