Skip to content

Commit

Permalink
Include matched stop string/token in responses
Browse files Browse the repository at this point in the history
Currently a finish_reason of "stop" is returned if any of the following are encountered:
- One of the provided stop strings
- One of the provided stop tokens
- The EOS token

It can be useful to know specifically which of these caused the sequence generation to stop, especially since by default the stop strings/tokens are omitted from the output text (and output token_ids?).

This PR adds a "stop_reason" field to the CompletionOutput class which will contain the matched stop string or integer token id. It will be None otherwise, including the EOS token case. This means in particular that EOS can be inferred by (finish_reason=="stop" and stop_reason=None).

I've also added to the openai server responses but not sure whether or not this should be included since it isn't part of the official API.
  • Loading branch information
njhill committed Feb 26, 2024
1 parent cfc15a1 commit bb6f831
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 6 deletions.
7 changes: 5 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,12 +933,15 @@ def _check_stop(self, seq: Sequence,
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
if seq.get_last_token_id() in sampling_params.stop_token_ids:
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
seq.get_last_token_id())
last_token_id)
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return

# Check if the sequence has reached max_model_len.
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class CompletionResponseChoice(BaseModel):
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None


class CompletionResponse(BaseModel):
Expand All @@ -203,6 +204,7 @@ class CompletionResponseStreamChoice(BaseModel):
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None


class CompletionStreamResponse(BaseModel):
Expand All @@ -224,6 +226,7 @@ class ChatCompletionResponseChoice(BaseModel):
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None


class ChatCompletionResponse(BaseModel):
Expand All @@ -245,6 +248,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None


class ChatCompletionStreamResponse(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ async def chat_completion_stream_generator(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason)
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
Expand Down Expand Up @@ -248,6 +249,7 @@ async def chat_completion_full_generator(
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
choices.append(choice_data)

Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ async def completion_stream_generator(
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
Expand All @@ -94,6 +95,7 @@ async def completion_stream_generator(
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
)
]).model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n"
Expand All @@ -117,6 +119,7 @@ async def completion_stream_generator(
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
],
usage=final_usage,
Expand Down Expand Up @@ -195,6 +198,7 @@ def request_output_to_completion_response(
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
choices.append(choice_data)

Expand Down
12 changes: 9 additions & 3 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Union
import time

from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
Expand All @@ -18,6 +18,9 @@ class CompletionOutput:
logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested.
finish_reason: The reason why the sequence is finished.
stop_reason: The stop string or token id that caused the completion to stop,
None if the completion finished for some other reason including
encountering the EOS token.
lora_request: The LoRA request that was used to generate the output.
"""

Expand All @@ -29,6 +32,7 @@ def __init__(
cumulative_logprob: float,
logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None,
stop_reason: Union[int, str, None] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.index = index
Expand All @@ -37,6 +41,7 @@ def __init__(
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
self.finish_reason = finish_reason
self.stop_reason = stop_reason
self.lora_request = lora_request

def finished(self) -> bool:
Expand All @@ -48,7 +53,8 @@ def __repr__(self) -> str:
f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason})")
f"finish_reason={self.finish_reason}, "
f"stop_reason={self.stop_reason})")


class RequestOutput:
Expand Down Expand Up @@ -111,7 +117,7 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
output = CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(),
seq.get_cumulative_logprob(), logprobs,
finshed_reason)
finshed_reason, seq.stop_reason)
outputs.append(output)

# Every sequence in the sequence group should have the same prompt.
Expand Down
1 change: 1 addition & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(prompt_token_ids)
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None

# Used for incremental detokenization
self.prefix_offset = 0
Expand Down

0 comments on commit bb6f831

Please sign in to comment.