Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker #5348

Merged
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
bbf1484
Integrate Typical Acceptance Sampler into spec decode worker
sroy745 Jun 7, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
3495673
Fixing tests
sroy745 Jun 9, 2024
26c7c57
adding missing commit
sroy745 Jun 10, 2024
090f0bf
reverting changes to conftest
sroy745 Jun 10, 2024
733cc6e
reverting changes to conftest
sroy745 Jun 10, 2024
19ca0c9
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 10, 2024
acf8d2c
Dummy commit
sroy745 Jun 10, 2024
2d2b02b
Merge branch 'spec_decode_integrate_accpetance_sampler' of https://gi…
sroy745 Jun 10, 2024
2010b35
Revert unnecessary commits
sroy745 Jun 10, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
7fa64b6
Merge remote-tracking branch 'origin/main' into spec_decode_integrate…
sroy745 Jun 10, 2024
dea6fbd
Pass only one sampler which can either be the RejectionSampler of the…
sroy745 Jun 10, 2024
c3383db
Fix test scripture
sroy745 Jun 10, 2024
b15abba
Fix tests
sroy745 Jun 11, 2024
6ca731c
Fix tests
sroy745 Jun 11, 2024
483c671
Pass only 1 verification_sampler which can either be rejectionSampler…
sroy745 Jun 11, 2024
2c6d06c
Update metrics.py to take the base sampler class
sroy745 Jun 11, 2024
027b485
Fix tests and comments
sroy745 Jun 11, 2024
ded92ac
Fix test fixture and default values of args
sroy745 Jun 11, 2024
738871e
Small misc fixes
sroy745 Jun 11, 2024
50e8771
Fix spec_decode/test_metrics.py
sroy745 Jun 11, 2024
101611e
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 11, 2024
5e6638b
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 25, 2024
cc760a0
Make rejection_sampler.py and typical_acceptance_sampler.py implement…
sroy745 Jun 25, 2024
360ce0b
Raise exception instead of returning None for invalid sampler name
sroy745 Jun 25, 2024
6572ba4
Adding log about type of sampler
sroy745 Jun 25, 2024
be85f07
Misc comment fixes
sroy745 Jun 26, 2024
6dc9efe
Misc fixes
sroy745 Jun 26, 2024
512fad9
Misc fixes
sroy745 Jun 26, 2024
b1d510c
Misc fixes
sroy745 Jun 26, 2024
f4b9e4d
Misc fixes
sroy745 Jun 26, 2024
0ea9408
Documentation
sroy745 Jun 26, 2024
5772d04
Fix comments
sroy745 Jun 26, 2024
b7254e7
Fix arg name
sroy745 Jun 26, 2024
ef93081
Fixing a test
sroy745 Jun 26, 2024
0165842
Fix comment
sroy745 Jun 26, 2024
510974b
Fix formatting
sroy745 Jun 26, 2024
396fa54
Fixing tests and lint failures
sroy745 Jun 26, 2024
f8cc895
Removing e2e test for TypicalAcceptanceSampler from test_ngram_correc…
sroy745 Jun 27, 2024
439117d
Fix a comment
sroy745 Jun 27, 2024
75f034f
Dummy commit
sroy745 Jun 27, 2024
a0f5ade
Merge pull request #2 from vllm-project/main
sroy745 Jun 27, 2024
3082255
Fix format error
sroy745 Jun 27, 2024
4e7f51a
Merge pull request #3 from vllm-project/main
sroy745 Jun 28, 2024
d26c624
Dummy fix
sroy745 Jun 29, 2024
98d5f92
Merge branch 'main' into spec_decode_integrate_accpetance_sampler
sroy745 Jun 29, 2024
f186844
Update test_multistep_correctness.py
sroy745 Jun 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.

At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.

For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model).
be prohibitively expensive to run with a real model). Similary, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing strategy:

I am concerned that we are adding many E2E tests that don't provide a lot of signal over what already exists. The tradeoff of more tests is that we can accidentally explode CI time. This is because we rely on E2E tests for spec decode correctness and any small regression in model loading or vLLM initialization time can hurt us bad.

So, what I suggest:

  • E2E tests over the interaction between spec decode worker and typical acceptance
    • make sure it can handle different BS
    • make sure it can handle different K
  • E2E test over one other proposer method
    • just need one to make sure typical acceptance works beyond draft model
  • We don't need tests around preemption, disabling/skipping speculation, different block size, since all of these are no different for typical acceptance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a single test in test_multistep_correctness.py to cover different batch size and speculation_length values with TypicalAcceptanceSampler. Added a similar test to test_ngram_correctness.py


NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
Expand Down Expand Up @@ -177,7 +183,9 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"speculative_draft_token_sampling_method": method
}
for method in ["rejection_sampler", "typical_acceptance_sampler"]
])
@pytest.mark.parametrize(
"output_len",
Expand Down Expand Up @@ -230,8 +238,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"num_speculative_tokens": 3,
"speculative_draft_token_sampling_method": method
}
for method in ["rejection_sampler", "typical_acceptance_sampler"]
])
@pytest.mark.parametrize(
"output_len",
Expand Down Expand Up @@ -279,7 +289,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"speculative_draft_token_sampling_method": method
}
for method in ["rejection_sampler", "typical_acceptance_sampler"]
])
@pytest.mark.parametrize("max_output_len", [
256,
Expand Down Expand Up @@ -320,7 +332,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"speculative_draft_token_sampling_method": method
}
for method in ["rejection_sampler", "typical_acceptance_sampler"]
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -364,7 +378,9 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"speculative_draft_token_sampling_method": method
}
for method in ["rejection_sampler", "typical_acceptance_sampler"]
])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -411,7 +427,9 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"speculative_draft_token_sampling_method": method
}
for method in ["rejection_sampler", "typical_acceptance_sampler"]
])
@pytest.mark.parametrize(
"output_len",
Expand Down Expand Up @@ -465,7 +483,9 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"speculative_draft_token_sampling_method": method
}
for method in ["rejection_sampler", "typical_acceptance_sampler"]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -506,11 +526,13 @@ def test_spec_decode_different_block_size(baseline_llm_generator,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,

# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
},
"speculative_draft_token_sampling_method": method
}
for method in ["rejection_sampler", "typical_acceptance_sampler"]
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -554,7 +576,9 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
},
"speculative_draft_token_sampling_method": method
}
for method in ["rejection_sampler", "typical_acceptance_sampler"]
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [10])
Expand Down Expand Up @@ -589,9 +613,12 @@ def test_disable_speculation(baseline_llm_generator, test_llm_generator,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"speculative_draft_token_sampling_method": method,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
# Try both methods of sampling in the verifier.
for method in ["rejection_sampler", "typical_acceptance_sampler"]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
Expand Down
12 changes: 6 additions & 6 deletions tests/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,33 @@
import pytest
import torch

from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer

from .utils import create_batch, mock_worker

from .test_utils import mock_spec_decode_sampler

@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("mock_spec_decode_sampler",
["rejection_sampler", "typical_acceptance_sampler"], indirect=True)
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
def test_disable_spec_tokens(
queue_size: int, batch_size: int, k: int, mock_spec_decode_sampler):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3

draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
rejection_sampler=rejection_sampler,
spec_decode_sampler=mock_spec_decode_sampler,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)

Expand Down
90 changes: 45 additions & 45 deletions tests/spec_decode/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
def test_initial_call_returns_none():
"""Expect first call to get metrics to return None.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(rej_sampler)
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None
Expand All @@ -28,22 +28,22 @@ def test_initial_call_returns_none():
def test_second_call_returns_metrics():
"""Expect second call to not return None.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0

collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]

collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
Expand All @@ -56,16 +56,16 @@ def test_second_call_returns_metrics():
def test_nonzero_rank_noop(rank):
"""Verify nonzero ranks don't collect metrics.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(rej_sampler)
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
Expand All @@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank):
def test_noop_until_time():
"""Verify metrics aren't collected until enough time passes.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0

collect_interval_s = 5.0
timer = MagicMock()
Expand All @@ -91,7 +91,7 @@ def test_noop_until_time():
collect_interval_s + 0.1, collect_interval_s + 0.1
]

collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
Expand Down Expand Up @@ -122,22 +122,22 @@ def test_initial_metrics_has_correct_values(has_data: bool):
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
num_draft_tokens, k)

rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = num_draft_tokens
spec_decode_sampler.num_draft_tokens = num_draft_tokens

collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]

collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
Expand Down
Loading
Loading