From 24dd6daedb9de34f076ea77a3ae803bb17674d23 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 21 Feb 2024 17:53:03 -0800 Subject: [PATCH 1/3] Include matched stop string/token in responses Currently a finish_reason of "stop" is returned if any of the following are encountered: - One of the provided stop strings - One of the provided stop tokens - The EOS token It can be useful to know specifically which of these caused the sequence generation to stop, especially since by default the stop strings/tokens are omitted from the output text (and output token_ids?). This PR adds a "stop_reason" field to the CompletionOutput class which will contain the matched stop string or integer token id. It will be None otherwise, including the EOS token case. This means in particular that EOS can be inferred by (finish_reason=="stop" and stop_reason=None). I've also added to the openai server responses but not sure whether or not this should be included since it isn't part of the official API. --- vllm/engine/llm_engine.py | 7 +++++-- vllm/entrypoints/openai/protocol.py | 4 ++++ vllm/entrypoints/openai/serving_chat.py | 4 +++- vllm/entrypoints/openai/serving_completion.py | 3 +++ vllm/outputs.py | 14 ++++++++++---- vllm/sequence.py | 1 + 6 files changed, 26 insertions(+), 7 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1984b94024a16..5e37ee2ff5a74 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -735,12 +735,15 @@ def _check_stop(self, seq: Sequence, if seq.output_text.endswith(stop_str): self._finalize_sequence(seq, sampling_params, stop_str) seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str return - if seq.get_last_token_id() in sampling_params.stop_token_ids: + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - seq.get_last_token_id()) + last_token_id) self._finalize_sequence(seq, sampling_params, stop_str) seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id return # Check if the sequence has generated the EOS token. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 965313e29f8d4..4061c2eab33b0 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -338,6 +338,7 @@ class CompletionResponseChoice(BaseModel): text: str logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None + stop_reason: Union[None, int, str] = None class CompletionResponse(BaseModel): @@ -354,6 +355,7 @@ class CompletionResponseStreamChoice(BaseModel): text: str logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None + stop_reason: Union[None, int, str] = None class CompletionStreamResponse(BaseModel): @@ -375,6 +377,7 @@ class ChatCompletionResponseChoice(BaseModel): message: ChatMessage logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None + stop_reason: Union[None, int, str] = None class ChatCompletionResponse(BaseModel): @@ -396,6 +399,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): delta: DeltaMessage logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None + stop_reason: Union[None, int, str] = None class ChatCompletionStreamResponse(BaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0de80f04e51f3..0980c3d3cb614 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -220,7 +220,8 @@ async def chat_completion_stream_generator( index=i, delta=DeltaMessage(content=delta_text), logprobs=logprobs, - finish_reason=output.finish_reason) + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, @@ -278,6 +279,7 @@ async def chat_completion_full_generator( message=ChatMessage(role=role, content=output.text), logprobs=logprobs, finish_reason=output.finish_reason, + stop_reason=output.stop_reason, ) choices.append(choice_data) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9d5319c857109..ff435e4385885 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -266,6 +266,7 @@ async def completion_stream_generator( previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) finish_reason = output.finish_reason + stop_reason = output.stop_reason if output.finish_reason is not None: # return final usage prompt_tokens = len(res.prompt_token_ids) completion_tokens = len(output.token_ids) @@ -286,6 +287,7 @@ async def completion_stream_generator( text=delta_text, logprobs=logprobs, finish_reason=finish_reason, + stop_reason=stop_reason, ) ], usage=final_usage, @@ -342,6 +344,7 @@ def request_output_to_completion_response( text=output_text, logprobs=logprobs, finish_reason=output.finish_reason, + stop_reason=output.stop_reason, ) choices.append(choice_data) diff --git a/vllm/outputs.py b/vllm/outputs.py index accc18ad41aa8..2a955419352a8 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,5 +1,5 @@ import time -from typing import List, Optional +from typing import List, Optional, Union from vllm.lora.request import LoRARequest from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, @@ -18,6 +18,9 @@ class CompletionOutput: logprobs: The log probabilities of the top probability words at each position if the logprobs are requested. finish_reason: The reason why the sequence is finished. + stop_reason: The stop string or token id that caused the completion + to stop, None if the completion finished for some other reason + including encountering the EOS token. lora_request: The LoRA request that was used to generate the output. """ @@ -29,6 +32,7 @@ def __init__( cumulative_logprob: float, logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, + stop_reason: Union[int, str, None] = None, lora_request: Optional[LoRARequest] = None, ) -> None: self.index = index @@ -37,6 +41,7 @@ def __init__( self.cumulative_logprob = cumulative_logprob self.logprobs = logprobs self.finish_reason = finish_reason + self.stop_reason = stop_reason self.lora_request = lora_request def finished(self) -> bool: @@ -48,7 +53,8 @@ def __repr__(self) -> str: f"token_ids={self.token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " f"logprobs={self.logprobs}, " - f"finish_reason={self.finish_reason})") + f"finish_reason={self.finish_reason}, " + f"stop_reason={self.stop_reason})") class RequestOutput: @@ -111,8 +117,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": seq.get_output_token_ids(), seq.get_cumulative_logprob(), seq.output_logprobs if include_logprobs else None, - SequenceStatus.get_finished_reason(seq.status)) - for seq in top_n_seqs + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason) for seq in top_n_seqs ] # Every sequence in the sequence group should have the same prompt. diff --git a/vllm/sequence.py b/vllm/sequence.py index 8b2855daa5525..4b4162928d9b9 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -183,6 +183,7 @@ def __init__( # Initialize the logical token blocks with the prompt token ids. self._append_tokens_to_blocks(prompt_token_ids) self.status = SequenceStatus.WAITING + self.stop_reason: Union[int, str, None] = None # Used for incremental detokenization self.prefix_offset = 0 From 59073e32178677292165fadcef21c5821e687c4b Mon Sep 17 00:00:00 2001 From: Sahil Suneja Date: Mon, 26 Feb 2024 23:57:15 +0000 Subject: [PATCH 2/3] test for stop_reason --- tests/samplers/test_stop_reason.py | 59 ++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 tests/samplers/test_stop_reason.py diff --git a/tests/samplers/test_stop_reason.py b/tests/samplers/test_stop_reason.py new file mode 100644 index 0000000000000..b242c405a4fb6 --- /dev/null +++ b/tests/samplers/test_stop_reason.py @@ -0,0 +1,59 @@ +"""Test the different finish_reason="stop" situations during generation: + 1. One of the provided stop strings + 2. One of the provided stop tokens + 3. The EOS token + +Run `pytest tests/samplers/test_stop_reason.py`. +""" + +import pytest +import transformers + +from vllm import SamplingParams + +MODEL = "facebook/opt-350m" +STOP_STR = "." +SEED = 42 +MAX_TOKENS = 1024 + + +@pytest.fixture +def vllm_model(vllm_runner): + vllm_model = vllm_runner(MODEL) + yield vllm_model + del vllm_model + + +def test_stop_reason(vllm_model, example_prompts): + tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL) + stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR) + llm = vllm_model.model + + # test stop token + outputs = llm.generate(example_prompts, + sampling_params=SamplingParams( + seed=SEED, + max_tokens=MAX_TOKENS, + stop_token_ids=[stop_token_id])) + for output in outputs: + output = output.outputs[0] + assert output.finish_reason == "stop" + assert output.stop_reason == stop_token_id + + # test stop string + outputs = llm.generate(example_prompts, + sampling_params=SamplingParams( + seed=SEED, max_tokens=MAX_TOKENS, stop=".")) + for output in outputs: + output = output.outputs[0] + assert output.finish_reason == "stop" + assert output.stop_reason == STOP_STR + + # test EOS token + outputs = llm.generate(example_prompts, + sampling_params=SamplingParams( + seed=SEED, max_tokens=MAX_TOKENS)) + for output in outputs: + output = output.outputs[0] + assert output.finish_reason == "length" or ( + output.finish_reason == "stop" and output.stop_reason is None) From 00a8f71418eacff9ce16f6e75fe5c72cc9bb0798 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 25 Mar 2024 16:38:45 -0700 Subject: [PATCH 3/3] Add pydantic descriptions to the new openai response fields --- vllm/entrypoints/openai/protocol.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4061c2eab33b0..af52f543d5411 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -338,7 +338,13 @@ class CompletionResponseChoice(BaseModel): text: str logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + stop_reason: Union[None, int, str] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) class CompletionResponse(BaseModel): @@ -355,7 +361,13 @@ class CompletionResponseStreamChoice(BaseModel): text: str logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + stop_reason: Union[None, int, str] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) class CompletionStreamResponse(BaseModel):