Skip to content

Commit

Permalink
[Speculative decoding 3/9] Worker which speculates, scores, and appli…
Browse files Browse the repository at this point in the history
…es rejection sampling (vllm-project#3103)
  • Loading branch information
cadedaniel authored Mar 9, 2024
1 parent f48c679 commit 8437bae
Show file tree
Hide file tree
Showing 21 changed files with 2,786 additions and 215 deletions.
5 changes: 4 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ steps:
num_gpus: 2 # only support 1 or 2 for now.

- label: Engine Test
command: pytest -v -s engine
command: pytest -v -s engine test_sequence.py

- label: Entrypoints Test
command: pytest -v -s entrypoints
Expand All @@ -52,6 +52,9 @@ steps:
- label: Worker Test
command: pytest -v -s worker

- label: Speculative decoding tests
command: pytest -v -s spec_decode

- label: LoRA Test
command: pytest -v -s lora --forked

Expand Down
File renamed without changes.
95 changes: 95 additions & 0 deletions tests/spec_decode/test_batch_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import pytest

from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer

from .utils import mock_worker, create_seq_group_metadata_from_prompts


@pytest.mark.parametrize('num_target_seq_ids', [100])
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
"""Verify all new sequence ids are greater than all input
seq ids.
"""
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)

all_seq_ids = [
[1, 3, 5, 7],
list(range(100)) + [0],
[100],
]

for seq_ids in all_seq_ids:
max_seq_id = max(seq_ids)
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
for _ in range(num_target_seq_ids):
assert next(iterator) > max_seq_id


@pytest.mark.parametrize('k', [1, 2, 6])
def test_get_token_ids_to_score(k: int):
"""Verify correct tokens are selected for scoring.
"""
proposal_token_ids = torch.tensor(
list(range(k)),
dtype=torch.int64,
device='cuda',
)

expected_output = [
[],
]
for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist())

scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access

actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
]

assert actual_output == expected_output


@pytest.mark.parametrize('k', [1, 2, 6])
def test_create_single_target_seq_group_metadata(k: int):
"""Verify correct creation of a batch-expanded seq group metadata.
"""

prompt_tokens = [1, 2, 3]
prev_output_tokens = [4, 5, 6]

token_ids = list(range(k))

num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1

final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
token_ids)

block_size = 32
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
[prev_output_tokens], [num_tokens_processed])[0]

input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
target_seq_id = 100

scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
input_seq_group_metadata,
input_seq_id,
target_seq_id,
token_ids,
)

assert output.request_id == input_seq_group_metadata.request_id
assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids(
) == prompt_tokens
assert output.seq_data[target_seq_id].get_output_token_ids(
) == prev_output_tokens + token_ids

assert len(output.block_tables) == 1
assert output.block_tables[
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]
157 changes: 157 additions & 0 deletions tests/spec_decode/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import torch
import math
import pytest

from unittest.mock import MagicMock

from vllm.spec_decode.metrics import AsyncMetricsCollector


def test_initial_call_returns_none():
"""Expect first call to get metrics to return None.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(rej_sampler)
collector.init_gpu_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None


def test_second_call_returns_metrics():
"""Expect second call to not return None.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0

collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]

collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None


@pytest.mark.parametrize("rank", [1, 2, 3, 4])
def test_nonzero_rank_noop(rank):
"""Verify nonzero ranks don't collect metrics.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(rej_sampler)
collector.init_gpu_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None


def test_noop_until_time():
"""Verify metrics aren't collected until enough time passes.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0

collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1,
collect_interval_s + 0.1, collect_interval_s + 0.1
]

collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None


@pytest.mark.parametrize("has_data", [True, False])
def test_initial_metrics_has_correct_values(has_data: bool):
"""Test correctness of metrics data.
"""
if has_data:
num_accepted_tokens = 103
num_emitted_tokens = 104
num_draft_tokens = 105
else:
num_accepted_tokens = 0
num_emitted_tokens = 0
num_draft_tokens = 0
k = 5

num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
num_draft_tokens, k)

rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = num_draft_tokens

collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]

collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k)
metrics = collector.maybe_collect_rejsample_metrics(k)

assert metrics.num_spec_tokens == k
assert metrics.accepted_tokens == num_accepted_tokens
assert metrics.draft_tokens == num_draft_tokens
assert metrics.emitted_tokens == num_emitted_tokens

if has_data:
assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens
assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens
else:
assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency)
Loading

0 comments on commit 8437bae

Please sign in to comment.