Skip to content

Commit

Permalink
[Frontend] API support for beam search for MQLLMEngine (vllm-project#…
Browse files Browse the repository at this point in the history
…9117)

Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
LunrEclipse authored and sumitd2 committed Nov 14, 2024
1 parent 3bcb124 commit d1a3584
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 106 deletions.
43 changes: 19 additions & 24 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,30 +495,25 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text

try:
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=prompts,
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but
# not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
except BadRequestError as e:
# the only allowed exception is when beam search is not supported
# in the default mqllmengine
assert "--disable-frontend-multiprocessing" in str(e)
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=prompts,
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but
# not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"

# test streaming
batch = await client.completions.create(
Expand Down
61 changes: 61 additions & 0 deletions vllm/beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from dataclasses import dataclass
from typing import List, Optional


@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None


@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]


class BeamSearchInstance:

def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []


def get_beam_search_score(
tokens: List[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len = len(tokens)
if tokens[-1] == eos_token_id:
seq_len -= 1

return cumulative_logprob / (seq_len**length_penalty)


def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):

def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id,
length_penalty)

return sort_beams_key
12 changes: 5 additions & 7 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from weakref import ReferenceType

import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.entrypoints.llm import BeamSearchSequence
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
Expand All @@ -33,7 +33,7 @@
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
get_beam_search_score, random_uuid, weak_bind)
random_uuid, weak_bind)

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -1052,16 +1052,14 @@ async def beam_search(
temperature = params.temperature
length_penalty = params.length_penalty

def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)

tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
Expand Down
113 changes: 107 additions & 6 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
Union, overload)
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, overload)

import cloudpickle
import zmq
Expand All @@ -12,6 +12,7 @@
from zmq.asyncio import Socket

from vllm import PoolingParams
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
Expand All @@ -27,14 +28,16 @@
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType
from vllm.inputs import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
random_uuid)

logger = init_logger(__name__)

Expand Down Expand Up @@ -441,6 +444,104 @@ def generate(
lora_request, trace_headers,
prompt_adapter_request, priority)

async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:

beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty

tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []

for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]

tasks = []

request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)

output = await asyncio.gather(*tasks)

output = [x[0] for x in output]

logger.info(output)

new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]

if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]

completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]

for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])

beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)

logger.info(beam_search_output)

yield beam_search_output

@overload # DEPRECATED
def encode(
self,
Expand Down
37 changes: 3 additions & 34 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import itertools
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload)

from tqdm import tqdm

from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
Expand All @@ -28,43 +29,11 @@
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
is_list_of)
from vllm.utils import Counter, deprecate_kwargs, is_list_of

logger = init_logger(__name__)


@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None


@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]


class BeamSearchInstance:

def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []


class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
Expand Down
18 changes: 10 additions & 8 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_hf_chat_template,
Expand Down Expand Up @@ -236,15 +237,16 @@ async def create_chat_completion(
log_tracing_disabled_warning()

if isinstance(sampling_params, BeamSearchParams):
if not isinstance(self.engine_client, AsyncLLMEngine):
raise ValueError(
"Beam search in the API server is only supported with"
" AsyncLLMEngine. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search.")
assert isinstance(self.engine_client,
(AsyncLLMEngine,
MQLLMEngineClient)), \
"Beam search is only supported with" \
"AsyncLLMEngine and MQLLMEngineClient."
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'], request_id,
sampling_params)
engine_inputs['prompt_token_ids'],
request_id,
sampling_params,
)
else:
result_generator = self.engine_client.generate(
engine_inputs,
Expand Down
Loading

0 comments on commit d1a3584

Please sign in to comment.