-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Speculative decoding] Add ngram prompt lookup decoding #4237
[Speculative decoding] Add ngram prompt lookup decoding #4237
Conversation
Implementaion for the feature of ngram prompt lookup mentioned in #2188 |
@cadedaniel |
Thanks for the PR! I'll review it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another batch of comments. Also we would need unit tests for this PR.
I current have one e2e unittest in tests/spec_decode/e2e/test_correctness.py. Do you mean may need some other unit tests? |
@cadedaniel @comaniac , |
Sorry let me make it more clear. The unit tests should cover as many cases as possible. For example, batch size > 1 with some seqs find a match but others don't; n-gram size from 1 to 3; long/short speculative sizes, etc. |
You're free to use what strategy you prefer to remove the merge conflicts. I often will merge main into my dev branch but there's tradeoffs. I'd keep this PR open, you can push/force push as necessary! And sorry for the conflicts -- PR 7/9 was the last large one to the subsystem. I expect less merge conflicts going forward.
+1. We'll want the following tests:
Most of these can be copied from the existing tests and refitted for ngram speculation; I suggest making a new file for spec ngram test correctness (so the current one can be draft model). |
44a1530
to
e870757
Compare
@cadedaniel @comaniac Thx~ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM (the n-gram worker). I'll leave the rest to @cadedaniel
will take another pass today |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just minor comments, otherwise LGTM! Can you address/respond and then we'll merge.
@@ -0,0 +1,241 @@ | |||
"""The tests in this file verify end-to-end speculative decoding correctness. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this note
]) | ||
@pytest.mark.parametrize("batch_size", [1, 64]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_spec_decode_e2e_greedy_correctness_tiny_model(baseline_llm_generator, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename test_ngram_e2e_greedy_correctness
]) | ||
@pytest.mark.parametrize("batch_size", [4]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_spec_decode_e2e_greedy_correctness_with_preemption( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_ngram_e2e_greedy_correctness_with_preemption
"""Verify greedy equality on a tiny model with batch size of one. | ||
|
||
Since this test is cheaper than other e2e correctness tests, we generate | ||
with a higher output_len. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update comment -- this does more than bs=1
{ | ||
"speculative_model": "[ngram]", | ||
"num_speculative_tokens": k, | ||
"ngram_prompt_lookup_max": 3, | ||
} | ||
# Try a range of common k, as well as large speculation. | ||
for k in [1, 3, 5, 7, 10, 63] | ||
] + [ | ||
{ | ||
"speculative_model": "[ngram]", | ||
"num_speculative_tokens": k, | ||
"ngram_prompt_lookup_max": 1, | ||
} | ||
# Try a range of common k, as well as large speculation. | ||
for k in [1, 3, 5, 7, 10, 63] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve test time, we can reduce the space covered -- I suggest k=[1, 3, 5]
x ngram_prompt_lookup_max=[1, 3]
.
vllm/engine/arg_utils.py
Outdated
parser.add_argument( | ||
'--ngram-prompt-lookup-max', | ||
type=int, | ||
default=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: default=EngineArgs.ngram_prompt_lookup_max
vllm/engine/arg_utils.py
Outdated
parser.add_argument( | ||
'--ngram-prompt-lookup-min', | ||
type=int, | ||
default=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: default=EngineArgs.ngram_prompt_lookup_min
draft_model_config = target_model_config | ||
draft_parallel_config = target_parallel_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a TODO here to set these to None
?
) -> Tuple[List[SamplerOutput], bool]: | ||
"""Run the model forward pass sample_len times. Returns the list of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docs on the new return value ?
sampler_output_list: List[SamplerOutput], | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
sampler_output_list: List[SamplerOutput], | ||
sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Utility function which converts a list of SamplerOutput to tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a docstring with new arg
Algo details could refer to this blog post: https://huggingface.co/blog/assisted-generation Code directly refer to transformers's current implementation. huggingface/transformers#27775 Since we directly get draft from prompt, there is no need another model or modified model to get the proposal, it would be the most convenient way to enjoy the speedup of speculation.
748c687
to
8adaf38
Compare
@cadedaniel all note has been addressed, and rebase it against latest code, could you take another look? |
Looks good, thanks @leiwen83 ! Thanks for contributing to vLLM 😃 |
…#4237) Co-authored-by: Lei Wen <[email protected]>
…#4237) Co-authored-by: Lei Wen <[email protected]>
…#4237) Co-authored-by: Lei Wen <[email protected]>
…#4237) Co-authored-by: Lei Wen <[email protected]>
Algo details could refer to this blog post:
https://huggingface.co/blog/assisted-generation
Code directly refer to transformers's current implementation. huggingface/transformers#27775
Since we directly get draft from prompt, there is no need another model or modified model to get the proposal, it would be the most convenient way to enjoy the speedup of speculation.