forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Speculative decoding 3/9] Worker which speculates, scores, and appli…
…es rejection sampling (vllm-project#3103)
- Loading branch information
1 parent
f48c679
commit 8437bae
Showing
21 changed files
with
2,786 additions
and
215 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import torch | ||
import pytest | ||
|
||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer | ||
|
||
from .utils import mock_worker, create_seq_group_metadata_from_prompts | ||
|
||
|
||
@pytest.mark.parametrize('num_target_seq_ids', [100]) | ||
def test_create_target_seq_id_iterator(num_target_seq_ids: int): | ||
"""Verify all new sequence ids are greater than all input | ||
seq ids. | ||
""" | ||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) | ||
|
||
all_seq_ids = [ | ||
[1, 3, 5, 7], | ||
list(range(100)) + [0], | ||
[100], | ||
] | ||
|
||
for seq_ids in all_seq_ids: | ||
max_seq_id = max(seq_ids) | ||
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access | ||
for _ in range(num_target_seq_ids): | ||
assert next(iterator) > max_seq_id | ||
|
||
|
||
@pytest.mark.parametrize('k', [1, 2, 6]) | ||
def test_get_token_ids_to_score(k: int): | ||
"""Verify correct tokens are selected for scoring. | ||
""" | ||
proposal_token_ids = torch.tensor( | ||
list(range(k)), | ||
dtype=torch.int64, | ||
device='cuda', | ||
) | ||
|
||
expected_output = [ | ||
[], | ||
] | ||
for i in range(proposal_token_ids.shape[0]): | ||
expected_output.append(proposal_token_ids[:i + 1].tolist()) | ||
|
||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) | ||
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access | ||
|
||
actual_output = [ | ||
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output | ||
] | ||
|
||
assert actual_output == expected_output | ||
|
||
|
||
@pytest.mark.parametrize('k', [1, 2, 6]) | ||
def test_create_single_target_seq_group_metadata(k: int): | ||
"""Verify correct creation of a batch-expanded seq group metadata. | ||
""" | ||
|
||
prompt_tokens = [1, 2, 3] | ||
prev_output_tokens = [4, 5, 6] | ||
|
||
token_ids = list(range(k)) | ||
|
||
num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1 | ||
|
||
final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len( | ||
token_ids) | ||
|
||
block_size = 32 | ||
input_seq_group_metadata = create_seq_group_metadata_from_prompts( | ||
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len], | ||
[prev_output_tokens], [num_tokens_processed])[0] | ||
|
||
input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0] | ||
target_seq_id = 100 | ||
|
||
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) | ||
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access | ||
input_seq_group_metadata, | ||
input_seq_id, | ||
target_seq_id, | ||
token_ids, | ||
) | ||
|
||
assert output.request_id == input_seq_group_metadata.request_id | ||
assert len(output.seq_data) == 1 | ||
assert output.seq_data[target_seq_id].get_prompt_token_ids( | ||
) == prompt_tokens | ||
assert output.seq_data[target_seq_id].get_output_token_ids( | ||
) == prev_output_tokens + token_ids | ||
|
||
assert len(output.block_tables) == 1 | ||
assert output.block_tables[ | ||
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import torch | ||
import math | ||
import pytest | ||
|
||
from unittest.mock import MagicMock | ||
|
||
from vllm.spec_decode.metrics import AsyncMetricsCollector | ||
|
||
|
||
def test_initial_call_returns_none(): | ||
"""Expect first call to get metrics to return None. | ||
""" | ||
rej_sampler = MagicMock() | ||
rej_sampler.num_accepted_tokens = torch.tensor(0, | ||
dtype=torch.long, | ||
device='cuda') | ||
rej_sampler.num_emitted_tokens = torch.tensor(0, | ||
dtype=torch.long, | ||
device='cuda') | ||
rej_sampler.num_draft_tokens = 0 | ||
|
||
collector = AsyncMetricsCollector(rej_sampler) | ||
collector.init_gpu_tensors(rank=0) | ||
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5) | ||
assert maybe_metrics is None | ||
|
||
|
||
def test_second_call_returns_metrics(): | ||
"""Expect second call to not return None. | ||
""" | ||
rej_sampler = MagicMock() | ||
rej_sampler.num_accepted_tokens = torch.tensor(0, | ||
dtype=torch.long, | ||
device='cuda') | ||
rej_sampler.num_emitted_tokens = torch.tensor(0, | ||
dtype=torch.long, | ||
device='cuda') | ||
rej_sampler.num_draft_tokens = 0 | ||
|
||
collect_interval_s = 5.0 | ||
timer = MagicMock() | ||
timer.side_effect = [ | ||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 | ||
] | ||
|
||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, | ||
timer=timer, | ||
collect_interval_s=collect_interval_s) | ||
collector.init_gpu_tensors(rank=0) | ||
_ = collector.maybe_collect_rejsample_metrics(k=5) | ||
metrics = collector.maybe_collect_rejsample_metrics(k=5) | ||
assert metrics is not None | ||
|
||
|
||
@pytest.mark.parametrize("rank", [1, 2, 3, 4]) | ||
def test_nonzero_rank_noop(rank): | ||
"""Verify nonzero ranks don't collect metrics. | ||
""" | ||
rej_sampler = MagicMock() | ||
rej_sampler.num_accepted_tokens = torch.tensor(0, | ||
dtype=torch.long, | ||
device='cuda') | ||
rej_sampler.num_emitted_tokens = torch.tensor(0, | ||
dtype=torch.long, | ||
device='cuda') | ||
rej_sampler.num_draft_tokens = 0 | ||
|
||
collector = AsyncMetricsCollector(rej_sampler) | ||
collector.init_gpu_tensors(rank=rank) | ||
_ = collector.maybe_collect_rejsample_metrics(k=5) | ||
metrics = collector.maybe_collect_rejsample_metrics(k=5) | ||
assert metrics is None | ||
|
||
|
||
def test_noop_until_time(): | ||
"""Verify metrics aren't collected until enough time passes. | ||
""" | ||
rej_sampler = MagicMock() | ||
rej_sampler.num_accepted_tokens = torch.tensor(0, | ||
dtype=torch.long, | ||
device='cuda') | ||
rej_sampler.num_emitted_tokens = torch.tensor(0, | ||
dtype=torch.long, | ||
device='cuda') | ||
rej_sampler.num_draft_tokens = 0 | ||
|
||
collect_interval_s = 5.0 | ||
timer = MagicMock() | ||
timer.side_effect = [ | ||
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1, | ||
collect_interval_s + 0.1, collect_interval_s + 0.1 | ||
] | ||
|
||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, | ||
timer=timer, | ||
collect_interval_s=collect_interval_s) | ||
collector.init_gpu_tensors(rank=0) | ||
|
||
_ = collector.maybe_collect_rejsample_metrics(k=5) | ||
metrics = collector.maybe_collect_rejsample_metrics(k=5) | ||
assert metrics is None | ||
|
||
_ = collector.maybe_collect_rejsample_metrics(k=5) | ||
metrics = collector.maybe_collect_rejsample_metrics(k=5) | ||
assert metrics is not None | ||
|
||
|
||
@pytest.mark.parametrize("has_data", [True, False]) | ||
def test_initial_metrics_has_correct_values(has_data: bool): | ||
"""Test correctness of metrics data. | ||
""" | ||
if has_data: | ||
num_accepted_tokens = 103 | ||
num_emitted_tokens = 104 | ||
num_draft_tokens = 105 | ||
else: | ||
num_accepted_tokens = 0 | ||
num_emitted_tokens = 0 | ||
num_draft_tokens = 0 | ||
k = 5 | ||
|
||
num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens( | ||
num_draft_tokens, k) | ||
|
||
rej_sampler = MagicMock() | ||
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens, | ||
dtype=torch.long, | ||
device='cuda') | ||
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, | ||
dtype=torch.long, | ||
device='cuda') | ||
rej_sampler.num_draft_tokens = num_draft_tokens | ||
|
||
collect_interval_s = 5.0 | ||
timer = MagicMock() | ||
timer.side_effect = [ | ||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 | ||
] | ||
|
||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, | ||
timer=timer, | ||
collect_interval_s=collect_interval_s) | ||
collector.init_gpu_tensors(rank=0) | ||
_ = collector.maybe_collect_rejsample_metrics(k) | ||
metrics = collector.maybe_collect_rejsample_metrics(k) | ||
|
||
assert metrics.num_spec_tokens == k | ||
assert metrics.accepted_tokens == num_accepted_tokens | ||
assert metrics.draft_tokens == num_draft_tokens | ||
assert metrics.emitted_tokens == num_emitted_tokens | ||
|
||
if has_data: | ||
assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens | ||
assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens | ||
else: | ||
assert math.isnan(metrics.draft_acceptance_rate) | ||
assert math.isnan(metrics.system_efficiency) |
Oops, something went wrong.