Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding] Add ngram prompt lookup decoding #4237

Merged
merged 2 commits into from
May 1, 2024

Conversation

leiwen83
Copy link
Contributor

Algo details could refer to this blog post:
https://huggingface.co/blog/assisted-generation

Code directly refer to transformers's current implementation. huggingface/transformers#27775

Since we directly get draft from prompt, there is no need another model or modified model to get the proposal, it would be the most convenient way to enjoy the speedup of speculation.

@leiwen83
Copy link
Contributor Author

Implementaion for the feature of ngram prompt lookup mentioned in #2188

@leiwen83
Copy link
Contributor Author

@cadedaniel
Could you help to review this PR, especially whether we shall have some more abstraction taking some other speculatiation type like medusa into consideration?

@cadedaniel
Copy link
Collaborator

cadedaniel commented Apr 22, 2024

thanks for the PR @leiwen83 , we'll take a look!

@comaniac do you have bandwidth to shepherd this PR?

@comaniac
Copy link
Collaborator

comaniac commented Apr 22, 2024

Thanks for the PR! I'll review it

vllm/spec_decode/multi_step_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/multi_step_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/multi_step_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/multi_step_worker.py Outdated Show resolved Hide resolved
@leiwen83 leiwen83 marked this pull request as ready for review April 22, 2024 10:42
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another batch of comments. Also we would need unit tests for this PR.

vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
tests/spec_decode/e2e/test_correctness.py Outdated Show resolved Hide resolved
vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
@leiwen83
Copy link
Contributor Author

Another batch of comments. Also we would need unit tests for this PR.

I current have one e2e unittest in tests/spec_decode/e2e/test_correctness.py. Do you mean may need some other unit tests?

@leiwen83
Copy link
Contributor Author

leiwen83 commented Apr 23, 2024

@cadedaniel @comaniac ,
Since #3951 is merged, and there is some changes have been made for spec infer sub system, shall we squash this commit and rebase over latest code for easy further code review?
And if squash/rebase is needed, whether I shall force update in this thread or open another PR?

@comaniac
Copy link
Collaborator

I current have one e2e unittest in tests/spec_decode/e2e/test_correctness.py. Do you mean may need some other unit tests?

Sorry let me make it more clear. The unit tests should cover as many cases as possible. For example, batch size > 1 with some seqs find a match but others don't; n-gram size from 1 to 3; long/short speculative sizes, etc.

@cadedaniel
Copy link
Collaborator

@cadedaniel @comaniac ,
Since #3951 is merged, and there is some changes have been made for spec infer sub system, shall we squash this commit and rebase over latest code for easy further code review?
And if squash/rebase is needed, whether I shall force update in this thread or open another PR?

You're free to use what strategy you prefer to remove the merge conflicts. I often will merge main into my dev branch but there's tradeoffs. I'd keep this PR open, you can push/force push as necessary!

And sorry for the conflicts -- PR 7/9 was the last large one to the subsystem. I expect less merge conflicts going forward.

I current have one e2e unittest in tests/spec_decode/e2e/test_correctness.py. Do you mean may need some other unit tests?

Sorry let me make it more clear. The unit tests should cover as many cases as possible. For example, batch size > 1 with some seqs find a match but others don't; n-gram size from 1 to 3; long/short speculative sizes, etc.

+1. We'll want the following tests:

  • Batch size 1 greedy equality
  • Batch size >1 greedy equality
  • A test that covers when there's no ngram match for any sequence
  • A test that covers when there's ngram matches for some sequences in a batch, but not all
  • A test that covers when there's ngram matches for all seqs in the batch (can omit since bs>1 greedy equality will likely cover it)
  • Test various ngram sizes / speculative sizes
  • Test greedy equality under preemption.

Most of these can be copied from the existing tests and refitted for ngram speculation; I suggest making a new file for spec ngram test correctness (so the current one can be draft model).

@leiwen83
Copy link
Contributor Author

leiwen83 commented Apr 25, 2024

@cadedaniel @comaniac
I squash and rebase over latest main branch, while fix all issues so far and add the necessary test cases, would you mind to take a look at the refreshed PR?

Thx~

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM (the n-gram worker). I'll leave the rest to @cadedaniel

tests/spec_decode/e2e/test_compatibility.py Outdated Show resolved Hide resolved
tests/spec_decode/e2e/test_ngram_correctness.py Outdated Show resolved Hide resolved
tests/spec_decode/e2e/test_ngram_correctness.py Outdated Show resolved Hide resolved
vllm/spec_decode/ngram_worker.py Outdated Show resolved Hide resolved
@cadedaniel
Copy link
Collaborator

will take another pass today

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just minor comments, otherwise LGTM! Can you address/respond and then we'll merge.

@@ -0,0 +1,241 @@
"""The tests in this file verify end-to-end speculative decoding correctness.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this note

])
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model(baseline_llm_generator,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename test_ngram_e2e_greedy_correctness

])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_with_preemption(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_ngram_e2e_greedy_correctness_with_preemption

Comment on lines 75 to 79
"""Verify greedy equality on a tiny model with batch size of one.

Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update comment -- this does more than bs=1

Comment on lines 151 to 165
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5, 7, 10, 63]
] + [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5, 7, 10, 63]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To improve test time, we can reduce the space covered -- I suggest k=[1, 3, 5] x ngram_prompt_lookup_max=[1, 3].

parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
default=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: default=EngineArgs.ngram_prompt_lookup_max

parser.add_argument(
'--ngram-prompt-lookup-min',
type=int,
default=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: default=EngineArgs.ngram_prompt_lookup_min

Comment on lines +734 to +762
draft_model_config = target_model_config
draft_parallel_config = target_parallel_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a TODO here to set these to None?

Comment on lines +52 to +53
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docs on the new return value ?

sampler_output_list: List[SamplerOutput],
) -> Tuple[torch.Tensor, torch.Tensor]:
sampler_output_list: List[SamplerOutput],
sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility function which converts a list of SamplerOutput to tensors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a docstring with new arg

Algo details could refer to this blog post:
https://huggingface.co/blog/assisted-generation

Code directly refer to transformers's current implementation.
huggingface/transformers#27775

Since we directly get draft from prompt, there is no need
another model or modified model to get the proposal, it would be
the most convenient way to enjoy the speedup of speculation.
@leiwen83 leiwen83 force-pushed the ngram_lookahead_specinfer branch from 748c687 to 8adaf38 Compare May 1, 2024 13:06
@leiwen83
Copy link
Contributor Author

leiwen83 commented May 1, 2024

@cadedaniel all note has been addressed, and rebase it against latest code, could you take another look?

@cadedaniel
Copy link
Collaborator

Looks good, thanks @leiwen83 ! Thanks for contributing to vLLM 😃

@cadedaniel cadedaniel merged commit b38e42f into vllm-project:main May 1, 2024
48 checks passed
@leiwen83 leiwen83 deleted the ngram_lookahead_specinfer branch May 2, 2024 08:58
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 7, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants