forked from mesolitica/vllm-whisper
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] Adding token ranks along with logprobs (vllm-project#3516)
Co-authored-by: Swapnil Parekh <[email protected]>
- Loading branch information
1 parent
ae5c800
commit ca906f8
Showing
3 changed files
with
98 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import pytest | ||
from vllm import SamplingParams | ||
|
||
MODELS = ["facebook/opt-125m"] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
def test_ranks( | ||
vllm_runner, | ||
model, | ||
dtype, | ||
example_prompts, | ||
): | ||
max_tokens = 5 | ||
num_top_logprobs = 5 | ||
num_prompt_logprobs = 5 | ||
|
||
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) | ||
|
||
## Test greedy logprobs ranks | ||
vllm_sampling_params = SamplingParams(temperature=0.0, | ||
top_p=1.0, | ||
max_tokens=max_tokens, | ||
logprobs=num_top_logprobs, | ||
prompt_logprobs=num_prompt_logprobs) | ||
vllm_results = vllm_model.generate_w_logprobs(example_prompts, | ||
vllm_sampling_params) | ||
for result in vllm_results: | ||
assert result[2] is not None | ||
assert len(result[2]) == len(result[0]) | ||
# check whether all chosen tokens have ranks = 1 | ||
for token, logprobs in zip(result[0], result[2]): | ||
assert token in logprobs | ||
assert logprobs[token].rank == 1 | ||
|
||
## Test non-greedy logprobs ranks | ||
sampling_params = SamplingParams(temperature=1.0, | ||
top_p=1.0, | ||
max_tokens=max_tokens, | ||
logprobs=num_top_logprobs, | ||
prompt_logprobs=num_prompt_logprobs) | ||
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params) | ||
for result in res: | ||
assert result[2] is not None | ||
assert len(result[2]) == len(result[0]) | ||
# check whether all chosen tokens have ranks | ||
for token, logprobs in zip(result[0], result[2]): | ||
assert logprobs[token].rank >= 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters