diff --git a/tests/conftest.py b/tests/conftest.py index a7e8963af0eda..5c50fc2d1bab6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -401,7 +401,7 @@ def __del__(self): cleanup() -@pytest.fixture +@pytest.fixture(scope="session") def vllm_runner(): return VllmRunner diff --git a/tests/samplers/test_stop_reason.py b/tests/engine/test_stop_reason.py similarity index 97% rename from tests/samplers/test_stop_reason.py rename to tests/engine/test_stop_reason.py index b242c405a4fb6..b2f521a8ae4ce 100644 --- a/tests/samplers/test_stop_reason.py +++ b/tests/engine/test_stop_reason.py @@ -3,7 +3,7 @@ 2. One of the provided stop tokens 3. The EOS token -Run `pytest tests/samplers/test_stop_reason.py`. +Run `pytest tests/engine/test_stop_reason.py`. """ import pytest diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py new file mode 100644 index 0000000000000..6b747beb4b543 --- /dev/null +++ b/tests/engine/test_stop_strings.py @@ -0,0 +1,111 @@ +from typing import Any, List, Optional + +import pytest + +from vllm import CompletionOutput, LLMEngine, SamplingParams + +MODEL = "meta-llama/llama-2-7b-hf" +MAX_TOKENS = 200 + + +@pytest.fixture(scope="session") +def vllm_model(vllm_runner): + return vllm_runner(MODEL) + + +@pytest.mark.skip_global_cleanup +def test_stop_basic(vllm_model): + _test_stopping(vllm_model.model.llm_engine, + stop=["."], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=".") + + _test_stopping(vllm_model.model.llm_engine, + stop=["."], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization.", + expected_reason=".") + + +@pytest.mark.skip_global_cleanup +def test_stop_multi_tokens(vllm_model): + _test_stopping( + vllm_model.model.llm_engine, + stop=["group of peo", "short"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization. We are a ", + expected_reason="group of peo") + + _test_stopping( + vllm_model.model.llm_engine, + stop=["group of peo", "short"], + include_in_output=True, + expected_output= + "VLLM is a 100% volunteer organization. We are a group of peo", + expected_reason="group of peo") + + +@pytest.mark.skip_global_cleanup +def test_stop_partial_token(vllm_model): + _test_stopping(vllm_model.model.llm_engine, + stop=["gani"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer or", + expected_reason="gani") + + _test_stopping(vllm_model.model.llm_engine, + stop=["gani"], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organi", + expected_reason="gani") + + +@pytest.mark.skip_global_cleanup +def test_stop_token_id(vllm_model): + # token id 13013 => " organization" + + _test_stopping(vllm_model.model.llm_engine, + stop_token_ids=[13013], + include_in_output=False, + expected_output="VLLM is a 100% volunteer", + expected_reason=13013) + + _test_stopping(vllm_model.model.llm_engine, + stop_token_ids=[13013], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=13013) + + +def _test_stopping(llm_engine: LLMEngine, + expected_output: str, + expected_reason: Any, + stop: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, + include_in_output: bool = False) -> None: + llm_engine.add_request( + "id", "A story about vLLM:\n", + SamplingParams( + temperature=0.0, + max_tokens=MAX_TOKENS, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_in_output, + ), None) + + output: Optional[CompletionOutput] = None + output_text = "" + stop_reason = None + while llm_engine.has_unfinished_requests(): + (request_output, ) = llm_engine.step() + (output, ) = request_output.outputs + + # Ensure we don't backtrack + assert output.text.startswith(output_text) + output_text = output.text + stop_reason = output.stop_reason + + assert output is not None + assert output_text == expected_output + assert stop_reason == expected_reason diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ddfdda898a5c6..a91629a630591 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -501,9 +501,11 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, for seq, _ in child_seqs: if seq_group.sampling_params.detokenize: - self.detokenizer.decode_sequence_inplace( + new_char_count = self.detokenizer.decode_sequence_inplace( seq, seq_group.sampling_params) - self._check_stop(seq, seq_group.sampling_params) + else: + new_char_count = 0 + self._check_stop(seq, new_char_count, seq_group.sampling_params) # Non-beam search case if not seq_group.sampling_params.use_beam_search: @@ -798,56 +800,86 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) - def _check_stop(self, seq: Sequence, + def _check_stop(self, seq: Sequence, new_char_count: int, sampling_params: SamplingParams) -> None: - """Stop the finished sequences.""" - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return + """Stop the finished sequences. - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ # Check if the minimum number of tokens has been generated yet; # skip the stop string/token checks if not if seq.get_output_len() < sampling_params.min_tokens: return - if sampling_params.detokenize: - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. last_token_id = seq.get_last_token_id() if last_token_id in sampling_params.stop_token_ids: - stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - last_token_id) - self._finalize_sequence(seq, sampling_params, stop_str) + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] seq.status = SequenceStatus.FINISHED_STOPPED seq.stop_reason = last_token_id return - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): + # Check if any stop strings are matched. + stop_str = self._check_stop_strings(seq, new_char_count, + sampling_params) + if stop_str is not None: seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str return - def _finalize_sequence(self, seq: Sequence, - sampling_params: SamplingParams, - stop_string: str) -> None: - if sampling_params.include_stop_str_in_output: + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return - if stop_string and seq.output_text.endswith(stop_string): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_string)] + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + @staticmethod + def _check_stop_strings(seq: Sequence, new_char_count: int, + sampling_params: SamplingParams) -> Optional[str]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns the stop string if matched or else None. + """ + if not new_char_count: + return None + + for stop_str in sampling_params.stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = seq.output_text.find( + stop_str, -new_char_count - stop_string_len) + if stop_index == -1: + continue + + if sampling_params.include_stop_str_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(seq.output_text): + # No truncation required. + return stop_str + + # Truncate the output text to either the beginning + # or end of the stop string. + seq.output_text = seq.output_text[:stop_index] + return stop_str + return None def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/outputs.py b/vllm/outputs.py index 61fe20bfc2744..d01be0eb0efd2 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -112,8 +112,10 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # always has the logprobs of the sampled tokens even if the # logprobs are not requested. include_logprobs = seq_group.sampling_params.logprobs is not None + text_buffer_length = seq_group.sampling_params.output_text_buffer_length outputs = [ - CompletionOutput(seqs.index(seq), seq.output_text, + CompletionOutput(seqs.index(seq), + seq.get_output_text_to_return(text_buffer_length), seq.get_output_token_ids(), seq.get_cumulative_logprob(), seq.output_logprobs if include_logprobs else None, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4fdc3c6dedaef..0b9787608798c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -166,6 +166,13 @@ def __init__( self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output self.truncate_prompt_tokens = truncate_prompt_tokens + # Number of characters to hold back for stop string evaluation + # until sequence is finished. + if self.stop and not include_stop_str_in_output: + self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 + else: + self.output_text_buffer_length = 0 + self._verify_args() if self.use_beam_search: self._verify_beam_search() @@ -226,6 +233,8 @@ def _verify_args(self) -> None: and self.truncate_prompt_tokens < 1): raise ValueError(f"truncate_prompt_tokens must be >= 1, " f"got {self.truncate_prompt_tokens}") + if any(not stop_str for stop_str in self.stop): + raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " diff --git a/vllm/sequence.py b/vllm/sequence.py index 77029908c2218..cdb6cce6f0255 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -235,6 +235,12 @@ def __init__( def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + def get_output_text_to_return(self, buffer_length: int): + # We return the full output text if the sequence is finished. + truncate = buffer_length and not self.is_finished() + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 486c1938e1e10..005932f1e3df4 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -87,12 +87,15 @@ def decode_prompt_logprobs_inplace( prev_tokens.extend(next_iter_tokens) def decode_sequence_inplace(self, seq: Sequence, - prms: SamplingParams) -> None: + prms: SamplingParams) -> int: """Decodes the new token for a sequence. In-place operation. Args: seq: The sequence to decode. prms: The sampling parameters used to generate the sequence. + + Returns: + The number of characters added to the output text. """ all_input_ids = seq.get_token_ids() token_id_generated_this_iteration = all_input_ids[-1] @@ -151,6 +154,8 @@ def decode_sequence_inplace(self, seq: Sequence, seq.read_offset = read_offset seq.output_text += new_decoded_token_text + return len(new_decoded_token_text) + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],