Skip to content

Commit

Permalink
[Core][VLM] Add precise multi-modal placeholder tracking (#8346)
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Salas <[email protected]>
  • Loading branch information
petersalas authored Nov 1, 2024
1 parent d151fde commit 6c0b7f5
Show file tree
Hide file tree
Showing 53 changed files with 914 additions and 282 deletions.
6 changes: 1 addition & 5 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int):
tokenize=False,
add_generation_prompt=True)

llm = LLM(model=model_name,
enforce_eager=True,
enable_chunked_prefill=False,
max_model_len=8192,
limit_mm_per_prompt={"audio": audio_count})
llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down
2 changes: 2 additions & 0 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def make_test_metadata(
return attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down Expand Up @@ -914,6 +915,7 @@ def make_test_metadata(
return attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down
91 changes: 74 additions & 17 deletions tests/models/decoder_only/audio_language/test_ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import numpy as np
import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding

from tests.utils import RemoteOpenAIServer
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

Expand All @@ -17,6 +19,13 @@
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
HF_PLACEHOLDER = "<|audio|>"

CHUNKED_PREFILL_KWARGS = {
"enable_chunked_prefill": True,
"max_num_seqs": 2,
# Use a very small limit to exercise chunked prefill.
"max_num_batched_tokens": 16
}


@pytest.fixture(scope="session")
def audio_assets():
Expand All @@ -30,6 +39,26 @@ def audio(request):
return AudioAsset(request.param)


@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS))
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client


def _get_prompt(audio_count, question, placeholder):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
placeholder = f"{placeholder}\n" * audio_count
Expand Down Expand Up @@ -68,8 +97,7 @@ def run_test(
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
**kwargs,
):
"""Inference result should be the same between hf and vllm."""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
Expand All @@ -79,11 +107,8 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
with vllm_runner(model, dtype=dtype, enforce_eager=True,
**kwargs) as vllm_model:
vllm_outputs_per_audio = [
vllm_model.generate_greedy_logprobs([vllm_prompt],
max_tokens,
Expand Down Expand Up @@ -135,18 +160,16 @@ def run_multi_audio_test(
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
**kwargs,
):
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={
"audio":
max((len(audio) for _, audio in prompts_and_audios))
}) as vllm_model:
},
**kwargs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
[prompt for prompt, _ in prompts_and_audios],
max_tokens,
Expand All @@ -162,8 +185,9 @@ def run_multi_audio_test(
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
num_logprobs: int, vllm_kwargs: dict) -> None:

vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
Expand All @@ -175,17 +199,18 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
**vllm_kwargs,
)


@pytest.mark.core_model
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
max_tokens: int, num_logprobs: int,
vllm_kwargs: dict) -> None:

vllm_prompt = _get_prompt(len(audio_assets),
"Describe each of the audios above.",
Expand All @@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
**vllm_kwargs,
)


@pytest.mark.asyncio
async def test_online_inference(client, audio_assets):
"""Exercises online inference with/without chunked prefill enabled."""

messages = [{
"role":
"user",
"content": [
*[{
"type": "audio_url",
"audio_url": {
"url": audio.url
}
} for audio in audio_assets],
{
"type":
"text",
"text":
f"What's happening in these {len(audio_assets)} audio clips?"
},
],
}]

chat_completion = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=10)

assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
14 changes: 7 additions & 7 deletions tests/multimodal/test_processor_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest
import torch

from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
from vllm.inputs.registry import InputRegistry
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
InputRegistry, token_inputs)
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

Expand Down Expand Up @@ -56,7 +56,7 @@ def custom_dummy_data_factory(self,
num_crops=DEFAULT_NUM_CROPS):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
return seq_data, None
return DummyData(seq_data, None)

with patch(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
Expand Down Expand Up @@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count
assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count


@pytest.mark.parametrize(
Expand All @@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS


### Test overrides for the max token count per multimodal instance
Expand Down
57 changes: 45 additions & 12 deletions tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model):
tokenizer = AutoTokenizer.from_pretrained(model)

test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]),
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
("<image><image>", [3, 2], "<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000]),
("Image:<image>Image:<image>!", [3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
]

for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
(
"<image>",
2,
"<image><image>",
[32000, 32000],
[{ "offset": 0, "length": 2 }],
),
(
"<image><image>",
2,
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 2 }]),
(
"<image><image>",
[3, 2],
"<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000],
[{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }],
),
(
"Image:<image>Image:<image>!",
[3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }],
),
(
"<image>",
[3, 2],
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 3 }],
),
] # yapf: disable

for (
prompt,
repeat_count,
expected_prompt,
expected_token_ids,
expected_ranges,
) in test_cases:
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer,
prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt,
Expand All @@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model):
)
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
assert ranges == expected_ranges
3 changes: 3 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_model_runner_input():
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
)
model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
Expand Down Expand Up @@ -124,6 +125,7 @@ def test_embedding_model_runner_input():
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
Expand Down Expand Up @@ -174,6 +176,7 @@ def test_multi_step_model_runner_input():
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
Expand Down
11 changes: 11 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import torch

from vllm.multimodal import MultiModalPlaceholderMap

if TYPE_CHECKING:
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
Expand Down Expand Up @@ -108,6 +110,15 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor

# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]

@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def prefill_metadata(
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -243,6 +245,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
Expand Down
Loading

0 comments on commit 6c0b7f5

Please sign in to comment.