Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] API support for beam search for MQLLMEngine #9117

Merged
merged 13 commits into from
Oct 8, 2024
14 changes: 6 additions & 8 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.entrypoints.llm import BeamSearchSequence
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
Expand All @@ -32,8 +31,9 @@
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
get_beam_search_score, random_uuid, weak_bind)
from vllm.utils import (BeamSearchSequence, collect_from_async_generator,
create_sort_beams_key_function, deprecate_kwargs,
random_uuid, weak_bind)

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -1052,16 +1052,14 @@ async def beam_search(
temperature = params.temperature
length_penalty = params.length_penalty

def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)

tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer, length_penalty=length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
Expand Down
17 changes: 7 additions & 10 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.entrypoints.llm import BeamSearchSequence
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType, TokensPrompt
from vllm.logger import init_logger
Expand All @@ -36,8 +35,9 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
get_beam_search_score, random_uuid)
from vllm.utils import (BeamSearchSequence, collect_from_async_generator,
create_sort_beams_key_function, deprecate_kwargs,
random_uuid)

logger = init_logger(__name__)

Expand Down Expand Up @@ -449,7 +449,6 @@ async def beam_search(
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
lora_request: Optional[LoRARequest] = None
) -> AsyncGenerator[RequestOutput, None]:

beam_width = params.beam_width
Expand All @@ -458,16 +457,14 @@ async def beam_search(
temperature = params.temperature
length_penalty = params.length_penalty

def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)

tokenizer = await self.get_tokenizer(lora_request)
tokenizer = await self.get_tokenizer(None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add some sort of kwarg to show why it is None here

tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer, length_penalty=length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
Expand Down
37 changes: 3 additions & 34 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import itertools
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload)

Expand All @@ -28,43 +27,13 @@
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
is_list_of)
from vllm.utils import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, Counter, deprecate_kwargs,
get_beam_search_score, is_list_of)

logger = init_logger(__name__)


@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None


@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]


class BeamSearchInstance:

def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []


class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.

Expand Down
29 changes: 10 additions & 19 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,25 +237,16 @@ async def create_chat_completion(
log_tracing_disabled_warning()

if isinstance(sampling_params, BeamSearchParams):
if isinstance(self.engine_client, AsyncLLMEngine):
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'],
request_id,
sampling_params,
)
elif isinstance(self.engine_client, MQLLMEngineClient):
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'],
request_id,
sampling_params,
lora_request,
)
else:
raise ValueError(
"Beam search in the API server is only supported with"
" AsyncLLMEngine and MQLLMEngineClient. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search.")
assert isinstance(self.engine_client,
(AsyncLLMEngine,
MQLLMEngineClient)), \
"Beam search is only supported with" \
"AsyncLLMEngine and MQLLMEngineClient."
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'],
request_id,
sampling_params,
)
else:
result_generator = self.engine_client.generate(
engine_inputs,
Expand Down
24 changes: 10 additions & 14 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,16 @@ async def create_completion(
log_tracing_disabled_warning()

if isinstance(sampling_params, BeamSearchParams):
if isinstance(self.engine_client, AsyncLLMEngine):
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"], request_id_item,
sampling_params)
elif isinstance(self.engine_client, MQLLMEngineClient):
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"], request_id_item,
sampling_params, lora_request)
else:
raise ValueError(
"Beam search in the API server is only supported"
" with AsyncLLMEngine and MQLLMEngineClient."
" please add `--disable-frontend-multiprocessing`"
" to use beam search.")
assert isinstance(self.engine_client,
(AsyncLLMEngine,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, you don't actually need the assert

You can beam_search to the EngineClientProtocol

Copy link
Contributor Author

@LunrEclipse LunrEclipse Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the base EngineClientProtocol has a beam_search function. I can add it in though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not. But the EngineClientProtocol defines the behavior of AsyncLLMEngine and MQLLMEngine. So now that both support it, you can expand EngineClientProtocol to include the beam_search api

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot find EngineClientProtocol . does it exist now? @robertgshaw2-neuralmagic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, it should be EngineClient class. but MQLLMEngineClient does not inherit from EngineClient . we can make it a future step to absorb beam search implementation into the EngineClient .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EngineClient is a protocol. MQLLMEngine should inherit from this. If it doesn’t, I’ll submit a PR to make it (since we support the full API). On train so AFK

Either way, we are about to collapse MQLLMEngine and AsyncLLMEngine once we have PP working, so the concept of an EngineClient will be removed once this is done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. I'll go ahead and merge this pr after it is ready. and after you make MQLLMEngine inherit from EngineClient, we can merge separate beam search implementation in one place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that it's currently a Protocol, MQLLMEngine technically doesn't need to subclass it directly. But it would probably be good to anyhow and actually we could consider changing it to an ABC instead.

I agree with @robertgshaw2-neuralmagic that this method should just be added to EngineClient though and we should not need these type assertions.

Not directly related to this PR but I also think we should consider renaming it to something like AsyncEngineClient, and have a way to obtain an instance of an AsyncEngineClient which doesn't involve explicit construction. And that would replace explicit use of AsyncLLMEngine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes are welcome on this part!

MQLLMEngineClient)), \
"Beam search is only supported with" \
"AsyncLLMEngine and MQLLMEngineClient."
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"],
request_id,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between request_id and request_id_item ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh yeah it's supposed to be request_id_item, I was a bit careless when I was doing some refactoring. I think completion supports multiple prompts so each prompt has it's own request_id_item for a general request_id.

sampling_params,
)
else:
generator = self.engine_client.generate(
{
Expand Down
41 changes: 41 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import warnings
import weakref
from asyncio import FIRST_COMPLETED, ensure_future
from dataclasses import dataclass
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Expand Down Expand Up @@ -1363,6 +1364,37 @@ def value(self):
return self._value


@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None


@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]


class BeamSearchInstance:

def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move them to vllm/sequence.py ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a circular import error if I put these classes in vllm/sequence.py as BeamSearchSequence is needed in vllm/utils.py, but vllm/sequence.py indirectly imports from vllm/utils.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logically, vllm/utils.py should not import vllm/sequence.py . we should change the code if this is the case.

Copy link
Contributor Author

@LunrEclipse LunrEclipse Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should i also move the create_sort_beams_key_function to vllm/sequence.py? That will solve the issue

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, go ahead!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether these should go in sequence.py since that holds "internal" data-structures used within the scheduler etc. and if I understand correctly, BeamSearchSequence etc. are only used in the outer layer(s). Maybe better to have a dedicated file an the appropriate place in the tree for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to have a dedicated file an the appropriate place in the tree for this?

makes sense. how about vllm/beam_search.py ?

def get_beam_search_score(
tokens: List[int],
cumulative_logprob: float,
Expand All @@ -1380,3 +1412,12 @@ def get_beam_search_score(
seq_len -= 1

return cumulative_logprob / (seq_len**length_penalty)


def create_sort_beams_key_function(tokenizer, length_penalty):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def create_sort_beams_key_function(tokenizer, length_penalty):
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):


def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id, length_penalty)

return sort_beams_key
Loading