Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] [Frontend] Make detokenization optional #3749

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/engine/test_detokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_computed_prefix_blocks(model: str):
# This test checks if the engine generates completions both with and
# without optional detokenization, that detokenization includes text
# and no-detokenization doesn't, and that both completions have the same
# token_ids.
prompt = (
"You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available "
"online for free?")

llm = LLM(model=model)
sampling_params = SamplingParams(max_tokens=10,
temperature=0.0,
detokenize=False)

outputs_no_detokenization = llm.generate(prompt,
sampling_params)[0].outputs[0]
sampling_params.detokenize = True
outputs_with_detokenization = llm.generate(prompt,
sampling_params)[0].outputs[0]

assert outputs_no_detokenization.text == ''
assert outputs_with_detokenization.text != ''
assert outputs_no_detokenization.token_ids == \
outputs_with_detokenization.token_ids
20 changes: 11 additions & 9 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs
Expand Down Expand Up @@ -461,8 +461,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
self.detokenizer.decode_sequence_inplace(seq,
seq_group.sampling_params)
if seq_group.sampling_params.detokenize:
self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)

# Non-beam search case
Expand Down Expand Up @@ -774,12 +775,13 @@ def _check_stop(self, seq: Sequence,
if seq.get_output_len() < sampling_params.min_tokens:
return

for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
if sampling_params.detokenize:
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
Expand Down
10 changes: 10 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class SamplingParams:
log probability of the sampled token, so there may be up to
`logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
min_tokens: int = 0,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
detokenize: bool = True,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
Expand Down Expand Up @@ -150,6 +152,10 @@ def __init__(
self.min_tokens = min_tokens
self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self.detokenize = detokenize
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
Expand Down Expand Up @@ -210,6 +216,10 @@ def _verify_args(self) -> None:
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
if self.stop and not self.detokenize:
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")

def _verify_beam_search(self) -> None:
if self.best_of == 1:
Expand Down
Loading