Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] API support for beam search for MQLLMEngine #9117

Merged
merged 13 commits into from
Oct 8, 2024
43 changes: 19 additions & 24 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,30 +495,25 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text

try:
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=prompts,
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but
# not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
except BadRequestError as e:
# the only allowed exception is when beam search is not supported
# in the default mqllmengine
assert "--disable-frontend-multiprocessing" in str(e)
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=prompts,
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but
# not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"

# test streaming
batch = await client.completions.create(
Expand Down
14 changes: 6 additions & 8 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.entrypoints.llm import BeamSearchSequence
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
Expand All @@ -32,8 +31,9 @@
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
get_beam_search_score, random_uuid, weak_bind)
from vllm.utils import (BeamSearchSequence, collect_from_async_generator,
create_sort_beams_key_function, deprecate_kwargs,
random_uuid, weak_bind)

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -1052,16 +1052,14 @@ async def beam_search(
temperature = params.temperature
length_penalty = params.length_penalty

def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)

tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer, length_penalty=length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
Expand Down
113 changes: 107 additions & 6 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
Union, overload)
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, overload)

import cloudpickle
import zmq
Expand All @@ -27,14 +27,17 @@
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType
from vllm.inputs import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
from vllm.utils import (BeamSearchSequence, collect_from_async_generator,
create_sort_beams_key_function, deprecate_kwargs,
random_uuid)

logger = init_logger(__name__)

Expand Down Expand Up @@ -441,6 +444,104 @@ def generate(
lora_request, trace_headers,
prompt_adapter_request, priority)

async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:

beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty

tokenizer = await self.get_tokenizer(None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add some sort of kwarg to show why it is None here

tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer, length_penalty=length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []

for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]

tasks = []

request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)

output = await asyncio.gather(*tasks)

output = [x[0] for x in output]

logger.info(output)

new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]

if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]

completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]

for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])

beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)

logger.info(beam_search_output)

yield beam_search_output

@overload # DEPRECATED
def encode(
self,
Expand Down
37 changes: 3 additions & 34 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import itertools
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload)

Expand All @@ -28,43 +27,13 @@
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
is_list_of)
from vllm.utils import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, Counter, deprecate_kwargs,
get_beam_search_score, is_list_of)

logger = init_logger(__name__)


@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None


@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]


class BeamSearchInstance:

def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []


class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.

Expand Down
18 changes: 10 additions & 8 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_hf_chat_template,
Expand Down Expand Up @@ -236,15 +237,16 @@ async def create_chat_completion(
log_tracing_disabled_warning()

if isinstance(sampling_params, BeamSearchParams):
if not isinstance(self.engine_client, AsyncLLMEngine):
raise ValueError(
"Beam search in the API server is only supported with"
" AsyncLLMEngine. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search.")
assert isinstance(self.engine_client,
(AsyncLLMEngine,
MQLLMEngineClient)), \
"Beam search is only supported with" \
"AsyncLLMEngine and MQLLMEngineClient."
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'], request_id,
sampling_params)
engine_inputs['prompt_token_ids'],
request_id,
sampling_params,
)
else:
result_generator = self.engine_client.generate(
engine_inputs,
Expand Down
18 changes: 10 additions & 8 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
Expand Down Expand Up @@ -150,15 +151,16 @@ async def create_completion(
log_tracing_disabled_warning()

if isinstance(sampling_params, BeamSearchParams):
if not isinstance(self.engine_client, AsyncLLMEngine):
raise ValueError(
"Beam search in the API server is only supported"
" with AsyncLLMEngine. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search.")
assert isinstance(self.engine_client,
(AsyncLLMEngine,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, you don't actually need the assert

You can beam_search to the EngineClientProtocol

Copy link
Contributor Author

@LunrEclipse LunrEclipse Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the base EngineClientProtocol has a beam_search function. I can add it in though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not. But the EngineClientProtocol defines the behavior of AsyncLLMEngine and MQLLMEngine. So now that both support it, you can expand EngineClientProtocol to include the beam_search api

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot find EngineClientProtocol . does it exist now? @robertgshaw2-neuralmagic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, it should be EngineClient class. but MQLLMEngineClient does not inherit from EngineClient . we can make it a future step to absorb beam search implementation into the EngineClient .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EngineClient is a protocol. MQLLMEngine should inherit from this. If it doesn’t, I’ll submit a PR to make it (since we support the full API). On train so AFK

Either way, we are about to collapse MQLLMEngine and AsyncLLMEngine once we have PP working, so the concept of an EngineClient will be removed once this is done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. I'll go ahead and merge this pr after it is ready. and after you make MQLLMEngine inherit from EngineClient, we can merge separate beam search implementation in one place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that it's currently a Protocol, MQLLMEngine technically doesn't need to subclass it directly. But it would probably be good to anyhow and actually we could consider changing it to an ABC instead.

I agree with @robertgshaw2-neuralmagic that this method should just be added to EngineClient though and we should not need these type assertions.

Not directly related to this PR but I also think we should consider renaming it to something like AsyncEngineClient, and have a way to obtain an instance of an AsyncEngineClient which doesn't involve explicit construction. And that would replace explicit use of AsyncLLMEngine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes are welcome on this part!

MQLLMEngineClient)), \
"Beam search is only supported with" \
"AsyncLLMEngine and MQLLMEngineClient."
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"], request_id_item,
sampling_params)
prompt_inputs["prompt_token_ids"],
request_id,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between request_id and request_id_item ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh yeah it's supposed to be request_id_item, I was a bit careless when I was doing some refactoring. I think completion supports multiple prompts so each prompt has it's own request_id_item for a general request_id.

sampling_params,
)
else:
generator = self.engine_client.generate(
{
Expand Down
Loading
Loading