Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Include matched stop string/token in responses #2976

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions tests/samplers/test_stop_reason.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Test the different finish_reason="stop" situations during generation:
1. One of the provided stop strings
2. One of the provided stop tokens
3. The EOS token

Run `pytest tests/samplers/test_stop_reason.py`.
"""

import pytest
import transformers

from vllm import SamplingParams

MODEL = "facebook/opt-350m"
STOP_STR = "."
SEED = 42
MAX_TOKENS = 1024


@pytest.fixture
def vllm_model(vllm_runner):
vllm_model = vllm_runner(MODEL)
yield vllm_model
del vllm_model


def test_stop_reason(vllm_model, example_prompts):
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
llm = vllm_model.model

# test stop token
outputs = llm.generate(example_prompts,
sampling_params=SamplingParams(
seed=SEED,
max_tokens=MAX_TOKENS,
stop_token_ids=[stop_token_id]))
for output in outputs:
output = output.outputs[0]
assert output.finish_reason == "stop"
assert output.stop_reason == stop_token_id

# test stop string
outputs = llm.generate(example_prompts,
sampling_params=SamplingParams(
seed=SEED, max_tokens=MAX_TOKENS, stop="."))
for output in outputs:
output = output.outputs[0]
assert output.finish_reason == "stop"
assert output.stop_reason == STOP_STR

# test EOS token
outputs = llm.generate(example_prompts,
sampling_params=SamplingParams(
seed=SEED, max_tokens=MAX_TOKENS))
for output in outputs:
output = output.outputs[0]
assert output.finish_reason == "length" or (
output.finish_reason == "stop" and output.stop_reason is None)
7 changes: 5 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,15 @@ def _check_stop(self, seq: Sequence,
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
if seq.get_last_token_id() in sampling_params.stop_token_ids:
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
seq.get_last_token_id())
last_token_id)
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return

# Check if the sequence has generated the EOS token.
Expand Down
16 changes: 16 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,13 @@ class CompletionResponseChoice(BaseModel):
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)


class CompletionResponse(BaseModel):
Expand All @@ -354,6 +361,13 @@ class CompletionResponseStreamChoice(BaseModel):
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)


class CompletionStreamResponse(BaseModel):
Expand All @@ -375,6 +389,7 @@ class ChatCompletionResponseChoice(BaseModel):
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None


class ChatCompletionResponse(BaseModel):
Expand All @@ -396,6 +411,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None


class ChatCompletionStreamResponse(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ async def chat_completion_stream_generator(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason)
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
Expand Down Expand Up @@ -278,6 +279,7 @@ async def chat_completion_full_generator(
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
choices.append(choice_data)

Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ async def completion_stream_generator(
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
if output.finish_reason is not None: # return final usage
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
Expand All @@ -286,6 +287,7 @@ async def completion_stream_generator(
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
)
],
usage=final_usage,
Expand Down Expand Up @@ -342,6 +344,7 @@ def request_output_to_completion_response(
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
choices.append(choice_data)

Expand Down
14 changes: 10 additions & 4 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import List, Optional
from typing import List, Optional, Union

from vllm.lora.request import LoRARequest
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
Expand All @@ -18,6 +18,9 @@ class CompletionOutput:
logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested.
finish_reason: The reason why the sequence is finished.
stop_reason: The stop string or token id that caused the completion
to stop, None if the completion finished for some other reason
including encountering the EOS token.
lora_request: The LoRA request that was used to generate the output.
"""

Expand All @@ -29,6 +32,7 @@ def __init__(
cumulative_logprob: float,
logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None,
stop_reason: Union[int, str, None] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.index = index
Expand All @@ -37,6 +41,7 @@ def __init__(
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
self.finish_reason = finish_reason
self.stop_reason = stop_reason
self.lora_request = lora_request

def finished(self) -> bool:
Expand All @@ -48,7 +53,8 @@ def __repr__(self) -> str:
f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason})")
f"finish_reason={self.finish_reason}, "
f"stop_reason={self.stop_reason})")


class RequestOutput:
Expand Down Expand Up @@ -111,8 +117,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status))
for seq in top_n_seqs
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason) for seq in top_n_seqs
]

# Every sequence in the sequence group should have the same prompt.
Expand Down
1 change: 1 addition & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(prompt_token_ids)
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None

# Used for incremental detokenization
self.prefix_offset = 0
Expand Down
Loading