Skip to content

Commit

Permalink
[Core] *Prompt* logprobs support in Multi-step (vllm-project#8199)
Browse files Browse the repository at this point in the history
Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
afeldman-nm authored and sumitd2 committed Nov 14, 2024
1 parent 25c1fc7 commit 9640ac5
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 59 deletions.
84 changes: 52 additions & 32 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from tests.models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs)
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
Expand All @@ -33,7 +35,6 @@
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity, is_cpu)

Expand Down Expand Up @@ -469,7 +470,7 @@ def generate_greedy_logprobs_limit(
audios: Optional[PromptAudioInput] = None,
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
) -> List[TokensTextLogprobs]:
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
Expand Down Expand Up @@ -525,7 +526,7 @@ def generate_encoder_decoder_greedy_logprobs_limit(
max_tokens: int,
num_logprobs: int,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
) -> List[TokensTextLogprobs]:
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''
Expand Down Expand Up @@ -653,14 +654,16 @@ def generate(
@staticmethod
def _final_steps_generate_w_logprobs(
req_outputs: List[RequestOutput],
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
) -> List[TokensTextLogprobsPromptLogprobs]:
outputs: List[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs:
assert len(req_output.outputs) > 0
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
outputs.append((output_ids, output_str, output_logprobs,
req_output.prompt_logprobs))
return outputs

def generate_w_logprobs(
Expand All @@ -670,7 +673,8 @@ def generate_w_logprobs(
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
assert sampling_params.logprobs is not None

if images is not None:
Expand All @@ -695,21 +699,33 @@ def generate_w_logprobs(

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)

toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)

def generate_encoder_decoder_w_logprobs(
self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
'''
Logprobs generation for vLLM encoder/decoder models
'''

assert sampling_params.logprobs is not None
req_outputs = self.model.generate(encoder_decoder_prompts,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)

def generate_greedy(
self,
Expand All @@ -727,44 +743,48 @@ def generate_greedy_logprobs(
prompts: List[str],
max_tokens: int,
num_logprobs: int,
num_prompt_logprobs: Optional[int] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)

return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
stop_token_ids=stop_token_ids)

return self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)

def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens,
logprobs=num_logprobs)
num_prompt_logprobs: Optional[int] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
)
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''

outputs = self.generate_encoder_decoder_w_logprobs(
return self.generate_encoder_decoder_w_logprobs(
encoder_decoder_prompts, greedy_logprobs_params)

return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]

def generate_beam_search(
self,
prompts: List[str],
Expand Down
108 changes: 102 additions & 6 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

from vllm.sequence import Logprob, SampleLogprobs
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs

TokensText = Tuple[List[int], str]

Expand Down Expand Up @@ -34,20 +34,47 @@ def check_outputs_equal(
assert output_ids_0 == output_ids_1, fail_msg


# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * List of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]],
SampleLogprobs]]]

# Allow for tokens to be represented as str's rather than IDs
# Allow for tokens to be represented as str's rather than IDs;
# tuple of
# * Token string representations list
# * String
# * Optional list of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
List[Dict[str,
Logprob]]]]]

# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * Optional list of top sample logprobs for each sampled token
# * Optional list of top prompt logprobs for each prompt token
#
# Allows prompt logprobs to be requested.
TokensTextLogprobsPromptLogprobs = Tuple[
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]


def check_logprobs_close(
*,
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
name_0: str,
name_1: str,
num_outputs_0_skip_tokens: int = 0,
Expand All @@ -57,6 +84,18 @@ def check_logprobs_close(
"""Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
How sample logprobs are compared:
* `always_check_logprobs == True`: set of highest-logprob token ids
must match between seq0 and seq1 at all sampled token offsets
* `always_check_logprobs == False`: highest-logprob token ids are
only compared at sampled token offsets for which generated token
ids don't match
Prompt logprobs must be provided either for both input sequences, or
for neither. If prompt logprobs are provided, then highest-logprob
prompt token ids must match between seq0 and seq1 at all prompt token
offsets.
Args:
outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare
Expand All @@ -78,8 +117,65 @@ def check_logprobs_close(
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
assert len(outputs_0) == len(outputs_1)
if len(outputs_0) == 3:
assert len(outputs_1) == 3
# Break out tokens, text & sample logprobs
# (prompt logprobs were not provided)
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
elif len(outputs_0) == 4:
assert len(outputs_1) == 4
# Break out tokens, text, sample logprobs & prompt logprobs
(
output_ids_0,
output_str_0,
logprobs_0,
prompt_logprobs_0,
) = outputs_0
(
output_ids_1,
output_str_1,
logprobs_1,
prompt_logprobs_1,
) = outputs_1

# Test prompt logprobs closeness
if (prompt_logprobs_0 is not None
and prompt_logprobs_1 is not None):
# Both sequences' prompt logprobs lists are not `None``
# (although individual list elements may be `None`);
# for each token's logprobs:
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
zip(prompt_logprobs_0, prompt_logprobs_1)):
fail_msg = (
f"Prompt logprobs test:"
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")

if logprobs_elem_0 is None:
# If the seq 0 token's logprobs are `None`,
# the seq 1 token's logprobs must be `None`
assert logprobs_elem_1 is None, fail_msg
else:
# If the seq 0 token's logprobs are not `None`,
# the seq 1 token's logprobs must not be `None`
assert logprobs_elem_1 is not None, fail_msg
# Logprobs check: top-k token choices must be the same
assert (set(logprobs_elem_0.keys()) == set(
logprobs_elem_1.keys())), fail_msg
else:
# Both sequence logprobs lists must be `None`
fail_msg = (f"Prompt logprobs test:"
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")

assert (prompt_logprobs_0 is None
and prompt_logprobs_1 is None), fail_msg
else:
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
f"{len(outputs_0)} elements were provided: "
f"{outputs_0}")

if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0)
Expand Down
Loading

0 comments on commit 9640ac5

Please sign in to comment.