Skip to content

Commit

Permalink
test: update spec_decode e2e tests
Browse files Browse the repository at this point in the history
Include cases where disable_logprobs is True.

Signed-off-by: Travis Johnson <[email protected]>
  • Loading branch information
tjohnson31415 committed Sep 16, 2024
1 parent f54e8bc commit 5a4a674
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 92 deletions.
141 changes: 94 additions & 47 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from itertools import cycle
from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Tuple, Union

import pytest

from vllm import LLM, SamplingParams
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import PromptLogprobs, SampleLogprobs

from ...conftest import cleanup
from ...models.utils import check_logprobs_close, check_outputs_equal
from ...models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
check_logprobs_close, check_outputs_equal)
from ...utils import RemoteOpenAIServer

PROMPTS = [
Expand Down Expand Up @@ -81,45 +84,79 @@ def get_output_from_llm_generator(
return tokens, token_ids, acceptance_rate


def run_logprob_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
max_output_len: int,
seed: Optional[int] = 0,
temperature: float = 0.0,
logprobs: int = 1):
org_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**baseline_llm_kwargs,
}

sd_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}

prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]

sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
logprobs=logprobs)

with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)

with vllm_runner(**sd_args) as vllm_model:
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)

check_logprobs_close(outputs_0_lst=org_outputs,
outputs_1_lst=sd_outputs,
name_0="org",
name_1="sd")
def check_logprobs_correctness(
spec_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
baseline_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
disable_logprobs: bool = False,
):
"""Compare sampled and prompt logprobs between baseline and spec decoding
"""
if not disable_logprobs:
return check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=spec_outputs,
name_0="org",
name_1="sd",
)

# Check correctness when disable_logprobs == True
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
# assert len(baseline_output) >= 3
# assert len(spec_output) >= 3
# Check generated token logprobs.
spec_logprobs = spec_output[2]
baseline_logprobs = baseline_output[2]
_check_logprobs_when_output_disabled(spec_logprobs,
baseline_logprobs,
is_prompt_logprobs=False)

# Check prompt logprobs too, if they exist
if len(baseline_output) == 4:
assert len(spec_output) == 4
spec_prompt_logprobs = spec_output[3]
baseline_prompt_logprobs = baseline_output[3]
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
baseline_prompt_logprobs,
is_prompt_logprobs=True)


def _check_logprobs_when_output_disabled(
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
is_prompt_logprobs: bool = False,
):
# Prompt logprobs are optional
if is_prompt_logprobs and baseline_logprobs is None:
assert spec_logprobs is None
return

assert spec_logprobs is not None
assert baseline_logprobs is not None
assert len(spec_logprobs) == len(baseline_logprobs)

# For each generated position of the sequence.
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):

# First prompt logprob is expected to be None
if is_prompt_logprobs and baseline_pos_logprobs is None:
assert spec_pos_logprobs is None
assert pos == 0
continue

assert spec_pos_logprobs is not None
assert baseline_pos_logprobs is not None

# When disabled, the 1 logprob is returned with dummy values for the
# score and rank, but the token id should match the baseline model
assert len(spec_pos_logprobs) == 1
(spec_pos_logprob_token_id,
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
assert spec_pos_logprob_token_id in baseline_pos_logprobs


def run_equality_correctness_test(
Expand All @@ -135,7 +172,10 @@ def run_equality_correctness_test(
disable_seed: bool = False,
ignore_eos: bool = True,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None):
expected_acceptance_rate: Optional[float] = None,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
disable_logprobs: bool = False):

org_args = {
**common_llm_kwargs,
Expand All @@ -157,10 +197,12 @@ def run_equality_correctness_test(
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
ignore_eos=ignore_eos)
ignore_eos=ignore_eos,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)

with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate(prompts, sampling_params)
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)

with vllm_runner(**sd_args) as vllm_model:
if ensure_all_accepted or expected_acceptance_rate is not None:
Expand All @@ -169,7 +211,7 @@ def run_equality_correctness_test(
'prometheus']
stat_logger.local_interval = -100

sd_outputs = vllm_model.generate(prompts, sampling_params)
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)

if ensure_all_accepted or expected_acceptance_rate is not None:
acceptance_rate = (stat_logger.metrics.
Expand All @@ -185,11 +227,16 @@ def run_equality_correctness_test(
if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2

check_outputs_equal(outputs_0_lst=org_outputs,
outputs_1_lst=sd_outputs,
# Only pass token entries, not the logprobs
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
outputs_1_lst=[out[0:2] for out in sd_outputs],
name_0="org",
name_1="sd")

# Check logprobs if requested
if logprobs is not None or prompt_logprobs is not None:
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)


def run_equality_correctness_test_tp(model,
common_llm_kwargs,
Expand Down
101 changes: 57 additions & 44 deletions tests/spec_decode/e2e/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from vllm import SamplingParams

from .conftest import run_logprob_correctness_test
from .conftest import run_equality_correctness_test


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
Expand All @@ -22,9 +22,13 @@
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
Expand All @@ -41,22 +45,25 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
seed: int, logprobs: int):
"""Verify output logprobs are equal with and without speculative decoding.
"""
run_logprob_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
Expand Down Expand Up @@ -91,16 +98,18 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
output_len: int, seed: int, logprobs: int):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_logprob_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])


@pytest.mark.parametrize(
Expand Down Expand Up @@ -143,16 +152,18 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
seed: int, logprobs: int):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_logprob_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])


@pytest.mark.parametrize(
Expand Down Expand Up @@ -267,13 +278,15 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
"""
run_logprob_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
Loading

0 comments on commit 5a4a674

Please sign in to comment.