forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Misc] Include matched stop string/token in responses (vllm-project#2976
) Co-authored-by: Sahil Suneja <[email protected]>
- Loading branch information
1 parent
19d7628
commit bc3ea46
Showing
7 changed files
with
97 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""Test the different finish_reason="stop" situations during generation: | ||
1. One of the provided stop strings | ||
2. One of the provided stop tokens | ||
3. The EOS token | ||
Run `pytest tests/samplers/test_stop_reason.py`. | ||
""" | ||
|
||
import pytest | ||
import transformers | ||
|
||
from vllm import SamplingParams | ||
|
||
MODEL = "facebook/opt-350m" | ||
STOP_STR = "." | ||
SEED = 42 | ||
MAX_TOKENS = 1024 | ||
|
||
|
||
@pytest.fixture | ||
def vllm_model(vllm_runner): | ||
vllm_model = vllm_runner(MODEL) | ||
yield vllm_model | ||
del vllm_model | ||
|
||
|
||
def test_stop_reason(vllm_model, example_prompts): | ||
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL) | ||
stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR) | ||
llm = vllm_model.model | ||
|
||
# test stop token | ||
outputs = llm.generate(example_prompts, | ||
sampling_params=SamplingParams( | ||
seed=SEED, | ||
max_tokens=MAX_TOKENS, | ||
stop_token_ids=[stop_token_id])) | ||
for output in outputs: | ||
output = output.outputs[0] | ||
assert output.finish_reason == "stop" | ||
assert output.stop_reason == stop_token_id | ||
|
||
# test stop string | ||
outputs = llm.generate(example_prompts, | ||
sampling_params=SamplingParams( | ||
seed=SEED, max_tokens=MAX_TOKENS, stop=".")) | ||
for output in outputs: | ||
output = output.outputs[0] | ||
assert output.finish_reason == "stop" | ||
assert output.stop_reason == STOP_STR | ||
|
||
# test EOS token | ||
outputs = llm.generate(example_prompts, | ||
sampling_params=SamplingParams( | ||
seed=SEED, max_tokens=MAX_TOKENS)) | ||
for output in outputs: | ||
output = output.outputs[0] | ||
assert output.finish_reason == "length" or ( | ||
output.finish_reason == "stop" and output.stop_reason is None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters