Skip to content

Commit

Permalink
fix ngram tests; remove sd+cp outdated compat check
Browse files Browse the repository at this point in the history
  • Loading branch information
NickLucche committed Nov 4, 2024
1 parent ca2691e commit 756d33f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 42 deletions.
34 changes: 0 additions & 34 deletions tests/spec_decode/e2e/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,6 @@
from .conftest import get_output_from_llm_generator


@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"enable_chunked_prefill": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

with pytest.raises(ValueError,
match="Speculative decoding and chunked prefill"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)


@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
Expand Down
17 changes: 11 additions & 6 deletions tests/spec_decode/e2e/test_ngram_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,34 @@
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size."""
if prefill_chunk_size > 0:
common_llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
common_llm_kwargs["enable_chunked_prefill"] = True
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down
9 changes: 7 additions & 2 deletions tests/spec_decode/test_ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)

for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
Expand Down Expand Up @@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
def test_ngram_algo_correctness_for_batches_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs
For the scenario find candidate in all batches
"""

block_size = 32
Expand Down Expand Up @@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
block_size,
final_prompt_lens=final_prompt_lens)

# Normally drafter is run on decode requests only; here we check the output
# of the ngram worker as it is the sole proposer that has no forward.
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
Expand Down

0 comments on commit 756d33f

Please sign in to comment.