From ca906f8620ec26467f3ea8022d9077ffe28d1dcf Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Mon, 25 Mar 2024 13:13:10 -0400 Subject: [PATCH] [Core] Adding token ranks along with logprobs (#3516) Co-authored-by: Swapnil Parekh --- tests/samplers/test_ranks.py | 49 ++++++++++++++++++++++++++ vllm/model_executor/layers/sampler.py | 50 +++++++++++++++++++++------ vllm/sequence.py | 11 ++++-- 3 files changed, 98 insertions(+), 12 deletions(-) create mode 100644 tests/samplers/test_ranks.py diff --git a/tests/samplers/test_ranks.py b/tests/samplers/test_ranks.py new file mode 100644 index 0000000000000..7f6f1c0093154 --- /dev/null +++ b/tests/samplers/test_ranks.py @@ -0,0 +1,49 @@ +import pytest +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_ranks( + vllm_runner, + model, + dtype, + example_prompts, +): + max_tokens = 5 + num_top_logprobs = 5 + num_prompt_logprobs = 5 + + vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) + + ## Test greedy logprobs ranks + vllm_sampling_params = SamplingParams(temperature=0.0, + top_p=1.0, + max_tokens=max_tokens, + logprobs=num_top_logprobs, + prompt_logprobs=num_prompt_logprobs) + vllm_results = vllm_model.generate_w_logprobs(example_prompts, + vllm_sampling_params) + for result in vllm_results: + assert result[2] is not None + assert len(result[2]) == len(result[0]) + # check whether all chosen tokens have ranks = 1 + for token, logprobs in zip(result[0], result[2]): + assert token in logprobs + assert logprobs[token].rank == 1 + + ## Test non-greedy logprobs ranks + sampling_params = SamplingParams(temperature=1.0, + top_p=1.0, + max_tokens=max_tokens, + logprobs=num_top_logprobs, + prompt_logprobs=num_prompt_logprobs) + res = vllm_model.generate_w_logprobs(example_prompts, sampling_params) + for result in res: + assert result[2] is not None + assert len(result[2]) == len(result[0]) + # check whether all chosen tokens have ranks + for token, logprobs in zip(result[0], result[2]): + assert logprobs[token].rank >= 1 diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 84b2125c0b09c..162d2abb292aa 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -465,6 +465,24 @@ def _sample( # sampling_tensors) +def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor: + """ + This function calculates the ranks of the chosen tokens in a logprob tensor. + + Args: + x (torch.Tensor): 2D logprob tensor of shape (N, M) + where N is the no. of tokens and M is the vocab dim. + indices (List[int]): List of chosen token indices. + + Returns: + torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. + Each element in the returned tensor represents the rank + of the chosen token in the input logprob tensor. + """ + vals = x[range(len(x)), indices] + return (x > vals[:, None]).long().sum(1) + 1 + + def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -520,6 +538,10 @@ def _get_logprobs( batched_logprobs_query_result = batched_logprobs_query_result.cpu() + batched_ranks_query_result = _get_ranks( + logprobs[batched_logprobs_query_seq_indices], + batched_logprobs_query_token_indices) + # Gather results result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] result_sample_logprobs: List[SampleLogprobs] = [] @@ -540,15 +562,20 @@ def _get_logprobs( for token_id in prompt_tokens[1:]: prompt_logprobs_dict = { token_id: - batched_logprobs_query_result[query_result_idx].item() + (batched_logprobs_query_result[query_result_idx].item(), + batched_ranks_query_result[query_result_idx].item()) } if num_logprobs > 0: prompt_logprobs_dict.update( - zip(top_token_ids[sample_idx, :num_logprobs].tolist(), - top_logprobs[sample_idx, :num_logprobs].tolist())) + zip( + top_token_ids[sample_idx, :num_logprobs].tolist(), + zip( + top_logprobs[ + sample_idx, :num_logprobs].tolist(), + range(1, num_logprobs + 1)))) group_prompt_logprobs.append({ - token_id: Logprob(logprob) - for token_id, logprob in prompt_logprobs_dict.items() + token_id: Logprob(*logprob_rank) + for token_id, logprob_rank in prompt_logprobs_dict.items() }) sample_idx += 1 query_result_idx += 1 @@ -564,7 +591,8 @@ def _get_logprobs( for next_token_id, parent_id in zip(next_token_ids, parent_ids): sample_logprobs_dict = { next_token_id: - batched_logprobs_query_result[query_result_idx].item() + (batched_logprobs_query_result[query_result_idx].item(), + batched_ranks_query_result[query_result_idx].item()) } query_result_idx += 1 if num_logprobs > 0: @@ -572,11 +600,13 @@ def _get_logprobs( zip( top_token_ids[sample_idx + parent_id, :num_logprobs].tolist(), - top_logprobs[sample_idx + - parent_id, :num_logprobs].tolist())) + zip( + top_logprobs[sample_idx + + parent_id, :num_logprobs].tolist(), + range(1, num_logprobs + 1)))) group_sample_logprobs.append({ - token_id: Logprob(logprob) - for token_id, logprob in sample_logprobs_dict.items() + token_id: Logprob(*logprob_rank) + for token_id, logprob_rank in sample_logprobs_dict.items() }) result_sample_logprobs.append(group_sample_logprobs) sample_idx += len(seq_ids) diff --git a/vllm/sequence.py b/vllm/sequence.py index 72f16579c83c6..8b2855daa5525 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -16,8 +16,15 @@ @dataclass class Logprob: - """Infos for supporting OpenAI compatible logprobs.""" + """Infos for supporting OpenAI compatible logprobs and token ranks. + + Attributes: + logprob: The logprob of chosen token + rank: The vocab rank of chosen token (>=1) + decoded_token: The decoded chosen token index + """ logprob: float + rank: Optional[int] = None decoded_token: Optional[str] = None @@ -66,7 +73,7 @@ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: class RequestMetrics: """Metrics associated with a request. - Args: + Attributes: arrival_time: The time when the request arrived. first_scheduled_time: The time when the request was first scheduled. first_token_time: The time when the first token was generated.