Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine #3894

Merged
merged 120 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
252a0c7
wip
cadedaniel Apr 3, 2024
dd629d4
Merge remote-tracking branch 'upstream/main' into executor_base
cadedaniel Apr 3, 2024
a34800f
wip
cadedaniel Apr 3, 2024
09f30bd
wip
cadedaniel Apr 3, 2024
8b5bb8b
clean
cadedaniel Apr 4, 2024
6fd424f
wip
cadedaniel Apr 4, 2024
2a347bb
wip
cadedaniel Apr 4, 2024
658ff9b
wip
cadedaniel Apr 4, 2024
acee7be
wip
cadedaniel Apr 4, 2024
85760d6
wip
cadedaniel Apr 4, 2024
408b29d
wip
cadedaniel Apr 4, 2024
9d8fd69
Merge remote-tracking branch 'upstream/main' into executor_base
cadedaniel Apr 4, 2024
3149a03
wip
cadedaniel Apr 4, 2024
0c32e0a
wip
cadedaniel Apr 4, 2024
f64d5b1
wip
cadedaniel Apr 4, 2024
7207f0c
wip
cadedaniel Apr 4, 2024
0c4df0b
wip
cadedaniel Apr 4, 2024
2e355e7
wip
cadedaniel Apr 4, 2024
edb7f62
wip
cadedaniel Apr 4, 2024
48bb3e9
wip
cadedaniel Apr 4, 2024
7b39044
fix test
cadedaniel Apr 4, 2024
9e5f2fb
fix test
cadedaniel Apr 5, 2024
1a3e26e
fix test
cadedaniel Apr 5, 2024
cd2015c
fix test
cadedaniel Apr 5, 2024
d926034
fix
cadedaniel Apr 5, 2024
607f7e2
fix
cadedaniel Apr 5, 2024
e127bb7
fix
cadedaniel Apr 5, 2024
deaa8b0
fix
cadedaniel Apr 5, 2024
7817d61
clean
cadedaniel Apr 5, 2024
99823a3
clean
cadedaniel Apr 5, 2024
849bfe9
fix
cadedaniel Apr 5, 2024
951ba85
fix
cadedaniel Apr 5, 2024
38948df
speed up cpu test
cadedaniel Apr 5, 2024
397ec77
wip
cadedaniel Apr 5, 2024
23382b9
wip
cadedaniel Apr 5, 2024
7a0294c
clean
cadedaniel Apr 5, 2024
dcdca68
wip
cadedaniel Apr 5, 2024
ed58af2
remove
cadedaniel Apr 5, 2024
df8688e
Revert "more test speedup"
cadedaniel Apr 5, 2024
55a5203
wip
cadedaniel Apr 5, 2024
55d083b
wip
cadedaniel Apr 5, 2024
0814d24
wip
cadedaniel Apr 5, 2024
b18d00c
rename profile_num_available_blocks to get_max_allowed_kv_blocks
cadedaniel Apr 5, 2024
8fb7b9a
rename again
cadedaniel Apr 5, 2024
3bb9e6f
rename
cadedaniel Apr 5, 2024
edad09c
wip
cadedaniel Apr 5, 2024
f93c845
wip
cadedaniel Apr 5, 2024
d2d2218
wip
cadedaniel Apr 5, 2024
2f960e7
lint
cadedaniel Apr 5, 2024
68552e1
wip
cadedaniel Apr 5, 2024
42983ba
import order
cadedaniel Apr 5, 2024
2d5dbb8
fix
cadedaniel Apr 5, 2024
ae2f7e6
docstrings
cadedaniel Apr 5, 2024
c89bb75
Merge branch 'main' into executor_base
cadedaniel Apr 5, 2024
bf041d9
Merge remote-tracking branch 'upstream/main' into llm-engine-spec
cadedaniel Apr 5, 2024
fa8705d
wip
cadedaniel Apr 7, 2024
8495321
wip
cadedaniel Apr 7, 2024
b63975b
wip
cadedaniel Apr 7, 2024
cb23e8c
wip
cadedaniel Apr 7, 2024
143ca28
wip
cadedaniel Apr 7, 2024
d8d4725
fix
cadedaniel Apr 7, 2024
b2728e0
wip
cadedaniel Apr 7, 2024
6250f6c
assertion
cadedaniel Apr 7, 2024
a930755
fix
cadedaniel Apr 7, 2024
5b896a3
fix
cadedaniel Apr 7, 2024
bb43b53
lint
cadedaniel Apr 7, 2024
cde3160
fix
cadedaniel Apr 7, 2024
dd8aeff
fix
cadedaniel Apr 7, 2024
46e4847
test
cadedaniel Apr 7, 2024
8454edc
test fixes
cadedaniel Apr 7, 2024
819e656
lint
cadedaniel Apr 7, 2024
2b0d787
Merge remote-tracking branch 'upstream/main' into executor_base
cadedaniel Apr 7, 2024
67fd287
Merge remote-tracking branch 'upstream/main' into llm-engine-spec
cadedaniel Apr 7, 2024
c3449ba
Merge branch 'executor_base' into llm-engine-spec
cadedaniel Apr 7, 2024
d0fbe47
clean
cadedaniel Apr 7, 2024
5445af6
refactor out beam search model processor
cadedaniel Apr 7, 2024
632b439
fix
cadedaniel Apr 7, 2024
26e7368
dedup stop check
cadedaniel Apr 7, 2024
06e7c01
wip
cadedaniel Apr 7, 2024
184a52c
del
cadedaniel Apr 7, 2024
34468fe
rename
cadedaniel Apr 7, 2024
208c467
wip
cadedaniel Apr 8, 2024
3c6abcc
wip
cadedaniel Apr 8, 2024
bbbcef7
wip
cadedaniel Apr 8, 2024
b58762d
fix
cadedaniel Apr 8, 2024
8b500d4
wip
cadedaniel Apr 8, 2024
782ce22
unit tests for block decode
cadedaniel Apr 8, 2024
3062e1c
stop token ids
cadedaniel Apr 8, 2024
fba3b30
format
cadedaniel Apr 8, 2024
bda141f
fixing spec tests
cadedaniel Apr 8, 2024
49865fb
lint
cadedaniel Apr 8, 2024
1a17ed1
clean up gpu executor
cadedaniel Apr 8, 2024
dea67bb
wip
cadedaniel Apr 8, 2024
189d7eb
fix
cadedaniel Apr 8, 2024
a70a040
wip
cadedaniel Apr 8, 2024
3e1b8f5
detokenization
cadedaniel Apr 8, 2024
b9777a6
lint
cadedaniel Apr 8, 2024
29b4f12
docstrings
cadedaniel Apr 8, 2024
42aa0bc
fix
cadedaniel Apr 8, 2024
0ebd93b
more spec test
cadedaniel Apr 8, 2024
33a3d72
remove
cadedaniel Apr 8, 2024
15c942d
wip
cadedaniel Apr 8, 2024
063e34b
strip
cadedaniel Apr 8, 2024
672a855
print
cadedaniel Apr 8, 2024
8021b38
fix flaky test
cadedaniel Apr 8, 2024
8e93fff
reduce output len
cadedaniel Apr 8, 2024
d06e9a4
strip
cadedaniel Apr 8, 2024
ca516aa
pr feedback
cadedaniel Apr 9, 2024
91cf0fc
Merge branch 'executor_base' into llm-engine-spec
cadedaniel Apr 9, 2024
f6c7b2e
Zhuohan offline pr feedback
cadedaniel Apr 9, 2024
0283fae
Merge remote-tracking branch 'upstream/main' into llm-engine-spec
cadedaniel Apr 9, 2024
96f81c4
lint
cadedaniel Apr 9, 2024
de16919
pr feedback
cadedaniel Apr 10, 2024
d933e50
Merge branch 'main' into llm-engine-spec
cadedaniel Apr 11, 2024
2a19f5e
allow append empty tokens in block table
cadedaniel Apr 16, 2024
79325d3
Merge remote-tracking branch 'upstream/main' into llm-engine-spec
cadedaniel Apr 16, 2024
b6e9e82
rebase on stop string fixes
cadedaniel Apr 16, 2024
bf0c37c
test spec
cadedaniel Apr 16, 2024
a158256
lint & mypy
cadedaniel Apr 16, 2024
5a69f6c
doc
cadedaniel Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions tests/core/block/e2e/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,76 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
"common_llm_kwargs",
[
{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 2,
"max_num_seqs": 2,
},
])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [
{
"use_v2_block_manager": False,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"use_v2_block_manager": True,
"num_lookahead_slots": 0,
},
{
"use_v2_block_manager": True,
"num_lookahead_slots": 5,
},
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify that chunked prefill works with BlockManagerV2, with and without
lookahead scheduling.
"""
output_len = 32
temperature = 0.0

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

print('Getting token ids with BlockManagerV1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)

print('Getting token ids with BlockManagerV2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)

for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids

assert baseline_token_ids == test_token_ids


def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
Expand Down
17 changes: 10 additions & 7 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Optional, Tuple
from typing import Iterable, Optional, Tuple

from vllm import SamplingParams
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -31,14 +31,17 @@ def create_dummy_prompt(


def create_seq_group(
seq_prompt_len=1024,
seq_output_lens=(128, ),
request_id='0',
seq_id_start=0,
) -> SequenceGroup:
seq_prompt_len: int = 1024,
seq_output_lens: Iterable[int] = (128, ),
request_id: str = '0',
seq_id_start: int = 0,
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:

assert len(seq_output_lens) > 0

if sampling_params is None:
sampling_params = SamplingParams()

prompt_token_ids = [0] * seq_prompt_len

seqs = []
Expand All @@ -60,7 +63,7 @@ def create_seq_group(
seq_group = SequenceGroup(
request_id=request_id,
seqs=seqs,
sampling_params=SamplingParams(),
sampling_params=sampling_params,
arrival_time=time.time(),
)

Expand Down
270 changes: 270 additions & 0 deletions tests/engine/output_processor/test_multi_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import random
from unittest.mock import MagicMock

import pytest
from transformers import PreTrainedTokenizer

from tests.core.utils import create_seq_group
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter


@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [1, 12])
@pytest.mark.skip_global_cleanup
def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
"""Verify multi-step decoding appends token ids correctly.
We append token ids and verify all the token ids were appended correctly.
Note that ignore_eos=True.
"""
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
)

seq_group = create_seq_group(
seq_prompt_len=1024,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(max_tokens=seq_output_len +
num_new_tokens,
ignore_eos=True),
)

seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

new_token_ids = list(range(num_new_tokens))

outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]

assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
output_processor.process_outputs(seq_group, outputs)
assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids


@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
@pytest.mark.parametrize("max_tokens", [128 + 3])
@pytest.mark.skip_global_cleanup
def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, max_tokens: int):
"""Verify tokens after max_tokens are dropped and not appended to the
sequence.
"""
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
)

seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(max_tokens=max_tokens, ),
)

seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

new_token_ids = list(range(num_new_tokens))

outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]

assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)

# Expect the processed sequence to not go over max tokens in len.
assert seq.get_len() == seq_prompt_len + max_tokens

# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens


@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [12])
@pytest.mark.parametrize("seed", list(range(6)))
@pytest.mark.skip_global_cleanup
def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, seed: int):
"""Verify the eos token id is included in the sequence, but subsequent
tokens are dropped (not appended to sequence).
"""
random.seed(seed)
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()

eos_token_id = 100

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
)

seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(
# Ensure enough space.
max_tokens=seq_output_len + num_new_tokens, ),
)

seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

new_token_ids = list(range(num_new_tokens))
assert eos_token_id not in new_token_ids
eos_index = random.randint(0, len(new_token_ids) - 1)
new_token_ids[eos_index] = eos_token_id

outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]

assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)

# Expect the processed sequence to not go beyond provided eos.
assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)

# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:eos_index + 1]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens


@pytest.mark.parametrize("seq_prompt_len", [1024])
@pytest.mark.parametrize("seq_output_len", [128])
@pytest.mark.parametrize("num_new_tokens", [12])
@pytest.mark.parametrize("seed", list(range(6)))
@pytest.mark.skip_global_cleanup
def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
seq_output_len: int, seed: int):
"""When sampling parameters dictate that we should ignore the eos token id,
ensure all token ids are appended even if the eos token id is emitted.
"""
random.seed(seed)
detokenizer = MagicMock(spec=Detokenizer)
scheduler = MagicMock(spec=Scheduler)
stop_checker = MagicMock(spec=StopChecker)
seq_counter = Counter()

eos_token_id = 100

output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
)

seq_group = create_seq_group(
seq_prompt_len=seq_prompt_len,
seq_output_lens=[seq_output_len],
sampling_params=SamplingParams(
# Ensure enough space.
max_tokens=seq_output_len + num_new_tokens,
ignore_eos=True,
),
)

seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

new_token_ids = list(range(num_new_tokens))
assert eos_token_id not in new_token_ids
eos_index = random.randint(0, len(new_token_ids) - 1)
new_token_ids[eos_index] = eos_token_id

outputs = [
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq.seq_id,
output_token=output_token,
logprobs={output_token: Logprob(0.0)},
)
],
prompt_logprobs=None,
) for output_token in new_token_ids
]

assert seq.get_len() == seq_prompt_len + seq_output_len
output_processor.process_outputs(seq_group, outputs)

# Expect the processed sequence to go beyond eos.
assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens

# Expect the correct tokens were appended.
expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
seq_output_len]
assert seq.get_token_ids(
)[-len(expected_appended_tokens):] == expected_appended_tokens


def mock_tokenizer(eos_token_id=1000):
tokenizer = MagicMock(spec=PreTrainedTokenizer)
tokenizer.eos_token_id = eos_token_id
return tokenizer
Loading
Loading