-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] API support for beam search for MQLLMEngine #9117
Changes from 1 commit
5add5d7
1375b59
b57bff2
8384451
f9843df
a33ea39
ac5520c
b3c5d05
b22e8bd
097479d
44334ee
4918f6a
e65ed31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,20 +151,16 @@ async def create_completion( | |
log_tracing_disabled_warning() | ||
|
||
if isinstance(sampling_params, BeamSearchParams): | ||
if isinstance(self.engine_client, AsyncLLMEngine): | ||
generator = self.engine_client.beam_search( | ||
prompt_inputs["prompt_token_ids"], request_id_item, | ||
sampling_params) | ||
elif isinstance(self.engine_client, MQLLMEngineClient): | ||
generator = self.engine_client.beam_search( | ||
prompt_inputs["prompt_token_ids"], request_id_item, | ||
sampling_params, lora_request) | ||
else: | ||
raise ValueError( | ||
"Beam search in the API server is only supported" | ||
" with AsyncLLMEngine and MQLLMEngineClient." | ||
" please add `--disable-frontend-multiprocessing`" | ||
" to use beam search.") | ||
assert isinstance(self.engine_client, | ||
(AsyncLLMEngine, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, you don't actually need the You can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the base There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not. But the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I cannot find There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay, it should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. EngineClient is a protocol. MQLLMEngine should inherit from this. If it doesn’t, I’ll submit a PR to make it (since we support the full API). On train so AFK Either way, we are about to collapse MQLLMEngine and AsyncLLMEngine once we have PP working, so the concept of an EngineClient will be removed once this is done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good. I'll go ahead and merge this pr after it is ready. and after you make MQLLMEngine inherit from EngineClient, we can merge separate beam search implementation in one place. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that it's currently a I agree with @robertgshaw2-neuralmagic that this method should just be added to Not directly related to this PR but I also think we should consider renaming it to something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changes are welcome on this part! |
||
MQLLMEngineClient)), \ | ||
"Beam search is only supported with" \ | ||
"AsyncLLMEngine and MQLLMEngineClient." | ||
generator = self.engine_client.beam_search( | ||
prompt_inputs["prompt_token_ids"], | ||
request_id, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the difference between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh yeah it's supposed to be |
||
sampling_params, | ||
) | ||
else: | ||
generator = self.engine_client.generate( | ||
{ | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -17,6 +17,7 @@ | |||||
import warnings | ||||||
import weakref | ||||||
from asyncio import FIRST_COMPLETED, ensure_future | ||||||
from dataclasses import dataclass | ||||||
from functools import lru_cache, partial, wraps | ||||||
from platform import uname | ||||||
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, | ||||||
|
@@ -1363,6 +1364,37 @@ def value(self): | |||||
return self._value | ||||||
|
||||||
|
||||||
@dataclass | ||||||
class BeamSearchSequence: | ||||||
"""A sequence for beam search. | ||||||
It keeps track of the tokens and the log probability of the sequence. | ||||||
The text field is optional and will only be filled when the sequence is | ||||||
about to be returned to the user. | ||||||
""" | ||||||
# The tokens includes the prompt. | ||||||
tokens: List[int] | ||||||
cum_logprob: float = 0.0 | ||||||
text: Optional[str] = None | ||||||
|
||||||
|
||||||
@dataclass | ||||||
class BeamSearchOutput: | ||||||
"""The output of beam search. | ||||||
It contains the list of the best beam search sequences. | ||||||
The length of the list is equal to the beam width. | ||||||
""" | ||||||
sequences: List[BeamSearchSequence] | ||||||
|
||||||
|
||||||
class BeamSearchInstance: | ||||||
|
||||||
def __init__(self, prompt_tokens: List[int]): | ||||||
self.beams: List[BeamSearchSequence] = [ | ||||||
BeamSearchSequence(tokens=prompt_tokens) | ||||||
] | ||||||
self.completed: List[BeamSearchSequence] = [] | ||||||
|
||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move them to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a circular import error if I put these classes in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logically, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should i also move the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, go ahead! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure whether these should go in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
makes sense. how about |
||||||
def get_beam_search_score( | ||||||
tokens: List[int], | ||||||
cumulative_logprob: float, | ||||||
|
@@ -1380,3 +1412,12 @@ def get_beam_search_score( | |||||
seq_len -= 1 | ||||||
|
||||||
return cumulative_logprob / (seq_len**length_penalty) | ||||||
|
||||||
|
||||||
def create_sort_beams_key_function(tokenizer, length_penalty): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
def sort_beams_key(x: BeamSearchSequence) -> float: | ||||||
return get_beam_search_score(x.tokens, x.cum_logprob, | ||||||
tokenizer.eos_token_id, length_penalty) | ||||||
|
||||||
return sort_beams_key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add some sort of kwarg to show why it is
None
here