Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] re-enable multi-modality input in the new beam search implementation #9427

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
08ab78e
update of beam search function
FerdinandZhong Oct 15, 2024
cac55e1
update of testing
FerdinandZhong Oct 16, 2024
358f89c
Merge remote-tracking branch 'upstream/main' into beam_search_multi_m…
FerdinandZhong Oct 16, 2024
2dde695
fix error in implementation
FerdinandZhong Oct 16, 2024
eb92b7d
add checking for logprobs and add more test cases
FerdinandZhong Oct 16, 2024
014d753
formatting
FerdinandZhong Oct 16, 2024
eae5b9b
Merge remote-tracking branch 'upstream/main' into beam_search_multi_m…
FerdinandZhong Oct 17, 2024
5f0e1cd
update BeamSequence, prompt preprocess and adding stop_reason
FerdinandZhong Oct 17, 2024
6e29318
Merge branch 'beam_search_multi_modality' of https://github.com/Ferdi…
FerdinandZhong Oct 17, 2024
5a256cb
fix the wrong declaration
FerdinandZhong Oct 17, 2024
b01a615
formatting
FerdinandZhong Oct 17, 2024
bc74931
Merge branch 'main' of https://github.com/vllm-project/vllm into beam…
FerdinandZhong Oct 18, 2024
8291a80
remove checking for logprobs
FerdinandZhong Oct 18, 2024
a682b63
format
FerdinandZhong Oct 18, 2024
8940743
Merge branch 'main' of https://github.com/vllm-project/vllm into beam…
FerdinandZhong Oct 18, 2024
bb53cbd
output beam's logprobs to Output's logprobs
FerdinandZhong Oct 18, 2024
c275ae3
Merge branch 'main' of https://github.com/vllm-project/vllm into beam…
FerdinandZhong Oct 19, 2024
3b7ab92
update calling of beam_search from serving_completion based on latest…
FerdinandZhong Oct 19, 2024
f96fa9a
Merge branch 'main' of github.com:vllm-project/vllm into beam_search_…
FerdinandZhong Oct 22, 2024
314a31e
Merge branch 'main' of https://github.com/vllm-project/vllm into beam…
FerdinandZhong Oct 28, 2024
8705266
Merge branch 'main' of https://github.com/vllm-project/vllm into beam…
FerdinandZhong Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
model_name: str,
image_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]

chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
Expand Down Expand Up @@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded(
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_base64encoded_beamsearch(
client: openai.AsyncOpenAI, model_name: str, image_url: str,
base64_encoded_image: Dict[str, str]):

messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url":
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
Expand Down
9 changes: 8 additions & 1 deletion vllm/beam_search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from vllm.sequence import Logprob

if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict


@dataclass
class BeamSearchSequence:
Expand All @@ -16,6 +19,10 @@ class BeamSearchSequence:
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0
text: Optional[str] = None
finish_reason: Optional[str] = None
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
stop_reason: Union[int, str, None] = None
multi_modal_data: Optional["MultiModalDataDict"] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None


@dataclass
Expand Down
88 changes: 57 additions & 31 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -59,7 +60,8 @@ def generate(

async def beam_search(
self,
prompt: Union[str, List[int]],
prompt: Union[PromptType, List[int]],
model_config: ModelConfig,
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
Expand All @@ -69,32 +71,40 @@ async def beam_search(
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output

tokenizer = await self.get_tokenizer(lora_request=None)
if isinstance(prompt, str):
tokenized_prompt = tokenizer.encode(prompt)
prompt_text = prompt
else:
tokenized_prompt = prompt
prompt_text = None
tokenized_length = len(tokenized_prompt)
tokenizer = await self.get_tokenizer()
input_preprocessor = InputPreprocessor(model_config, tokenizer)

(prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
prompt,
request_id=request_id,
)
tokenized_length = len(prompt_token_ids)

sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(tokens=tokenized_prompt,
BeamSearchSequence(tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
cum_logprob=0)
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
]
completed = []

for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
TokensPrompt(prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs)
for beam in all_beams
]

Expand All @@ -120,17 +130,31 @@ async def beam_search(
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
completed.append(
BeamSearchSequence(
tokens=current_beam.tokens +
[token_id] if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
finish_reason="stop",
stop_reason=tokenizer.eos_token_id))
else:
new_beams.append(new_beam)
new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.
multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs))

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
Expand All @@ -151,16 +175,18 @@ async def beam_search(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams)
CompletionOutput(text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
finish_reason=beam.finish_reason if
beam.finish_reason is not None else "length",
stop_reason=beam.stop_reason)
for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenized_prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None)

yield beam_search_output
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def to_beam_search_params(self,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
include_stop_str_in_output=self.include_stop_str_in_output)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
Expand Down Expand Up @@ -606,7 +606,7 @@ def to_beam_search_params(self,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
include_stop_str_in_output=self.include_stop_str_in_output)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
Expand Down
7 changes: 4 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,10 @@ async def create_chat_completion(

if isinstance(sampling_params, BeamSearchParams):
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'],
request_id,
sampling_params,
prompt=engine_inputs,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
else:
result_generator = self.engine_client.generate(
Expand Down
10 changes: 7 additions & 3 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,13 @@ async def create_completion(

if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"],
request_id_item,
sampling_params,
prompt={
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
else:
generator = self.engine_client.generate(
Expand Down
1 change: 1 addition & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,4 @@ class BeamSearchParams(
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
include_stop_str_in_output: bool = False