From 5add5d7552053050a2ea5d17fd599906fac120b4 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 06:10:34 +0000 Subject: [PATCH 01/13] port over beam search logic --- vllm/engine/multiprocessing/client.py | 106 +++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 4 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index b0d061dbab4a1..b40b52f276948 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -3,7 +3,7 @@ import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, - Union, overload) + Union, List, overload) import cloudpickle import zmq @@ -26,13 +26,14 @@ RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) # yapf: enable +from vllm.entrypoints.llm import BeamSearchSequence from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType +from vllm.inputs import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import CompletionOutput, EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import deprecate_kwargs @@ -441,6 +442,103 @@ def generate( lora_request, trace_headers, prompt_adapter_request, priority) + async def beam_search( + self, + prompt: Union[PromptType, List[int]], + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + + tokenizer = await self.get_tokenizer() + tokenizedPrompt = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + tokenizedLength = len(tokenizedPrompt) + + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] + completed = [] + + for _ in range(max_tokens): + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + tasks = [] + + request_id = f"beam_search-{random_uuid()}" + for i, individual_prompt in enumerate(prompts_batch): + request_id_item = f"{request_id}-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.generate(individual_prompt, beam_search_params, + request_id_item))) + tasks.append(task) + + output = await asyncio.gather(*tasks) + + output = [x[0] for x in output] + + logger.info(output) + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + completed.append(new_beam) + else: + new_beams.append(new_beam) + + sorted_beams = sorted(new_beams, + key=lambda x: x.cum_logprob, + reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, + key=lambda x: x.cum_logprob, + reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt, + outputs=[ + CompletionOutput( + text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens, + index=i, + logprobs=beam.cum_logprob, + ) for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=tokenizedPrompt, + prompt_logprobs=None) + + yield beam_search_output, RequestOutput + + @overload # DEPRECATED def encode( self, From 1375b59a8de73231f5b901b32d01ed49f7088b91 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 06:35:15 +0000 Subject: [PATCH 02/13] integrate mqllm engine --- vllm/engine/multiprocessing/client.py | 16 ++++++++++------ vllm/entrypoints/openai/serving_chat.py | 13 +++++++++---- vllm/entrypoints/openai/serving_completion.py | 12 +++++++----- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index b40b52f276948..76b35587b12a2 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -2,8 +2,8 @@ import copy import pickle from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, - Union, List, overload) +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, + Optional, Union, overload) import cloudpickle import zmq @@ -31,11 +31,13 @@ from vllm.inputs import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import CompletionOutput, EmbeddingRequestOutput, RequestOutput +from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, + RequestOutput) from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import deprecate_kwargs +from vllm.utils import (collect_from_async_generator, deprecate_kwargs, + random_uuid) logger = init_logger(__name__) @@ -447,6 +449,7 @@ async def beam_search( prompt: Union[PromptType, List[int]], request_id: str, params: BeamSearchParams, + lora_request: Optional[LoRARequest] = None ) -> AsyncGenerator[RequestOutput, None]: beam_width = params.beam_width @@ -454,7 +457,7 @@ async def beam_search( ignore_eos = params.ignore_eos temperature = params.temperature - tokenizer = await self.get_tokenizer() + tokenizer = await self.get_tokenizer(lora_request) tokenizedPrompt = prompt if isinstance( prompt, list) else tokenizer.encode(prompt) tokenizedLength = len(tokenizedPrompt) @@ -536,8 +539,9 @@ async def beam_search( prompt_token_ids=tokenizedPrompt, prompt_logprobs=None) - yield beam_search_output, RequestOutput + logger.info(beam_search_output) + yield beam_search_output @overload # DEPRECATED def encode( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c4652be6fe821..3ffea35f4e427 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -10,6 +10,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_hf_chat_template, @@ -236,15 +237,19 @@ async def create_chat_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - if not isinstance(self.engine_client, AsyncLLMEngine): + if not isinstance(self.engine_client, AsyncLLMEngine) and \ + not isinstance(self.engine_client, MQLLMEngineClient): raise ValueError( "Beam search in the API server is only supported with" - " AsyncLLMEngine. please add " + " AsyncLLMEngine and MQLLMEngineClient. please add " "`--disable-frontend-multiprocessing` to " "use beam search.") result_generator = self.engine_client.beam_search( - engine_inputs['prompt_token_ids'], request_id, - sampling_params) + engine_inputs['prompt_token_ids'], + request_id, + sampling_params, + lora_request, + ) else: result_generator = self.engine_client.generate( engine_inputs, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index bf9e9850797a6..c2263b781f660 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -9,6 +9,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block @@ -150,15 +151,16 @@ async def create_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - if not isinstance(self.engine_client, AsyncLLMEngine): + if not isinstance(self.engine_client, AsyncLLMEngine) and \ + not isinstance(self.engine_client, MQLLMEngineClient): raise ValueError( "Beam search in the API server is only supported" - " with AsyncLLMEngine. please add " - "`--disable-frontend-multiprocessing` to " - "use beam search.") + " with AsyncLLMEngine and MQLLMEngineClient." + " please add `--disable-frontend-multiprocessing`" + " to use beam search.") generator = self.engine_client.beam_search( prompt_inputs["prompt_token_ids"], request_id_item, - sampling_params) + sampling_params, lora_request) else: generator = self.engine_client.generate( { From b57bff2a2ddeb5abbdbd27ffb78610c8c6ce9281 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 06:47:07 +0000 Subject: [PATCH 03/13] fix engine differences --- vllm/entrypoints/openai/serving_chat.py | 22 ++++++++++++------- vllm/entrypoints/openai/serving_completion.py | 14 +++++++----- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 3ffea35f4e427..67dd5f3355ab4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -237,19 +237,25 @@ async def create_chat_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - if not isinstance(self.engine_client, AsyncLLMEngine) and \ - not isinstance(self.engine_client, MQLLMEngineClient): + if isinstance(self.engine_client, AsyncLLMEngine): + result_generator = self.engine_client.beam_search( + engine_inputs['prompt_token_ids'], + request_id, + sampling_params, + ) + elif isinstance(self.engine_client, MQLLMEngineClient): + result_generator = self.engine_client.beam_search( + engine_inputs['prompt_token_ids'], + request_id, + sampling_params, + lora_request, + ) + else: raise ValueError( "Beam search in the API server is only supported with" " AsyncLLMEngine and MQLLMEngineClient. please add " "`--disable-frontend-multiprocessing` to " "use beam search.") - result_generator = self.engine_client.beam_search( - engine_inputs['prompt_token_ids'], - request_id, - sampling_params, - lora_request, - ) else: result_generator = self.engine_client.generate( engine_inputs, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c2263b781f660..0afca1ba7e889 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -151,16 +151,20 @@ async def create_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - if not isinstance(self.engine_client, AsyncLLMEngine) and \ - not isinstance(self.engine_client, MQLLMEngineClient): + if isinstance(self.engine_client, AsyncLLMEngine): + generator = self.engine_client.beam_search( + prompt_inputs["prompt_token_ids"], request_id_item, + sampling_params) + elif isinstance(self.engine_client, MQLLMEngineClient): + generator = self.engine_client.beam_search( + prompt_inputs["prompt_token_ids"], request_id_item, + sampling_params, lora_request) + else: raise ValueError( "Beam search in the API server is only supported" " with AsyncLLMEngine and MQLLMEngineClient." " please add `--disable-frontend-multiprocessing`" " to use beam search.") - generator = self.engine_client.beam_search( - prompt_inputs["prompt_token_ids"], request_id_item, - sampling_params, lora_request) else: generator = self.engine_client.generate( { From 8384451b62432c9c517afe79e6e21493075353fe Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 20:33:31 +0000 Subject: [PATCH 04/13] update from comments --- tests/entrypoints/openai/test_completion.py | 43 +++++++++------------ vllm/engine/multiprocessing/client.py | 16 ++++---- 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 61da5513cb130..cc72a49ebbbda 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -495,30 +495,25 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): assert len(batch.choices) == 2 assert batch.choices[0].text == batch.choices[1].text - try: - # test n = 2 - batch = await client.completions.create( - model=model_name, - prompt=prompts, - n=2, - max_tokens=5, - temperature=0.0, - extra_body=dict( - # NOTE: this has to be true for n > 1 in vLLM, but - # not necessary for official client. - use_beam_search=True), - ) - assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" - except BadRequestError as e: - # the only allowed exception is when beam search is not supported - # in the default mqllmengine - assert "--disable-frontend-multiprocessing" in str(e) + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but + # not necessary for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" # test streaming batch = await client.completions.create( diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 76b35587b12a2..f5129c8465258 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -37,7 +37,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import (collect_from_async_generator, deprecate_kwargs, - random_uuid) + get_beam_search_score, random_uuid) logger = init_logger(__name__) @@ -456,6 +456,12 @@ async def beam_search( max_tokens = params.max_tokens ignore_eos = params.ignore_eos temperature = params.temperature + length_penalty = params.length_penalty + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, + tokenizer.eos_token_id, + length_penalty) tokenizer = await self.get_tokenizer(lora_request) tokenizedPrompt = prompt if isinstance( @@ -509,15 +515,11 @@ async def beam_search( else: new_beams.append(new_beam) - sorted_beams = sorted(new_beams, - key=lambda x: x.cum_logprob, - reverse=True) + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) all_beams = sorted_beams[:beam_width] completed.extend(all_beams) - sorted_completed = sorted(completed, - key=lambda x: x.cum_logprob, - reverse=True) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) best_beams = sorted_completed[:beam_width] for beam in best_beams: From f9843df338bbbaca8dd20027c338e3b2e1ae5971 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 21:16:52 +0000 Subject: [PATCH 05/13] update from review and refactor --- vllm/engine/async_llm_engine.py | 14 +++---- vllm/engine/multiprocessing/client.py | 17 ++++---- vllm/entrypoints/llm.py | 37 ++--------------- vllm/entrypoints/openai/serving_chat.py | 29 +++++-------- vllm/entrypoints/openai/serving_completion.py | 24 +++++------ vllm/utils.py | 41 +++++++++++++++++++ 6 files changed, 77 insertions(+), 85 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 50269493d64e9..f019cd6e00567 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -14,7 +14,6 @@ from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase -from vllm.entrypoints.llm import BeamSearchSequence from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster @@ -32,8 +31,9 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import (collect_from_async_generator, deprecate_kwargs, - get_beam_search_score, random_uuid, weak_bind) +from vllm.utils import (BeamSearchSequence, collect_from_async_generator, + create_sort_beams_key_function, deprecate_kwargs, + random_uuid, weak_bind) logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -1052,16 +1052,14 @@ async def beam_search( temperature = params.temperature length_penalty = params.length_penalty - def sort_beams_key(x: BeamSearchSequence) -> float: - return get_beam_search_score(x.tokens, x.cum_logprob, - tokenizer.eos_token_id, - length_penalty) - tokenizer = await self.get_tokenizer() tokenizedPrompt = prompt if isinstance( prompt, list) else tokenizer.encode(prompt) tokenizedLength = len(tokenizedPrompt) + sort_beams_key = create_sort_beams_key_function( + tokenizer, length_penalty=length_penalty) + beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, temperature=temperature) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index f5129c8465258..bfdfcc1952ec8 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -26,7 +26,6 @@ RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) # yapf: enable -from vllm.entrypoints.llm import BeamSearchSequence from vllm.envs import VLLM_RPC_TIMEOUT from vllm.inputs import PromptType, TokensPrompt from vllm.logger import init_logger @@ -36,8 +35,9 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import (collect_from_async_generator, deprecate_kwargs, - get_beam_search_score, random_uuid) +from vllm.utils import (BeamSearchSequence, collect_from_async_generator, + create_sort_beams_key_function, deprecate_kwargs, + random_uuid) logger = init_logger(__name__) @@ -449,7 +449,6 @@ async def beam_search( prompt: Union[PromptType, List[int]], request_id: str, params: BeamSearchParams, - lora_request: Optional[LoRARequest] = None ) -> AsyncGenerator[RequestOutput, None]: beam_width = params.beam_width @@ -458,16 +457,14 @@ async def beam_search( temperature = params.temperature length_penalty = params.length_penalty - def sort_beams_key(x: BeamSearchSequence) -> float: - return get_beam_search_score(x.tokens, x.cum_logprob, - tokenizer.eos_token_id, - length_penalty) - - tokenizer = await self.get_tokenizer(lora_request) + tokenizer = await self.get_tokenizer(None) tokenizedPrompt = prompt if isinstance( prompt, list) else tokenizer.encode(prompt) tokenizedLength = len(tokenizedPrompt) + sort_beams_key = create_sort_beams_key_function( + tokenizer, length_penalty=length_penalty) + beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, temperature=temperature) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 439f3769f9fbd..947b82b4afd9f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,7 +1,6 @@ import itertools import warnings from contextlib import contextmanager -from dataclasses import dataclass from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Union, cast, overload) @@ -28,43 +27,13 @@ get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext -from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score, - is_list_of) +from vllm.utils import (BeamSearchInstance, BeamSearchOutput, + BeamSearchSequence, Counter, deprecate_kwargs, + get_beam_search_score, is_list_of) logger = init_logger(__name__) -@dataclass -class BeamSearchSequence: - """A sequence for beam search. - It keeps track of the tokens and the log probability of the sequence. - The text field is optional and will only be filled when the sequence is - about to be returned to the user. - """ - # The tokens includes the prompt. - tokens: List[int] - cum_logprob: float = 0.0 - text: Optional[str] = None - - -@dataclass -class BeamSearchOutput: - """The output of beam search. - It contains the list of the best beam search sequences. - The length of the list is equal to the beam width. - """ - sequences: List[BeamSearchSequence] - - -class BeamSearchInstance: - - def __init__(self, prompt_tokens: List[int]): - self.beams: List[BeamSearchSequence] = [ - BeamSearchSequence(tokens=prompt_tokens) - ] - self.completed: List[BeamSearchSequence] = [] - - class LLM: """An LLM for generating texts from given prompts and sampling parameters. diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 67dd5f3355ab4..1e85167ea7619 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -237,25 +237,16 @@ async def create_chat_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - if isinstance(self.engine_client, AsyncLLMEngine): - result_generator = self.engine_client.beam_search( - engine_inputs['prompt_token_ids'], - request_id, - sampling_params, - ) - elif isinstance(self.engine_client, MQLLMEngineClient): - result_generator = self.engine_client.beam_search( - engine_inputs['prompt_token_ids'], - request_id, - sampling_params, - lora_request, - ) - else: - raise ValueError( - "Beam search in the API server is only supported with" - " AsyncLLMEngine and MQLLMEngineClient. please add " - "`--disable-frontend-multiprocessing` to " - "use beam search.") + assert isinstance(self.engine_client, + (AsyncLLMEngine, + MQLLMEngineClient)), \ + "Beam search is only supported with" \ + "AsyncLLMEngine and MQLLMEngineClient." + result_generator = self.engine_client.beam_search( + engine_inputs['prompt_token_ids'], + request_id, + sampling_params, + ) else: result_generator = self.engine_client.generate( engine_inputs, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 0afca1ba7e889..530fbf4fa0cd3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -151,20 +151,16 @@ async def create_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - if isinstance(self.engine_client, AsyncLLMEngine): - generator = self.engine_client.beam_search( - prompt_inputs["prompt_token_ids"], request_id_item, - sampling_params) - elif isinstance(self.engine_client, MQLLMEngineClient): - generator = self.engine_client.beam_search( - prompt_inputs["prompt_token_ids"], request_id_item, - sampling_params, lora_request) - else: - raise ValueError( - "Beam search in the API server is only supported" - " with AsyncLLMEngine and MQLLMEngineClient." - " please add `--disable-frontend-multiprocessing`" - " to use beam search.") + assert isinstance(self.engine_client, + (AsyncLLMEngine, + MQLLMEngineClient)), \ + "Beam search is only supported with" \ + "AsyncLLMEngine and MQLLMEngineClient." + generator = self.engine_client.beam_search( + prompt_inputs["prompt_token_ids"], + request_id, + sampling_params, + ) else: generator = self.engine_client.generate( { diff --git a/vllm/utils.py b/vllm/utils.py index 1b7638c4a12ac..f59b6981e9d8b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -17,6 +17,7 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, ensure_future +from dataclasses import dataclass from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, @@ -1363,6 +1364,37 @@ def value(self): return self._value +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: List[int] + cum_logprob: float = 0.0 + text: Optional[str] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: List[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__(self, prompt_tokens: List[int]): + self.beams: List[BeamSearchSequence] = [ + BeamSearchSequence(tokens=prompt_tokens) + ] + self.completed: List[BeamSearchSequence] = [] + + def get_beam_search_score( tokens: List[int], cumulative_logprob: float, @@ -1380,3 +1412,12 @@ def get_beam_search_score( seq_len -= 1 return cumulative_logprob / (seq_len**length_penalty) + + +def create_sort_beams_key_function(tokenizer, length_penalty): + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, + tokenizer.eos_token_id, length_penalty) + + return sort_beams_key From a33ea39a746a89e2712f622e51bfe56525aece3b Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 21:29:10 +0000 Subject: [PATCH 06/13] remove assert and add kwarg --- vllm/engine/multiprocessing/client.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 7 ------- vllm/entrypoints/openai/serving_completion.py | 7 ------- 3 files changed, 1 insertion(+), 15 deletions(-) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index bfdfcc1952ec8..ec48b52b47441 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -457,7 +457,7 @@ async def beam_search( temperature = params.temperature length_penalty = params.length_penalty - tokenizer = await self.get_tokenizer(None) + tokenizer = await self.get_tokenizer(lora_request=None) tokenizedPrompt = prompt if isinstance( prompt, list) else tokenizer.encode(prompt) tokenizedLength = len(tokenizedPrompt) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1e85167ea7619..5253ee5afdd11 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,8 +9,6 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_hf_chat_template, @@ -237,11 +235,6 @@ async def create_chat_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - assert isinstance(self.engine_client, - (AsyncLLMEngine, - MQLLMEngineClient)), \ - "Beam search is only supported with" \ - "AsyncLLMEngine and MQLLMEngineClient." result_generator = self.engine_client.beam_search( engine_inputs['prompt_token_ids'], request_id, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 530fbf4fa0cd3..076e220b5615a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,8 +8,6 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block @@ -151,11 +149,6 @@ async def create_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - assert isinstance(self.engine_client, - (AsyncLLMEngine, - MQLLMEngineClient)), \ - "Beam search is only supported with" \ - "AsyncLLMEngine and MQLLMEngineClient." generator = self.engine_client.beam_search( prompt_inputs["prompt_token_ids"], request_id, From ac5520cffefff7654081306370253b70f2c11f63 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 21:31:57 +0000 Subject: [PATCH 07/13] add asserts back --- vllm/entrypoints/openai/serving_chat.py | 7 +++++++ vllm/entrypoints/openai/serving_completion.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5253ee5afdd11..1e85167ea7619 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,6 +9,8 @@ from fastapi import Request from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_hf_chat_template, @@ -235,6 +237,11 @@ async def create_chat_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): + assert isinstance(self.engine_client, + (AsyncLLMEngine, + MQLLMEngineClient)), \ + "Beam search is only supported with" \ + "AsyncLLMEngine and MQLLMEngineClient." result_generator = self.engine_client.beam_search( engine_inputs['prompt_token_ids'], request_id, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 076e220b5615a..530fbf4fa0cd3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,6 +8,8 @@ from fastapi import Request from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block @@ -149,6 +151,11 @@ async def create_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): + assert isinstance(self.engine_client, + (AsyncLLMEngine, + MQLLMEngineClient)), \ + "Beam search is only supported with" \ + "AsyncLLMEngine and MQLLMEngineClient." generator = self.engine_client.beam_search( prompt_inputs["prompt_token_ids"], request_id, From b3c5d05105a5f5cca5b4f433826504ac6269b8d7 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 22:04:53 +0000 Subject: [PATCH 08/13] change input --- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/multiprocessing/client.py | 2 +- vllm/utils.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f019cd6e00567..c1f653dfe5c6f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1058,7 +1058,7 @@ async def beam_search( tokenizedLength = len(tokenizedPrompt) sort_beams_key = create_sort_beams_key_function( - tokenizer, length_penalty=length_penalty) + tokenizer.eos_token_id, length_penalty) beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index ec48b52b47441..8160670ce711c 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -463,7 +463,7 @@ async def beam_search( tokenizedLength = len(tokenizedPrompt) sort_beams_key = create_sort_beams_key_function( - tokenizer, length_penalty=length_penalty) + tokenizer.eos_token_id, length_penalty) beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, diff --git a/vllm/utils.py b/vllm/utils.py index f59b6981e9d8b..dc6da4c4a6c9b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1414,10 +1414,10 @@ def get_beam_search_score( return cumulative_logprob / (seq_len**length_penalty) -def create_sort_beams_key_function(tokenizer, length_penalty): +def create_sort_beams_key_function(eos_token_id: int, length_penalty): def sort_beams_key(x: BeamSearchSequence) -> float: - return get_beam_search_score(x.tokens, x.cum_logprob, - tokenizer.eos_token_id, length_penalty) + return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, + length_penalty) return sort_beams_key From b22e8bd455c0513e01626f7c997015c8382394c1 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 22:06:08 +0000 Subject: [PATCH 09/13] add typing --- vllm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index dc6da4c4a6c9b..6bce6686c44ab 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1414,7 +1414,7 @@ def get_beam_search_score( return cumulative_logprob / (seq_len**length_penalty) -def create_sort_beams_key_function(eos_token_id: int, length_penalty): +def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): def sort_beams_key(x: BeamSearchSequence) -> float: return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, From 097479dd586700435e95bc82be3ae9e6b5d18bb0 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 22:22:30 +0000 Subject: [PATCH 10/13] change serving_completion back --- vllm/entrypoints/openai/serving_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 530fbf4fa0cd3..077312dd1414e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -158,7 +158,7 @@ async def create_completion( "AsyncLLMEngine and MQLLMEngineClient." generator = self.engine_client.beam_search( prompt_inputs["prompt_token_ids"], - request_id, + request_id_item, sampling_params, ) else: From 44334eee2ebb4bda54bd1eb631eba665c1040515 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Mon, 7 Oct 2024 22:47:42 +0000 Subject: [PATCH 11/13] move beam search classes --- vllm/engine/async_llm_engine.py | 6 ++-- vllm/engine/multiprocessing/client.py | 4 +-- vllm/entrypoints/llm.py | 7 +++-- vllm/sequence.py | 41 +++++++++++++++++++++++++++ vllm/utils.py | 41 --------------------------- 5 files changed, 50 insertions(+), 49 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index c1f653dfe5c6f..eaf26bc2ac7d7 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -28,11 +28,11 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import ExecuteModelRequest +from vllm.sequence import (BeamSearchSequence, ExecuteModelRequest, + create_sort_beams_key_function) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import (BeamSearchSequence, collect_from_async_generator, - create_sort_beams_key_function, deprecate_kwargs, +from vllm.utils import (collect_from_async_generator, deprecate_kwargs, random_uuid, weak_bind) logger = init_logger(__name__) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 8160670ce711c..dad2771bbf634 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -34,9 +34,9 @@ RequestOutput) from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import BeamSearchSequence, create_sort_beams_key_function from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import (BeamSearchSequence, collect_from_async_generator, - create_sort_beams_key_function, deprecate_kwargs, +from vllm.utils import (collect_from_async_generator, deprecate_kwargs, random_uuid) logger = init_logger(__name__) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 947b82b4afd9f..a3dd50ab94785 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -23,13 +23,14 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) +from vllm.sequence import (BeamSearchInstance, BeamSearchOutput, + BeamSearchSequence) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext -from vllm.utils import (BeamSearchInstance, BeamSearchOutput, - BeamSearchSequence, Counter, deprecate_kwargs, - get_beam_search_score, is_list_of) +from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score, + is_list_of) logger = init_logger(__name__) diff --git a/vllm/sequence.py b/vllm/sequence.py index 9116408a001ff..60f64964b641e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -20,6 +20,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics +from vllm.utils import get_beam_search_score if TYPE_CHECKING: from vllm.multimodal.base import MultiModalDataDict @@ -1365,3 +1366,43 @@ def clone( last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, async_callback=self.async_callback) + + +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: List[int] + cum_logprob: float = 0.0 + text: Optional[str] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: List[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__(self, prompt_tokens: List[int]): + self.beams: List[BeamSearchSequence] = [ + BeamSearchSequence(tokens=prompt_tokens) + ] + self.completed: List[BeamSearchSequence] = [] + + +def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, + length_penalty) + + return sort_beams_key diff --git a/vllm/utils.py b/vllm/utils.py index 6bce6686c44ab..1b7638c4a12ac 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -17,7 +17,6 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, ensure_future -from dataclasses import dataclass from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, @@ -1364,37 +1363,6 @@ def value(self): return self._value -@dataclass -class BeamSearchSequence: - """A sequence for beam search. - It keeps track of the tokens and the log probability of the sequence. - The text field is optional and will only be filled when the sequence is - about to be returned to the user. - """ - # The tokens includes the prompt. - tokens: List[int] - cum_logprob: float = 0.0 - text: Optional[str] = None - - -@dataclass -class BeamSearchOutput: - """The output of beam search. - It contains the list of the best beam search sequences. - The length of the list is equal to the beam width. - """ - sequences: List[BeamSearchSequence] - - -class BeamSearchInstance: - - def __init__(self, prompt_tokens: List[int]): - self.beams: List[BeamSearchSequence] = [ - BeamSearchSequence(tokens=prompt_tokens) - ] - self.completed: List[BeamSearchSequence] = [] - - def get_beam_search_score( tokens: List[int], cumulative_logprob: float, @@ -1412,12 +1380,3 @@ def get_beam_search_score( seq_len -= 1 return cumulative_logprob / (seq_len**length_penalty) - - -def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): - - def sort_beams_key(x: BeamSearchSequence) -> float: - return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, - length_penalty) - - return sort_beams_key From 4918f6afa2a72e5be1ea56f52c86b1bc553a3da1 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Tue, 8 Oct 2024 03:41:23 +0000 Subject: [PATCH 12/13] add new file for beam search --- vllm/beam_search.py | 59 +++++++++++++++++++++++++++ vllm/engine/async_llm_engine.py | 4 +- vllm/engine/multiprocessing/client.py | 2 +- vllm/entrypoints/llm.py | 7 ++-- vllm/sequence.py | 41 ------------------- vllm/utils.py | 19 --------- 6 files changed, 65 insertions(+), 67 deletions(-) create mode 100644 vllm/beam_search.py diff --git a/vllm/beam_search.py b/vllm/beam_search.py new file mode 100644 index 0000000000000..490c6141fceb2 --- /dev/null +++ b/vllm/beam_search.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass +from typing import List, Optional + +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: List[int] + cum_logprob: float = 0.0 + text: Optional[str] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: List[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__(self, prompt_tokens: List[int]): + self.beams: List[BeamSearchSequence] = [ + BeamSearchSequence(tokens=prompt_tokens) + ] + self.completed: List[BeamSearchSequence] = [] + +def get_beam_search_score( + tokens: List[int], + cumulative_logprob: float, + eos_token_id: int, + length_penalty: float = 1.0, +) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + seq_len = len(tokens) + if tokens[-1] == eos_token_id: + seq_len -= 1 + + return cumulative_logprob / (seq_len**length_penalty) + + +def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, + length_penalty) + + return sort_beams_key diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index eaf26bc2ac7d7..30e1a09981c57 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,6 +7,7 @@ from weakref import ReferenceType import vllm.envs as envs +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs @@ -28,8 +29,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import (BeamSearchSequence, ExecuteModelRequest, - create_sort_beams_key_function) +from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import (collect_from_async_generator, deprecate_kwargs, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index dad2771bbf634..820f678abeff5 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -12,6 +12,7 @@ from zmq.asyncio import Socket from vllm import PoolingParams +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block @@ -34,7 +35,6 @@ RequestOutput) from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import BeamSearchSequence, create_sort_beams_key_function from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import (collect_from_async_generator, deprecate_kwargs, random_uuid) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a3dd50ab94785..b0a8a66ec133f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,6 +6,8 @@ from tqdm import tqdm +from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, + BeamSearchSequence, get_beam_search_score) from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, @@ -23,14 +25,11 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) -from vllm.sequence import (BeamSearchInstance, BeamSearchOutput, - BeamSearchSequence) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext -from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score, - is_list_of) +from vllm.utils import Counter, deprecate_kwargs, is_list_of logger = init_logger(__name__) diff --git a/vllm/sequence.py b/vllm/sequence.py index 60f64964b641e..9116408a001ff 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -20,7 +20,6 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -from vllm.utils import get_beam_search_score if TYPE_CHECKING: from vllm.multimodal.base import MultiModalDataDict @@ -1366,43 +1365,3 @@ def clone( last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, async_callback=self.async_callback) - - -@dataclass -class BeamSearchSequence: - """A sequence for beam search. - It keeps track of the tokens and the log probability of the sequence. - The text field is optional and will only be filled when the sequence is - about to be returned to the user. - """ - # The tokens includes the prompt. - tokens: List[int] - cum_logprob: float = 0.0 - text: Optional[str] = None - - -@dataclass -class BeamSearchOutput: - """The output of beam search. - It contains the list of the best beam search sequences. - The length of the list is equal to the beam width. - """ - sequences: List[BeamSearchSequence] - - -class BeamSearchInstance: - - def __init__(self, prompt_tokens: List[int]): - self.beams: List[BeamSearchSequence] = [ - BeamSearchSequence(tokens=prompt_tokens) - ] - self.completed: List[BeamSearchSequence] = [] - - -def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): - - def sort_beams_key(x: BeamSearchSequence) -> float: - return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, - length_penalty) - - return sort_beams_key diff --git a/vllm/utils.py b/vllm/utils.py index 1b7638c4a12ac..e44365fa24990 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1361,22 +1361,3 @@ def dec(self, num=1): @property def value(self): return self._value - - -def get_beam_search_score( - tokens: List[int], - cumulative_logprob: float, - eos_token_id: int, - length_penalty: float = 1.0, -) -> float: - """Calculate the beam search score with length penalty. - - Adapted from - - https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 - """ - seq_len = len(tokens) - if tokens[-1] == eos_token_id: - seq_len -= 1 - - return cumulative_logprob / (seq_len**length_penalty) From e65ed315dda117d206162c3f739c21833c7c3e86 Mon Sep 17 00:00:00 2001 From: Brendan Wong Date: Tue, 8 Oct 2024 03:42:30 +0000 Subject: [PATCH 13/13] fix import --- vllm/beam_search.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 490c6141fceb2..04624b8b94432 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from typing import List, Optional + @dataclass class BeamSearchSequence: """A sequence for beam search. @@ -31,6 +32,7 @@ def __init__(self, prompt_tokens: List[int]): ] self.completed: List[BeamSearchSequence] = [] + def get_beam_search_score( tokens: List[int], cumulative_logprob: float,