diff --git a/tests/samplers/test_stop_reason.py b/tests/samplers/test_stop_reason.py new file mode 100644 index 0000000000000..b242c405a4fb6 --- /dev/null +++ b/tests/samplers/test_stop_reason.py @@ -0,0 +1,59 @@ +"""Test the different finish_reason="stop" situations during generation: + 1. One of the provided stop strings + 2. One of the provided stop tokens + 3. The EOS token + +Run `pytest tests/samplers/test_stop_reason.py`. +""" + +import pytest +import transformers + +from vllm import SamplingParams + +MODEL = "facebook/opt-350m" +STOP_STR = "." +SEED = 42 +MAX_TOKENS = 1024 + + +@pytest.fixture +def vllm_model(vllm_runner): + vllm_model = vllm_runner(MODEL) + yield vllm_model + del vllm_model + + +def test_stop_reason(vllm_model, example_prompts): + tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL) + stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR) + llm = vllm_model.model + + # test stop token + outputs = llm.generate(example_prompts, + sampling_params=SamplingParams( + seed=SEED, + max_tokens=MAX_TOKENS, + stop_token_ids=[stop_token_id])) + for output in outputs: + output = output.outputs[0] + assert output.finish_reason == "stop" + assert output.stop_reason == stop_token_id + + # test stop string + outputs = llm.generate(example_prompts, + sampling_params=SamplingParams( + seed=SEED, max_tokens=MAX_TOKENS, stop=".")) + for output in outputs: + output = output.outputs[0] + assert output.finish_reason == "stop" + assert output.stop_reason == STOP_STR + + # test EOS token + outputs = llm.generate(example_prompts, + sampling_params=SamplingParams( + seed=SEED, max_tokens=MAX_TOKENS)) + for output in outputs: + output = output.outputs[0] + assert output.finish_reason == "length" or ( + output.finish_reason == "stop" and output.stop_reason is None) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 144829739f681..1c688397b1f4d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -740,12 +740,15 @@ def _check_stop(self, seq: Sequence, if seq.output_text.endswith(stop_str): self._finalize_sequence(seq, sampling_params, stop_str) seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str return - if seq.get_last_token_id() in sampling_params.stop_token_ids: + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - seq.get_last_token_id()) + last_token_id) self._finalize_sequence(seq, sampling_params, stop_str) seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id return # Check if the sequence has generated the EOS token. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 965313e29f8d4..af52f543d5411 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -338,6 +338,13 @@ class CompletionResponseChoice(BaseModel): text: str logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None + stop_reason: Union[None, int, str] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) class CompletionResponse(BaseModel): @@ -354,6 +361,13 @@ class CompletionResponseStreamChoice(BaseModel): text: str logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None + stop_reason: Union[None, int, str] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) class CompletionStreamResponse(BaseModel): @@ -375,6 +389,7 @@ class ChatCompletionResponseChoice(BaseModel): message: ChatMessage logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None + stop_reason: Union[None, int, str] = None class ChatCompletionResponse(BaseModel): @@ -396,6 +411,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): delta: DeltaMessage logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None + stop_reason: Union[None, int, str] = None class ChatCompletionStreamResponse(BaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0de80f04e51f3..0980c3d3cb614 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -220,7 +220,8 @@ async def chat_completion_stream_generator( index=i, delta=DeltaMessage(content=delta_text), logprobs=logprobs, - finish_reason=output.finish_reason) + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, @@ -278,6 +279,7 @@ async def chat_completion_full_generator( message=ChatMessage(role=role, content=output.text), logprobs=logprobs, finish_reason=output.finish_reason, + stop_reason=output.stop_reason, ) choices.append(choice_data) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9d5319c857109..ff435e4385885 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -266,6 +266,7 @@ async def completion_stream_generator( previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) finish_reason = output.finish_reason + stop_reason = output.stop_reason if output.finish_reason is not None: # return final usage prompt_tokens = len(res.prompt_token_ids) completion_tokens = len(output.token_ids) @@ -286,6 +287,7 @@ async def completion_stream_generator( text=delta_text, logprobs=logprobs, finish_reason=finish_reason, + stop_reason=stop_reason, ) ], usage=final_usage, @@ -342,6 +344,7 @@ def request_output_to_completion_response( text=output_text, logprobs=logprobs, finish_reason=output.finish_reason, + stop_reason=output.stop_reason, ) choices.append(choice_data) diff --git a/vllm/outputs.py b/vllm/outputs.py index accc18ad41aa8..2a955419352a8 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,5 +1,5 @@ import time -from typing import List, Optional +from typing import List, Optional, Union from vllm.lora.request import LoRARequest from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, @@ -18,6 +18,9 @@ class CompletionOutput: logprobs: The log probabilities of the top probability words at each position if the logprobs are requested. finish_reason: The reason why the sequence is finished. + stop_reason: The stop string or token id that caused the completion + to stop, None if the completion finished for some other reason + including encountering the EOS token. lora_request: The LoRA request that was used to generate the output. """ @@ -29,6 +32,7 @@ def __init__( cumulative_logprob: float, logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, + stop_reason: Union[int, str, None] = None, lora_request: Optional[LoRARequest] = None, ) -> None: self.index = index @@ -37,6 +41,7 @@ def __init__( self.cumulative_logprob = cumulative_logprob self.logprobs = logprobs self.finish_reason = finish_reason + self.stop_reason = stop_reason self.lora_request = lora_request def finished(self) -> bool: @@ -48,7 +53,8 @@ def __repr__(self) -> str: f"token_ids={self.token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " f"logprobs={self.logprobs}, " - f"finish_reason={self.finish_reason})") + f"finish_reason={self.finish_reason}, " + f"stop_reason={self.stop_reason})") class RequestOutput: @@ -111,8 +117,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": seq.get_output_token_ids(), seq.get_cumulative_logprob(), seq.output_logprobs if include_logprobs else None, - SequenceStatus.get_finished_reason(seq.status)) - for seq in top_n_seqs + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason) for seq in top_n_seqs ] # Every sequence in the sequence group should have the same prompt. diff --git a/vllm/sequence.py b/vllm/sequence.py index bf3679c312ddc..b019b5bf5802c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -183,6 +183,7 @@ def __init__( # Initialize the logical token blocks with the prompt token ids. self._append_tokens_to_blocks(prompt_token_ids) self.status = SequenceStatus.WAITING + self.stop_reason: Union[int, str, None] = None # Used for incremental detokenization self.prefix_offset = 0