Skip to content

Commit

Permalink
[BugFix] Fix handling of stop strings and stop token ids (#3672)
Browse files Browse the repository at this point in the history
njhill authored Apr 11, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 1e96c33 commit e46a60a
Showing 8 changed files with 202 additions and 37 deletions.
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -401,7 +401,7 @@ def __del__(self):
cleanup()


@pytest.fixture
@pytest.fixture(scope="session")
def vllm_runner():
return VllmRunner

Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
2. One of the provided stop tokens
3. The EOS token
Run `pytest tests/samplers/test_stop_reason.py`.
Run `pytest tests/engine/test_stop_reason.py`.
"""

import pytest
111 changes: 111 additions & 0 deletions tests/engine/test_stop_strings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Any, List, Optional

import pytest

from vllm import CompletionOutput, LLMEngine, SamplingParams

MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200


@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
return vllm_runner(MODEL)


@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["."],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=".")

_test_stopping(vllm_model.model.llm_engine,
stop=["."],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".")


@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo")

_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=True,
expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo")


@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer or",
expected_reason="gani")

_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani")


@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
# token id 13013 => " organization"

_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=False,
expected_output="VLLM is a 100% volunteer",
expected_reason=13013)

_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013)


def _test_stopping(llm_engine: LLMEngine,
expected_output: str,
expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)

output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs

# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason

assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason
98 changes: 65 additions & 33 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -501,9 +501,11 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize:
self.detokenizer.decode_sequence_inplace(
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)
else:
new_char_count = 0
self._check_stop(seq, new_char_count, seq_group.sampling_params)

# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
@@ -798,56 +800,86 @@ def _get_stats(self,
time_e2e_requests=time_e2e_requests,
)

def _check_stop(self, seq: Sequence,
def _check_stop(self, seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
"""Stop the finished sequences.
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""

# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return

if sampling_params.detokenize:
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return

# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
last_token_id)
self._finalize_sequence(seq, sampling_params, stop_str)
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return

# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return

def _finalize_sequence(self, seq: Sequence,
sampling_params: SamplingParams,
stop_string: str) -> None:
if sampling_params.include_stop_str_in_output:
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

if stop_string and seq.output_text.endswith(stop_string):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_string)]
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> Optional[str]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
"""
if not new_char_count:
return None

for stop_str in sampling_params.stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = seq.output_text.find(
stop_str, -new_char_count - stop_string_len)
if stop_index == -1:
continue

if sampling_params.include_stop_str_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(seq.output_text):
# No truncation required.
return stop_str

# Truncate the output text to either the beginning
# or end of the stop string.
seq.output_text = seq.output_text[:stop_index]
return stop_str
return None

def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)
4 changes: 3 additions & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
@@ -112,8 +112,10 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [
CompletionOutput(seqs.index(seq), seq.output_text,
CompletionOutput(seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.output_logprobs if include_logprobs else None,
9 changes: 9 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
@@ -166,6 +166,13 @@ def __init__(
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
if self.stop and not include_stop_str_in_output:
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.output_text_buffer_length = 0

self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
@@ -226,6 +233,8 @@ def _verify_args(self) -> None:
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
if any(not stop_str for stop_str in self.stop):
raise ValueError("stop cannot contain an empty string.")
if self.stop and not self.detokenize:
raise ValueError(
"stop strings are only supported when detokenize is True. "
6 changes: 6 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
@@ -235,6 +235,12 @@ def __init__(
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0

def get_output_text_to_return(self, buffer_length: int):
# We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished()
return self.output_text[:-buffer_length] if truncate else (
self.output_text)

def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size

7 changes: 6 additions & 1 deletion vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
@@ -87,12 +87,15 @@ def decode_prompt_logprobs_inplace(
prev_tokens.extend(next_iter_tokens)

def decode_sequence_inplace(self, seq: Sequence,
prms: SamplingParams) -> None:
prms: SamplingParams) -> int:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
@@ -151,6 +154,8 @@ def decode_sequence_inplace(self, seq: Sequence,
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text

return len(new_decoded_token_text)


def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],

0 comments on commit e46a60a

Please sign in to comment.