From 08ab78e8fce449033fb5fdd0bb4032043e9628e2 Mon Sep 17 00:00:00 2001 From: Qishuai Date: Tue, 15 Oct 2024 23:28:23 +0800 Subject: [PATCH 01/12] update of beam search function --- vllm/beam_search.py | 1 + vllm/engine/protocol.py | 69 ++++++++++++++----- vllm/entrypoints/openai/protocol.py | 2 + vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 2 +- vllm/sampling_params.py | 1 + 6 files changed, 57 insertions(+), 20 deletions(-) diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 04624b8b94432..fe04fd9501bc6 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -13,6 +13,7 @@ class BeamSearchSequence: tokens: List[int] cum_logprob: float = 0.0 text: Optional[str] = None + finish_reason: Optional[str] = None @dataclass diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 16ceddf13511c..4df80751c7b02 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,5 +1,6 @@ import asyncio from abc import ABC, abstractmethod +from copy import deepcopy from typing import AsyncGenerator, List, Mapping, Optional, Union from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function @@ -69,24 +70,46 @@ async def beam_search( ignore_eos = params.ignore_eos temperature = params.temperature length_penalty = params.length_penalty + include_stop_str_in_output = params.include_stop_str_in_output tokenizer = await self.get_tokenizer(lora_request=None) - tokenizedPrompt = prompt if isinstance( - prompt, list) else tokenizer.encode(prompt) - tokenizedLength = len(tokenizedPrompt) + + if isinstance(prompt, dict): + if "prompt" in prompt: + tokenized_prompt = tokenizer.encode(prompt.get("prompt")) + multi_modal_data = prompt.get("multi_modal_data") + mm_processor_kwargs = prompt.get("mm_processor_kwargs") + elif "prompt_token_ids" in prompt: + tokenized_prompt = tokenizer.encode(prompt.get("prompt")) + multi_modal_data = prompt.get("multi_modal_data") + mm_processor_kwargs = prompt.get("mm_processor_kwargs") + raise TypeError("Inputs in Dictionary type must be a TextPrompt or TokensPrompt") + else: + tokenized_prompt = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + multi_modal_data = None + mm_processor_kwargs = None + + tokenized_length = len(tokenized_prompt) sort_beams_key = create_sort_beams_key_function( tokenizer.eos_token_id, length_penalty) - beam_search_params = SamplingParams(logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature) - all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] + beam_search_params = SamplingParams( + logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature, + ) + all_beams = [BeamSearchSequence(tokens=tokenized_prompt, cum_logprob=0)] completed = [] for _ in range(max_tokens): prompts_batch = [ - TokensPrompt(prompt_token_ids=beam.tokens) + TokensPrompt( + prompt_token_ids=beam.tokens, + multi_modal_data=deepcopy(multi_modal_data), # always the values from inputs + mm_processor_kwargs=deepcopy(mm_processor_kwargs) + ) for beam in all_beams ] @@ -112,16 +135,25 @@ async def beam_search( if result.outputs[0].logprobs is not None: logprobs = result.outputs[0].logprobs[0] for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob) - if token_id == tokenizer.eos_token_id and \ not ignore_eos: - completed.append(new_beam) + completed.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id] + if include_stop_str_in_output else current_beam.tokens, # + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + finish_reason="stop" + ) + ) else: - new_beams.append(new_beam) + new_beams.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id], # + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + ) + ) sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) all_beams = sorted_beams[:beam_width] @@ -131,11 +163,11 @@ async def beam_search( best_beams = sorted_completed[:beam_width] for beam in best_beams: - beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) + beam.text = tokenizer.decode(beam.tokens[tokenized_length:]) beam_search_output = RequestOutput( request_id=request_id, - prompt=prompt, + prompt=tokenizer.decode(tokenized_prompt), outputs=[ CompletionOutput( text=beam.text, @@ -143,10 +175,11 @@ async def beam_search( token_ids=beam.tokens, index=i, logprobs=beam.cum_logprob, + finish_reason=beam.finish_reason if beam.finish_reason is not None else "length" ) for (i, beam) in enumerate(best_beams) ], finished=True, - prompt_token_ids=tokenizedPrompt, + prompt_token_ids=tokenized_prompt, prompt_logprobs=None) yield beam_search_output diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6f1135f8093ba..335ad5ddcc972 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -302,6 +302,7 @@ def to_beam_search_params(self, ignore_eos=self.ignore_eos, temperature=temperature, length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output ) def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: @@ -594,6 +595,7 @@ def to_beam_search_params(self, ignore_eos=self.ignore_eos, temperature=temperature, length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output ) def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index acb56e4a886e1..70ed12ea9293d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -236,7 +236,7 @@ async def create_chat_completion( if isinstance(sampling_params, BeamSearchParams): result_generator = self.engine_client.beam_search( - engine_inputs['prompt_token_ids'], + engine_inputs, request_id, sampling_params, ) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7aa4587e23c15..2408c3ab0caed 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -150,7 +150,7 @@ async def create_completion( if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( - prompt_inputs["prompt_token_ids"], + prompt_inputs, request_id_item, sampling_params, ) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4f2ae75e65f3a..412dd2b9ce9c4 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -489,3 +489,4 @@ class BeamSearchParams( ignore_eos: bool = False temperature: float = 0.0 length_penalty: float = 1.0 + include_stop_str_in_output: bool = False From cac55e181f131bd5855a6fe079f21e65792b0d20 Mon Sep 17 00:00:00 2001 From: Qishuai Date: Wed, 16 Oct 2024 20:18:54 +0800 Subject: [PATCH 02/12] update of testing --- tests/entrypoints/openai/test_vision.py | 35 +++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 81d79601124a7..e678c87e5dc67 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -160,6 +160,41 @@ async def test_single_chat_session_image_base64encoded( assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_single_chat_session_image_base64encoded_beamsearch( + client: openai.AsyncOpenAI, model_name: str, image_url: str, + base64_encoded_image: Dict[str, str]): + + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": + f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + n=2, + max_tokens=10, + logprobs=True, + top_logprobs=5, + extra_body=dict(use_beam_search=True) + ) + assert len(chat_completion.choices) == 4 + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) From 2dde695b7a27e3c4ddb95d3314df832a01bc83f8 Mon Sep 17 00:00:00 2001 From: Qishuai Date: Wed, 16 Oct 2024 13:28:10 +0000 Subject: [PATCH 03/12] fix error in implementation --- vllm/engine/protocol.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 4df80751c7b02..ad3a7efae7f88 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -80,10 +80,11 @@ async def beam_search( multi_modal_data = prompt.get("multi_modal_data") mm_processor_kwargs = prompt.get("mm_processor_kwargs") elif "prompt_token_ids" in prompt: - tokenized_prompt = tokenizer.encode(prompt.get("prompt")) + tokenized_prompt = prompt.get("prompt_token_ids") multi_modal_data = prompt.get("multi_modal_data") mm_processor_kwargs = prompt.get("mm_processor_kwargs") - raise TypeError("Inputs in Dictionary type must be a TextPrompt or TokensPrompt") + else: + raise TypeError("Inputs in Dictionary type must be a TextPrompt or TokensPrompt") else: tokenized_prompt = prompt if isinstance( prompt, list) else tokenizer.encode(prompt) @@ -143,7 +144,6 @@ async def beam_search( if include_stop_str_in_output else current_beam.tokens, # cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, - finish_reason="stop" ) ) else: @@ -175,13 +175,12 @@ async def beam_search( token_ids=beam.tokens, index=i, logprobs=beam.cum_logprob, - finish_reason=beam.finish_reason if beam.finish_reason is not None else "length" ) for (i, beam) in enumerate(best_beams) ], finished=True, prompt_token_ids=tokenized_prompt, prompt_logprobs=None) - + yield beam_search_output @abstractmethod From eb92b7d858f6ffea16720652d9294bf46416d081 Mon Sep 17 00:00:00 2001 From: Qishuai Date: Wed, 16 Oct 2024 15:36:11 +0000 Subject: [PATCH 04/12] add checking for logprobs and add more test cases --- tests/entrypoints/openai/test_vision.py | 49 +++++++++++++++++++++++-- vllm/engine/protocol.py | 3 +- vllm/entrypoints/openai/protocol.py | 4 ++ 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index e678c87e5dc67..a3f55674f1b78 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -104,6 +104,50 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, + model_name: str, image_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + n=2, + max_tokens=10, + extra_body=dict(use_beam_search=True) + ) + assert len(chat_completion.choices) == 2 + assert chat_completion.choices[0].message.content != chat_completion.choices[1].message.content + + with pytest.raises(openai.BadRequestError) as exc_info: + await client.chat.completions.create(model=model_name, + messages=messages, + n=2, + max_tokens=10, + logprobs=True, + top_logprobs=5, + extra_body=dict(use_beam_search=True) + ) + + # Assert that the exception message is correct + assert "Only the `cumulative_logprob` of each selected sequence will be returned." in str(exc_info.value) + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @@ -188,11 +232,10 @@ async def test_single_chat_session_image_base64encoded_beamsearch( messages=messages, n=2, max_tokens=10, - logprobs=True, - top_logprobs=5, extra_body=dict(use_beam_search=True) ) - assert len(chat_completion.choices) == 4 + assert len(chat_completion.choices) == 2 + assert chat_completion.choices[0].message.content != chat_completion.choices[1].message.content @pytest.mark.asyncio diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index ad3a7efae7f88..0ff32bd9fad6b 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -144,6 +144,7 @@ async def beam_search( if include_stop_str_in_output else current_beam.tokens, # cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, + finish_reason="stop" ) ) else: @@ -172,7 +173,7 @@ async def beam_search( CompletionOutput( text=beam.text, cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens, + token_ids=beam.tokens[tokenized_length:], index=i, logprobs=beam.cum_logprob, ) for (i, beam) in enumerate(best_beams) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 335ad5ddcc972..d5ec4f7c36cdf 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -401,6 +401,10 @@ def check_logprobs(cls, data): raise ValueError( "when using `top_logprobs`, `logprobs` must be set to true." ) + if data.get("logprobs") and data.get("use_beam_search"): + raise ValueError( + "Only the `cumulative_logprob` of each selected sequence will be returned." + ) return data From 014d7535a16412f061d4473fb81bae1711836472 Mon Sep 17 00:00:00 2001 From: Qishuai Date: Wed, 16 Oct 2024 15:50:55 +0000 Subject: [PATCH 05/12] formatting --- tests/entrypoints/openai/test_vision.py | 52 +++++++++++++------------ vllm/engine/protocol.py | 29 +++++++------- vllm/entrypoints/openai/protocol.py | 8 ++-- 3 files changed, 46 insertions(+), 43 deletions(-) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index a3f55674f1b78..459f978146109 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -104,11 +104,13 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 0 + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, - model_name: str, image_url: str): + model_name: str, + image_url: str): messages = [{ "role": "user", @@ -126,27 +128,28 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, ], }] - chat_completion = await client.chat.completions.create(model=model_name, - messages=messages, - n=2, - max_tokens=10, - extra_body=dict(use_beam_search=True) - ) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + n=2, + max_tokens=10, + extra_body=dict(use_beam_search=True)) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[0].message.content != chat_completion.choices[1].message.content + assert chat_completion.choices[ + 0].message.content != chat_completion.choices[1].message.content with pytest.raises(openai.BadRequestError) as exc_info: - await client.chat.completions.create(model=model_name, - messages=messages, - n=2, - max_tokens=10, - logprobs=True, - top_logprobs=5, - extra_body=dict(use_beam_search=True) - ) + await client.chat.completions.create( + model=model_name, + messages=messages, + n=2, + max_tokens=10, + logprobs=True, + top_logprobs=5, + extra_body=dict(use_beam_search=True)) # Assert that the exception message is correct - assert "Only the `cumulative_logprob` of each selected sequence will be returned." in str(exc_info.value) + assert "Only the `cumulative_logprob` " in str(exc_info.value) @pytest.mark.asyncio @@ -228,14 +231,15 @@ async def test_single_chat_session_image_base64encoded_beamsearch( }, ], }] - chat_completion = await client.chat.completions.create(model=model_name, - messages=messages, - n=2, - max_tokens=10, - extra_body=dict(use_beam_search=True) - ) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + n=2, + max_tokens=10, + extra_body=dict(use_beam_search=True)) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[0].message.content != chat_completion.choices[1].message.content + assert chat_completion.choices[ + 0].message.content != chat_completion.choices[1].message.content @pytest.mark.asyncio diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 0ff32bd9fad6b..e52474eebb628 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -84,7 +84,8 @@ async def beam_search( multi_modal_data = prompt.get("multi_modal_data") mm_processor_kwargs = prompt.get("mm_processor_kwargs") else: - raise TypeError("Inputs in Dictionary type must be a TextPrompt or TokensPrompt") + raise TypeError( + "Dictionary input must be a TextPrompt or TokensPrompt") else: tokenized_prompt = prompt if isinstance( prompt, list) else tokenizer.encode(prompt) @@ -101,16 +102,18 @@ async def beam_search( max_tokens=1, temperature=temperature, ) - all_beams = [BeamSearchSequence(tokens=tokenized_prompt, cum_logprob=0)] + all_beams = [ + BeamSearchSequence(tokens=tokenized_prompt, cum_logprob=0) + ] completed = [] for _ in range(max_tokens): prompts_batch = [ TokensPrompt( prompt_token_ids=beam.tokens, - multi_modal_data=deepcopy(multi_modal_data), # always the values from inputs - mm_processor_kwargs=deepcopy(mm_processor_kwargs) - ) + multi_modal_data=deepcopy( + multi_modal_data), # always the values from inputs + mm_processor_kwargs=deepcopy(mm_processor_kwargs)) for beam in all_beams ] @@ -140,21 +143,19 @@ async def beam_search( not ignore_eos: completed.append( BeamSearchSequence( - tokens=current_beam.tokens + [token_id] - if include_stop_str_in_output else current_beam.tokens, # + tokens=current_beam.tokens + + [token_id] if include_stop_str_in_output + else current_beam.tokens, # cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, - finish_reason="stop" - ) - ) + finish_reason="stop")) else: new_beams.append( BeamSearchSequence( - tokens=current_beam.tokens + [token_id], # + tokens=current_beam.tokens + [token_id], # cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, - ) - ) + )) sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) all_beams = sorted_beams[:beam_width] @@ -181,7 +182,7 @@ async def beam_search( finished=True, prompt_token_ids=tokenized_prompt, prompt_logprobs=None) - + yield beam_search_output @abstractmethod diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d5ec4f7c36cdf..dfab436f836f6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -302,8 +302,7 @@ def to_beam_search_params(self, ignore_eos=self.ignore_eos, temperature=temperature, length_penalty=self.length_penalty, - include_stop_str_in_output=self.include_stop_str_in_output - ) + include_stop_str_in_output=self.include_stop_str_in_output) def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens @@ -403,7 +402,7 @@ def check_logprobs(cls, data): ) if data.get("logprobs") and data.get("use_beam_search"): raise ValueError( - "Only the `cumulative_logprob` of each selected sequence will be returned." + "Only the `cumulative_logprob` of each output will be returned." ) return data @@ -599,8 +598,7 @@ def to_beam_search_params(self, ignore_eos=self.ignore_eos, temperature=temperature, length_penalty=self.length_penalty, - include_stop_str_in_output=self.include_stop_str_in_output - ) + include_stop_str_in_output=self.include_stop_str_in_output) def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens From 5f0e1cd1f6eb177c39fc4124087f9cb1a4dbd39d Mon Sep 17 00:00:00 2001 From: Qishuai Date: Thu, 17 Oct 2024 05:51:23 +0000 Subject: [PATCH 06/12] update BeamSequence, prompt preprocess and adding stop_reason --- vllm/beam_search.py | 8 +- vllm/engine/protocol.py | 77 +++++++++---------- vllm/entrypoints/openai/serving_chat.py | 7 +- vllm/entrypoints/openai/serving_completion.py | 7 +- 4 files changed, 52 insertions(+), 47 deletions(-) diff --git a/vllm/beam_search.py b/vllm/beam_search.py index fe04fd9501bc6..d30af35e2624b 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -1,5 +1,8 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalDataDict @dataclass @@ -14,6 +17,9 @@ class BeamSearchSequence: cum_logprob: float = 0.0 text: Optional[str] = None finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + multi_modal_data: Optional["MultiModalDataDict"] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None @dataclass diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e52474eebb628..5b30c5c77db49 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,12 +1,12 @@ import asyncio from abc import ABC, abstractmethod -from copy import deepcopy from typing import AsyncGenerator, List, Mapping, Optional, Union from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -61,6 +61,7 @@ def generate( async def beam_search( self, prompt: Union[PromptType, List[int]], + model_config: ModelConfig, request_id: str, params: BeamSearchParams, ) -> AsyncGenerator[RequestOutput, None]: @@ -72,27 +73,16 @@ async def beam_search( length_penalty = params.length_penalty include_stop_str_in_output = params.include_stop_str_in_output - tokenizer = await self.get_tokenizer(lora_request=None) - - if isinstance(prompt, dict): - if "prompt" in prompt: - tokenized_prompt = tokenizer.encode(prompt.get("prompt")) - multi_modal_data = prompt.get("multi_modal_data") - mm_processor_kwargs = prompt.get("mm_processor_kwargs") - elif "prompt_token_ids" in prompt: - tokenized_prompt = prompt.get("prompt_token_ids") - multi_modal_data = prompt.get("multi_modal_data") - mm_processor_kwargs = prompt.get("mm_processor_kwargs") - else: - raise TypeError( - "Dictionary input must be a TextPrompt or TokensPrompt") - else: - tokenized_prompt = prompt if isinstance( - prompt, list) else tokenizer.encode(prompt) - multi_modal_data = None - mm_processor_kwargs = None - - tokenized_length = len(tokenized_prompt) + tokenizer = await self.get_tokenizer() + self.input_preprocessor = InputPreprocessor(model_config, + self.tokenizer) + + (prompt_text, prompt_token_ids, multi_modal_data, mm_processor_kwargs + ) = self.input_preprocessor._extract_prompt_components( + prompt, + request_id=request_id, + ) + tokenized_length = len(prompt_token_ids) sort_beams_key = create_sort_beams_key_function( tokenizer.eos_token_id, length_penalty) @@ -103,17 +93,18 @@ async def beam_search( temperature=temperature, ) all_beams = [ - BeamSearchSequence(tokens=tokenized_prompt, cum_logprob=0) + BeamSearchSequence(tokens=prompt_token_ids, + cum_logprob=0, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs) ] completed = [] for _ in range(max_tokens): prompts_batch = [ - TokensPrompt( - prompt_token_ids=beam.tokens, - multi_modal_data=deepcopy( - multi_modal_data), # always the values from inputs - mm_processor_kwargs=deepcopy(mm_processor_kwargs)) + TokensPrompt(prompt_token_ids=beam.tokens, + multi_modal_data=beam.multi_modal_data, + mm_processor_kwargs=beam.mm_processor_kwargs) for beam in all_beams ] @@ -148,14 +139,18 @@ async def beam_search( else current_beam.tokens, # cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, - finish_reason="stop")) + finish_reason="stop", + stop_reason=tokenizer.eos_token_id)) else: new_beams.append( BeamSearchSequence( - tokens=current_beam.tokens + [token_id], # + tokens=current_beam.tokens + [token_id], cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, - )) + multi_modal_data=current_beam. + multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs)) sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) all_beams = sorted_beams[:beam_width] @@ -169,18 +164,20 @@ async def beam_search( beam_search_output = RequestOutput( request_id=request_id, - prompt=tokenizer.decode(tokenized_prompt), + prompt=prompt_text, outputs=[ - CompletionOutput( - text=beam.text, - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens[tokenized_length:], - index=i, - logprobs=beam.cum_logprob, - ) for (i, beam) in enumerate(best_beams) + CompletionOutput(text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.cum_logprob, + finish_reason=beam.finish_reason if + beam.finish_reason is not None else "length", + stop_reason=beam.stop_reason) + for (i, beam) in enumerate(best_beams) ], finished=True, - prompt_token_ids=tokenized_prompt, + prompt_token_ids=prompt_token_ids, prompt_logprobs=None) yield beam_search_output diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index f7a578b77fa1d..8704f2d2f604d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -236,9 +236,10 @@ async def create_chat_completion( if isinstance(sampling_params, BeamSearchParams): result_generator = self.engine_client.beam_search( - engine_inputs, - request_id, - sampling_params, + prompt=engine_inputs, + model_config=self.model_config, + request_id=request_id, + params=sampling_params, ) else: result_generator = self.engine_client.generate( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index f4ad52a9e076c..631c7c20682b0 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -150,9 +150,10 @@ async def create_completion( if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( - prompt_inputs, - request_id_item, - sampling_params, + prompt=prompt_inputs, + model_config=self.model_config, + request_id=request_id, + params=sampling_params, ) else: generator = self.engine_client.generate( From 5a256cbf5cb249aeb03fc1868c88951632f97fe8 Mon Sep 17 00:00:00 2001 From: Qishuai Date: Thu, 17 Oct 2024 06:34:39 +0000 Subject: [PATCH 07/12] fix the wrong declaration --- vllm/engine/protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 5b30c5c77db49..80a5480e6525e 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -74,11 +74,11 @@ async def beam_search( include_stop_str_in_output = params.include_stop_str_in_output tokenizer = await self.get_tokenizer() - self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) + input_preprocessor = InputPreprocessor(model_config, + tokenizer) (prompt_text, prompt_token_ids, multi_modal_data, mm_processor_kwargs - ) = self.input_preprocessor._extract_prompt_components( + ) = input_preprocessor._extract_prompt_components( prompt, request_id=request_id, ) From b01a6152418804949a830ec45504a3e8a5fcb891 Mon Sep 17 00:00:00 2001 From: Qishuai Date: Thu, 17 Oct 2024 06:42:01 +0000 Subject: [PATCH 08/12] formatting --- vllm/engine/protocol.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 80a5480e6525e..6ed9efc271960 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -74,11 +74,10 @@ async def beam_search( include_stop_str_in_output = params.include_stop_str_in_output tokenizer = await self.get_tokenizer() - input_preprocessor = InputPreprocessor(model_config, - tokenizer) + input_preprocessor = InputPreprocessor(model_config, tokenizer) - (prompt_text, prompt_token_ids, multi_modal_data, mm_processor_kwargs - ) = input_preprocessor._extract_prompt_components( + (prompt_text, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) = input_preprocessor._extract_prompt_components( prompt, request_id=request_id, ) From 8291a809eaa1b8e71cb6cc44462903fadd7da543 Mon Sep 17 00:00:00 2001 From: Qishuai Date: Fri, 18 Oct 2024 09:50:29 +0000 Subject: [PATCH 09/12] remove checking for logprobs --- tests/entrypoints/openai/test_vision.py | 15 ++------------- vllm/entrypoints/openai/protocol.py | 4 ---- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 459f978146109..d04c52efd2140 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -133,24 +133,13 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, messages=messages, n=2, max_tokens=10, + logprobs=True, + top_logprobs=5, extra_body=dict(use_beam_search=True)) assert len(chat_completion.choices) == 2 assert chat_completion.choices[ 0].message.content != chat_completion.choices[1].message.content - with pytest.raises(openai.BadRequestError) as exc_info: - await client.chat.completions.create( - model=model_name, - messages=messages, - n=2, - max_tokens=10, - logprobs=True, - top_logprobs=5, - extra_body=dict(use_beam_search=True)) - - # Assert that the exception message is correct - assert "Only the `cumulative_logprob` " in str(exc_info.value) - @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index dfab436f836f6..84d1fd9eb8886 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -400,10 +400,6 @@ def check_logprobs(cls, data): raise ValueError( "when using `top_logprobs`, `logprobs` must be set to true." ) - if data.get("logprobs") and data.get("use_beam_search"): - raise ValueError( - "Only the `cumulative_logprob` of each output will be returned." - ) return data From a682b63b11ebcabbb4ec12e14ec660f1e57fdcf9 Mon Sep 17 00:00:00 2001 From: Qishuai Date: Fri, 18 Oct 2024 10:12:58 +0000 Subject: [PATCH 10/12] format --- vllm/engine/protocol.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 2b3957d3ffdc4..64484dc53f71d 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -136,8 +136,9 @@ async def beam_search( BeamSearchSequence( tokens=current_beam.tokens + [token_id] if include_stop_str_in_output - else current_beam.tokens, - logprobs=current_beam.logprobs + [logprobs], + else current_beam.tokens, + logprobs=current_beam.logprobs + + [logprobs], cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, finish_reason="stop", @@ -146,7 +147,8 @@ async def beam_search( new_beams.append( BeamSearchSequence( tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs], + logprobs=current_beam.logprobs + + [logprobs], cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, multi_modal_data=current_beam. From bb53cbd2896e4777839ea823a58fa395af82ad4c Mon Sep 17 00:00:00 2001 From: Qishuai Date: Fri, 18 Oct 2024 13:47:24 +0000 Subject: [PATCH 11/12] output beam's logprobs to Output's logprobs --- vllm/engine/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 64484dc53f71d..06db62125756f 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -174,7 +174,7 @@ async def beam_search( cumulative_logprob=beam.cum_logprob, token_ids=beam.tokens[tokenized_length:], index=i, - logprobs=beam.cum_logprob, + logprobs=beam.logprobs, finish_reason=beam.finish_reason if beam.finish_reason is not None else "length", stop_reason=beam.stop_reason) From 3b7ab9264103ff5171c7d75307b10d4c765510c0 Mon Sep 17 00:00:00 2001 From: Qishuai Date: Sat, 19 Oct 2024 01:25:27 +0000 Subject: [PATCH 12/12] update calling of beam_search from serving_completion based on latest main --- vllm/entrypoints/openai/serving_completion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 599304446f2b1..da521a6012530 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -150,7 +150,10 @@ async def create_completion( if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( - prompt=prompt_inputs, + prompt={ + "prompt_token_ids": + prompt_inputs["prompt_token_ids"] + }, model_config=self.model_config, request_id=request_id, params=sampling_params,