Skip to content

Commit

Permalink
[Core] Adding token ranks along with logprobs (vllm-project#3516)
Browse files Browse the repository at this point in the history
Co-authored-by: Swapnil Parekh <[email protected]>
  • Loading branch information
SwapnilDreams100 and Swapnil Parekh authored Mar 25, 2024
1 parent ae5c800 commit ca906f8
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 12 deletions.
49 changes: 49 additions & 0 deletions tests/samplers/test_ranks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from vllm import SamplingParams

MODELS = ["facebook/opt-125m"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_ranks(
vllm_runner,
model,
dtype,
example_prompts,
):
max_tokens = 5
num_top_logprobs = 5
num_prompt_logprobs = 5

vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)

## Test greedy logprobs ranks
vllm_sampling_params = SamplingParams(temperature=0.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs)
vllm_results = vllm_model.generate_w_logprobs(example_prompts,
vllm_sampling_params)
for result in vllm_results:
assert result[2] is not None
assert len(result[2]) == len(result[0])
# check whether all chosen tokens have ranks = 1
for token, logprobs in zip(result[0], result[2]):
assert token in logprobs
assert logprobs[token].rank == 1

## Test non-greedy logprobs ranks
sampling_params = SamplingParams(temperature=1.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs)
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
for result in res:
assert result[2] is not None
assert len(result[2]) == len(result[0])
# check whether all chosen tokens have ranks
for token, logprobs in zip(result[0], result[2]):
assert logprobs[token].rank >= 1
50 changes: 40 additions & 10 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,24 @@ def _sample(
# sampling_tensors)


def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor:
"""
This function calculates the ranks of the chosen tokens in a logprob tensor.
Args:
x (torch.Tensor): 2D logprob tensor of shape (N, M)
where N is the no. of tokens and M is the vocab dim.
indices (List[int]): List of chosen token indices.
Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
"""
vals = x[range(len(x)), indices]
return (x > vals[:, None]).long().sum(1) + 1


def _get_logprobs(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
Expand Down Expand Up @@ -520,6 +538,10 @@ def _get_logprobs(

batched_logprobs_query_result = batched_logprobs_query_result.cpu()

batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices],
batched_logprobs_query_token_indices)

# Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
result_sample_logprobs: List[SampleLogprobs] = []
Expand All @@ -540,15 +562,20 @@ def _get_logprobs(
for token_id in prompt_tokens[1:]:
prompt_logprobs_dict = {
token_id:
batched_logprobs_query_result[query_result_idx].item()
(batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
}
if num_logprobs > 0:
prompt_logprobs_dict.update(
zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
top_logprobs[sample_idx, :num_logprobs].tolist()))
zip(
top_token_ids[sample_idx, :num_logprobs].tolist(),
zip(
top_logprobs[
sample_idx, :num_logprobs].tolist(),
range(1, num_logprobs + 1))))
group_prompt_logprobs.append({
token_id: Logprob(logprob)
for token_id, logprob in prompt_logprobs_dict.items()
token_id: Logprob(*logprob_rank)
for token_id, logprob_rank in prompt_logprobs_dict.items()
})
sample_idx += 1
query_result_idx += 1
Expand All @@ -564,19 +591,22 @@ def _get_logprobs(
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
sample_logprobs_dict = {
next_token_id:
batched_logprobs_query_result[query_result_idx].item()
(batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
}
query_result_idx += 1
if num_logprobs > 0:
sample_logprobs_dict.update(
zip(
top_token_ids[sample_idx +
parent_id, :num_logprobs].tolist(),
top_logprobs[sample_idx +
parent_id, :num_logprobs].tolist()))
zip(
top_logprobs[sample_idx +
parent_id, :num_logprobs].tolist(),
range(1, num_logprobs + 1))))
group_sample_logprobs.append({
token_id: Logprob(logprob)
for token_id, logprob in sample_logprobs_dict.items()
token_id: Logprob(*logprob_rank)
for token_id, logprob_rank in sample_logprobs_dict.items()
})
result_sample_logprobs.append(group_sample_logprobs)
sample_idx += len(seq_ids)
Expand Down
11 changes: 9 additions & 2 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@

@dataclass
class Logprob:
"""Infos for supporting OpenAI compatible logprobs."""
"""Infos for supporting OpenAI compatible logprobs and token ranks.
Attributes:
logprob: The logprob of chosen token
rank: The vocab rank of chosen token (>=1)
decoded_token: The decoded chosen token index
"""
logprob: float
rank: Optional[int] = None
decoded_token: Optional[str] = None


Expand Down Expand Up @@ -66,7 +73,7 @@ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
class RequestMetrics:
"""Metrics associated with a request.
Args:
Attributes:
arrival_time: The time when the request arrived.
first_scheduled_time: The time when the request was first scheduled.
first_token_time: The time when the first token was generated.
Expand Down

0 comments on commit ca906f8

Please sign in to comment.