Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Bugfix] Support prompt_logprobs returned with speculative decoding #8047

Merged
4 changes: 1 addition & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,6 @@ def generate_w_logprobs(
videos: Optional[PromptVideoInput] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
assert sampling_params.logprobs is not None

if images is not None:
assert len(prompts) == len(images)

Expand Down Expand Up @@ -754,7 +752,7 @@ def generate_greedy_logprobs(
temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
prompt_logprobs=num_prompt_logprobs,
stop_token_ids=stop_token_ids)

return self.generate_w_logprobs(prompts,
Expand Down
139 changes: 92 additions & 47 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from itertools import cycle
from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Tuple, Union

import pytest

from vllm import LLM, SamplingParams
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import PromptLogprobs, SampleLogprobs

from ...conftest import cleanup
from ...models.utils import check_logprobs_close, check_outputs_equal
from ...models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
check_logprobs_close, check_outputs_equal)
from ...utils import RemoteOpenAIServer

PROMPTS = [
Expand Down Expand Up @@ -81,45 +84,77 @@ def get_output_from_llm_generator(
return tokens, token_ids, acceptance_rate


def run_logprob_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
max_output_len: int,
seed: Optional[int] = 0,
temperature: float = 0.0,
logprobs: int = 1):
org_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**baseline_llm_kwargs,
}

sd_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}

prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]

sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
logprobs=logprobs)

with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)

with vllm_runner(**sd_args) as vllm_model:
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)

check_logprobs_close(outputs_0_lst=org_outputs,
outputs_1_lst=sd_outputs,
name_0="org",
name_1="sd")
def check_logprobs_correctness(
spec_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
baseline_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
disable_logprobs: bool = False,
):
"""Compare sampled and prompt logprobs between baseline and spec decoding
"""
if not disable_logprobs:
return check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=spec_outputs,
name_0="org",
name_1="sd",
)

# Check correctness when disable_logprobs == True
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
# Check generated token logprobs.
spec_logprobs = spec_output[2]
baseline_logprobs = baseline_output[2]
_check_logprobs_when_output_disabled(spec_logprobs,
baseline_logprobs,
is_prompt_logprobs=False)

# Check prompt logprobs too, if they exist
if len(baseline_output) == 4:
assert len(spec_output) == 4
spec_prompt_logprobs = spec_output[3]
baseline_prompt_logprobs = baseline_output[3]
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
baseline_prompt_logprobs,
is_prompt_logprobs=True)


def _check_logprobs_when_output_disabled(
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
is_prompt_logprobs: bool = False,
):
# Prompt logprobs are optional
if is_prompt_logprobs and baseline_logprobs is None:
assert spec_logprobs is None
return

assert spec_logprobs is not None
assert baseline_logprobs is not None
assert len(spec_logprobs) == len(baseline_logprobs)

# For each generated position of the sequence.
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):

# First prompt logprob is expected to be None
if is_prompt_logprobs and baseline_pos_logprobs is None:
assert spec_pos_logprobs is None
assert pos == 0
continue

assert spec_pos_logprobs is not None
assert baseline_pos_logprobs is not None

# When disabled, the 1 logprob is returned with dummy values for the
# score and rank, but the token id should match the baseline model
assert len(spec_pos_logprobs) == 1
(spec_pos_logprob_token_id,
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
assert spec_pos_logprob_token_id in baseline_pos_logprobs


def run_equality_correctness_test(
Expand All @@ -135,7 +170,10 @@ def run_equality_correctness_test(
disable_seed: bool = False,
ignore_eos: bool = True,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None):
expected_acceptance_rate: Optional[float] = None,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
disable_logprobs: bool = False):

org_args = {
**common_llm_kwargs,
Expand All @@ -157,10 +195,12 @@ def run_equality_correctness_test(
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
ignore_eos=ignore_eos)
ignore_eos=ignore_eos,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)

with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate(prompts, sampling_params)
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)

with vllm_runner(**sd_args) as vllm_model:
if ensure_all_accepted or expected_acceptance_rate is not None:
Expand All @@ -169,7 +209,7 @@ def run_equality_correctness_test(
'prometheus']
stat_logger.local_interval = -100

sd_outputs = vllm_model.generate(prompts, sampling_params)
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)

if ensure_all_accepted or expected_acceptance_rate is not None:
acceptance_rate = (stat_logger.metrics.
Expand All @@ -185,11 +225,16 @@ def run_equality_correctness_test(
if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2

check_outputs_equal(outputs_0_lst=org_outputs,
outputs_1_lst=sd_outputs,
# Only pass token entries, not the logprobs
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
outputs_1_lst=[out[0:2] for out in sd_outputs],
name_0="org",
name_1="sd")

# Check logprobs if requested
if logprobs is not None or prompt_logprobs is not None:
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)


def run_equality_correctness_test_tp(model,
common_llm_kwargs,
Expand Down
58 changes: 58 additions & 0 deletions tests/spec_decode/e2e/test_eagle_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
batch_size, output_len, seed)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,

# Required for spec decode.
"use_v2_block_manager": True,

# Print spec metrics.
"disable_log_stats": False,

# Precision
"dtype": PRECISION,

# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
},
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):

run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
Loading
Loading