Skip to content

Commit

Permalink
[Speculative decoding 7/9] Speculative decoding end-to-end correctnes…
Browse files Browse the repository at this point in the history
…s tests. (#3951)
  • Loading branch information
cadedaniel authored Apr 23, 2024
1 parent 050f285 commit 62b8aeb
Show file tree
Hide file tree
Showing 22 changed files with 1,164 additions and 175 deletions.
8 changes: 6 additions & 2 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids,
)

# Bonus tokens are currently disabled. Verify they're set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1

if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)

# Expect all bonus tokens to be included.
assert torch.equal(output_token_ids[:, -1:], bonus_token_ids)
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
elif which_tokens_accepted == "no_tokens_accepted":
# Expect first token to be equal to recovered tokens.
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
Expand All @@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
torch.ones_like(output_token_ids[:, 1:]) * -1)
elif which_tokens_accepted == "some_tokens_accepted":
recovered_plus_bonus = torch.cat(
(recovered_token_ids, bonus_token_ids), dim=-1)
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
# Assert first rejected token is a recovered token or bonus token.
assert torch.equal(
recovered_plus_bonus[torch.arange(0, batch_size),
Expand Down
3 changes: 2 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
def mock_sample(probs, *args, **kwargs):
nonlocal sample_probs
sample_probs = probs
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
for prob in probs], None)

with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
Expand Down
Empty file.
45 changes: 35 additions & 10 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Tuple

import pytest

from tests.conftest import cleanup
Expand All @@ -6,28 +8,34 @@


@pytest.fixture
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def baseline_llm_generator(request, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
seed):
return create_llm_generator("baseline", request, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, seed)


@pytest.fixture
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed)
return create_llm_generator("test", request, common_llm_kwargs,
per_test_common_llm_kwargs, test_llm_kwargs,
seed)


def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
distinct_llm_kwargs, seed):
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
per_test_common_llm_kwargs, distinct_llm_kwargs,
seed):
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**distinct_llm_kwargs,
}
test_name = request.node.name

def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = LLM(**kwargs)

set_random_seed(seed)
Expand All @@ -36,6 +44,23 @@ def generator_inner():
del llm
cleanup()

for llm in generator_inner():
yield llm
def generator_outer():
for llm in generator_inner():
yield llm
del llm

return generator_outer


def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
tokens = []
token_ids = []
for llm in llm_generator():
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
del llm

return tokens, token_ids
169 changes: 169 additions & 0 deletions tests/spec_decode/e2e/test_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import pytest

from vllm import SamplingParams

from .conftest import get_output_from_llm_generator


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_ray(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

with pytest.raises(AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"enable_chunked_prefill": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

with pytest.raises(ValueError,
match="Speculative decoding and chunked prefill"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Speculative max model len > overridden max model len should raise.
"max_model_len": 128,
"speculative_max_model_len": 129,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len": 2048 + 1,
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
"speculative_max_model_len": 4096 + 1,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
"""Verify that speculative decoding validates speculative_max_model_len.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)


@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
"""Verify that speculative decoding with block manager v1 fails.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

with pytest.raises(ValueError,
match="Speculative decoding requires usage of the V2"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
Loading

0 comments on commit 62b8aeb

Please sign in to comment.