diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 235db72eee4b9..86eddb576c42a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -90,6 +90,7 @@ steps: - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/openai + - pytest -v -s entrypoints/test_chat_utils.py - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index be90394511f89..e1d4055763e5f 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -1,7 +1,13 @@ """An example showing how to use vLLM to serve VLMs. Launch the vLLM server with the following command: + +(single image inference with Llava) vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja + +(multi-image inference with Phi-3.5-vision-instruct) +vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \ + --trust-remote-code --limit-mm-per-prompt image=2 """ import base64 @@ -84,3 +90,36 @@ def encode_image_base64_from_url(image_url: str) -> str: result = chat_completion_from_base64.choices[0].message.content print(f"Chat completion output:{result}") + +# Multi-image input inference +image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" +image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" +chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What are the animals in these images?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url_duck + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url_lion + }, + }, + ], + }], + model=model, + max_tokens=64, +) + +result = chat_completion_from_url.choices[0].message.content +print(f"Chat completion output:{result}") diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 3783b7cd66a6a..c3a6c65be1d90 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from unittest.mock import MagicMock +from vllm.config import MultiModalConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat @@ -20,6 +21,7 @@ class MockModelConfig: max_model_len = 100 tokenizer_revision = None embedding_mode = False + multimodal_config = MultiModalConfig() @dataclass diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index d2ef3c2071efb..f61fa127b7d06 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -6,11 +6,10 @@ from vllm.multimodal.utils import encode_image_base64, fetch_image -from ...utils import VLLM_PATH, RemoteOpenAIServer +from ...utils import RemoteOpenAIServer -MODEL_NAME = "llava-hf/llava-1.5-7b-hf" -LLAVA_CHAT_TEMPLATE = VLLM_PATH / "examples/template_llava.jinja" -assert LLAVA_CHAT_TEMPLATE.exists() +MODEL_NAME = "microsoft/Phi-3.5-vision-instruct" +MAXIMUM_IMAGES = 2 # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) TEST_IMAGE_URLS = [ @@ -24,13 +23,9 @@ @pytest.fixture(scope="module") def server(): args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "4096", - "--enforce-eager", - "--chat-template", - str(LLAVA_CHAT_TEMPLATE), + "--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs", + "5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", + f"image={MAXIMUM_IMAGES}" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -84,7 +79,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=596, total_tokens=606) + completion_tokens=10, prompt_tokens=772, total_tokens=782) message = choice.message message = chat_completion.choices[0].message @@ -139,7 +134,7 @@ async def test_single_chat_session_image_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=596, total_tokens=606) + completion_tokens=10, prompt_tokens=772, total_tokens=782) message = choice.message message = chat_completion.choices[0].message @@ -217,26 +212,22 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, - image_url: str): + image_urls: List[str]): messages = [{ "role": "user", "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { + *({ "type": "image_url", "image_url": { "url": image_url } - }, + } for image_url in image_urls), { "type": "text", "text": "What's in this image?" @@ -244,20 +235,30 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, ], }] - with pytest.raises(openai.BadRequestError): # test multi-image input - await client.chat.completions.create( + if len(image_urls) > MAXIMUM_IMAGES: + with pytest.raises(openai.BadRequestError): # test multi-image input + await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + ) + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 + else: + chat_completion = await client.chat.completions.create( model=model_name, messages=messages, max_tokens=10, temperature=0.0, ) - - # the server should still work afterwards - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - completion = completion.choices[0].text - assert completion is not None and len(completion) >= 0 + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py new file mode 100644 index 0000000000000..53f99189beb1c --- /dev/null +++ b/tests/entrypoints/test_chat_utils.py @@ -0,0 +1,305 @@ +import warnings + +import pytest +from PIL import Image + +from vllm.assets.image import ImageAsset +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import parse_chat_messages +from vllm.multimodal.utils import encode_image_base64 +from vllm.transformers_utils.tokenizer_group import TokenizerGroup + +PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" + + +@pytest.fixture(scope="module") +def phi3v_model_config(): + return ModelConfig(PHI3V_MODEL_ID, + PHI3V_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="bfloat16", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }) + + +@pytest.fixture(scope="module") +def phi3v_tokenizer(): + return TokenizerGroup( + tokenizer_id=PHI3V_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + + +@pytest.fixture(scope="module") +def image_url(): + image = ImageAsset('cherry_blossom') + base64 = encode_image_base64(image.pil_image) + return f"data:image/jpeg;base64,{base64}" + + +@pytest.mark.asyncio +async def test_parse_chat_messages_with_image_url(phi3v_model_config, + phi3v_tokenizer, image_url): + conversation, mm_future = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + mm_data = await mm_future + assert set(mm_data.keys()) == {"image"} + assert isinstance(mm_data["image"], Image.Image) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images(phi3v_model_config, + phi3v_tokenizer, image_url): + conversation, mm_future = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?" + }] + mm_data = await mm_future + assert set(mm_data.keys()) == {"image"} + assert len(mm_data["image"]) == 2 + + +@pytest.mark.asyncio +async def test_parse_chat_messages_placeholder_already_in_prompt( + phi3v_model_config, phi3v_tokenizer, image_url): + conversation, mm_future = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to <|image_2|>?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "What's in <|image_1|> and how does it compare to <|image_2|>?" + }] + mm_data = await mm_future + assert set(mm_data.keys()) == {"image"} + assert len(mm_data["image"]) == 2 + + +@pytest.mark.asyncio +async def test_parse_chat_messages_placeholder_one_already_in_prompt( + phi3v_model_config, phi3v_tokenizer, image_url): + conversation, mm_future = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to the other one?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " + "other one?" + }] + mm_data = await mm_future + assert set(mm_data.keys()) == {"image"} + assert len(mm_data["image"]) == 2 + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_across_messages( + phi3v_model_config, phi3v_tokenizer, image_url): + conversation, mm_future = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about this one?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?" + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "<|image_2|>\nWhat about this one?" + }, + ] + mm_data = await mm_future + assert set(mm_data.keys()) == {"image"} + assert len(mm_data["image"]) == 2 + + +@pytest.mark.asyncio +async def test_parse_chat_messages_rejects_too_many_images_in_one_message( + phi3v_model_config, phi3v_tokenizer, image_url): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="coroutine 'async_get_and_parse_image' was never awaited") + with pytest.raises( + ValueError, + match="At most 2 image\\(s\\) may be provided in one request\\." + ): + parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_rejects_too_many_images_across_messages( + phi3v_model_config, phi3v_tokenizer, image_url): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="coroutine 'async_get_and_parse_image' was never awaited") + with pytest.raises( + ValueError, + match="At most 2 image\\(s\\) may be provided in one request\\." + ): + parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about these two?" + }] + }], phi3v_model_config, phi3v_tokenizer) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c5368ac3bf026..c70c6d9330b10 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,9 +1,10 @@ +import asyncio import codecs -from dataclasses import dataclass +from collections import defaultdict from functools import lru_cache from pathlib import Path -from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, - Union) +from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping, + Optional, Tuple, Union) # yapf conflicts with isort for this block # yapf: disable @@ -80,10 +81,90 @@ class ConversationMessage(TypedDict): content: str -@dataclass(frozen=True) -class ChatMessageParseResult: - messages: List[ConversationMessage] - mm_futures: List[Awaitable[MultiModalDataDict]] +class MultiModalItemTracker: + """ + Tracks multi-modal items in a given request and ensures that the number + of multi-modal items in a given request does not exceed the configured + maximum per prompt. + """ + + def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): + self._model_config = model_config + self._tokenizer = tokenizer + self._allowed_items = (model_config.multimodal_config.limit_per_prompt + if model_config.multimodal_config else {}) + self._consumed_items = {k: 0 for k in self._allowed_items} + self._futures: List[Awaitable[MultiModalDataDict]] = [] + + @staticmethod + @lru_cache(maxsize=None) + def _cached_token_str(tokenizer: AnyTokenizer, token_index: int): + return tokenizer.decode(token_index) + + def add(self, modality: Literal["image", "audio"], + mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]: + """ + Adds the multi-modal item to the current prompt and returns the + placeholder string to use, if any. + """ + allowed_count = self._allowed_items.get(modality, 1) + current_count = self._consumed_items.get(modality, 0) + 1 + if current_count > allowed_count: + raise ValueError( + f"At most {allowed_count} {modality}(s) may be provided in " + "one request.") + + self._consumed_items[modality] = current_count + self._futures.append(mm_future) + + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + model_type = self._model_config.hf_config.model_type + if modality == "image": + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return f"<|image_{current_count}|>" + if model_type == "minicpmv": + return "(./)" + if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): + # These models do not use image tokens in the prompt + return None + if model_type.startswith("llava"): + return MultiModalItemTracker._cached_token_str( + self._tokenizer, + self._model_config.hf_config.image_token_index) + if model_type in ("chameleon", "internvl_chat"): + return "" + + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "audio": + if model_type == "ultravox": + return "<|reserved_special_token_0|>" + raise TypeError(f"Unknown model type: {model_type}") + else: + raise TypeError(f"Unknown modality: {modality}") + + @staticmethod + async def _combine(futures: List[Awaitable[MultiModalDataDict]]): + mm_lists: Mapping[str, List[object]] = defaultdict(list) + + # Merge all the multi-modal items + for single_mm_data in (await asyncio.gather(*futures)): + for mm_key, mm_item in single_mm_data.items(): + if isinstance(mm_item, list): + mm_lists[mm_key].extend(mm_item) + else: + mm_lists[mm_key].append(mm_item) + + # Unpack any single item lists for models that don't expect multiple. + return { + mm_key: mm_list[0] if len(mm_list) == 1 else mm_list + for mm_key, mm_list in mm_lists.items() + } + + def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]: + return MultiModalItemTracker._combine( + self._futures) if self._futures else None def load_chat_template( @@ -112,44 +193,30 @@ def load_chat_template( return resolved_chat_template -@lru_cache(maxsize=None) -def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer, - modality: Literal["image", "audio"]) -> Optional[str]: - # TODO: Let user specify how to insert image tokens into prompt - # (similar to chat template) - model_type = model_config.hf_config.model_type - if modality == "image": - if model_type == "phi3_v": - # Workaround since this token is not defined in the tokenizer - return "<|image_1|>" - if model_type == "minicpmv": - return "(./)" - if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): - # These models do not use image tokens in the prompt - return None - if model_type.startswith("llava"): - return tokenizer.decode(model_config.hf_config.image_token_index) - if model_type in ("chameleon", "internvl_chat"): - return "" - - raise TypeError(f"Unknown model type: {model_type}") - elif modality == "audio": - if model_type == "ultravox": - return "<|reserved_special_token_0|>" - raise TypeError(f"Unknown model type: {model_type}") - else: - raise TypeError(f"Unknown modality: {modality}") - - # TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) -def _get_full_multimodal_text_prompt(placeholder_token_str: str, +def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], text_prompt: str) -> str: """Combine multimodal prompts for a multimodal language model""" - # NOTE: For now we assume all model architectures use the same - # placeholder + text prompt format. This may change in the future. - return f"{placeholder_token_str}\n{text_prompt}" + # Look through the text prompt to check for missing placeholders + missing_placeholders = [] + for placeholder in placeholder_counts: + + # For any existing placeholder in the text prompt, we leave it as is + placeholder_counts[placeholder] -= text_prompt.count(placeholder) + + if placeholder_counts[placeholder] < 0: + raise ValueError( + f"Found more '{placeholder}' placeholders in input prompt than " + "actual multimodal data items.") + + missing_placeholders.extend([placeholder] * + placeholder_counts[placeholder]) + + # NOTE: For now we always add missing placeholders at the front of + # the prompt. This may change to be customizable in the future. + return "\n".join(missing_placeholders + [text_prompt]) _TextParser = TypeAdapter(ChatCompletionContentPartTextParam) @@ -160,12 +227,12 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str, def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], - model_config: ModelConfig, - tokenizer: AnyTokenizer, -) -> ChatMessageParseResult: + mm_tracker: MultiModalItemTracker, +) -> List[ConversationMessage]: texts: List[str] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] - modality: Literal["image", "audio"] = "image" + + # multimodal placeholder_string : count + mm_placeholder_counts: Dict[str, int] = {} for part in parts: part_type = part["type"] @@ -173,11 +240,6 @@ def _parse_chat_message_content_parts( text = _TextParser.validate_python(part)["text"] texts.append(text) elif part_type == "image_url": - modality = "image" - if len(mm_futures) > 0: - raise NotImplementedError( - "Multiple multimodal inputs is currently not supported.") - image_url = _ImageParser.validate_python(part)["image_url"] if image_url.get("detail", "auto") != "auto": @@ -185,60 +247,44 @@ def _parse_chat_message_content_parts( "'image_url.detail' is currently not supported and " "will be ignored.") - image_future = async_get_and_parse_image(image_url["url"]) - mm_futures.append(image_future) + image_coro = async_get_and_parse_image(image_url["url"]) + placeholder = mm_tracker.add("image", image_coro) + if placeholder: + mm_placeholder_counts[placeholder] = mm_placeholder_counts.get( + placeholder, 0) + 1 elif part_type == "audio_url": - modality = "audio" - if len(mm_futures) > 0: - raise NotImplementedError( - "Multiple multimodal inputs is currently not supported.") - audio_url = _AudioParser.validate_python(part)["audio_url"] - audio_future = async_get_and_parse_audio(audio_url["url"]) - mm_futures.append(audio_future) + audio_coro = async_get_and_parse_audio(audio_url["url"]) + placeholder = mm_tracker.add("audio", audio_coro) + if placeholder: + mm_placeholder_counts[placeholder] = mm_placeholder_counts.get( + placeholder, 0) + 1 else: raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, + text_prompt) - if mm_futures: - placeholder_token_str = _mm_token_str(model_config, tokenizer, - modality) - if placeholder_token_str is not None: - if placeholder_token_str in text_prompt: - logger.warning( - "Detected multi-modal token string in the text prompt. " - "Skipping prompt formatting.") - else: - text_prompt = _get_full_multimodal_text_prompt( - placeholder_token_str=placeholder_token_str, - text_prompt=text_prompt, - ) - - messages = [ConversationMessage(role=role, content=text_prompt)] - - return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) + return [ConversationMessage(role=role, content=text_prompt)] def _parse_chat_message_content( - message: ChatCompletionMessageParam, - model_config: ModelConfig, - tokenizer: AnyTokenizer, -) -> ChatMessageParseResult: + message: ChatCompletionMessageParam, + mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]: role = message["role"] content = message.get("content") if content is None: - return ChatMessageParseResult(messages=[], mm_futures=[]) + return [] if isinstance(content, str): - messages = [ConversationMessage(role=role, content=content)] - return ChatMessageParseResult(messages=messages, mm_futures=[]) + return [ConversationMessage(role=role, content=content)] return _parse_chat_message_content_parts( role, content, # type: ignore - model_config, - tokenizer, + mm_tracker, ) @@ -246,18 +292,16 @@ def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, -) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: +) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]: conversation: List[ConversationMessage] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] + mm_tracker = MultiModalItemTracker(model_config, tokenizer) for msg in messages: - parse_result = _parse_chat_message_content(msg, model_config, - tokenizer) + sub_messages = _parse_chat_message_content(msg, mm_tracker) - conversation.extend(parse_result.messages) - mm_futures.extend(parse_result.mm_futures) + conversation.extend(sub_messages) - return conversation, mm_futures + return conversation, mm_tracker.all_mm_data() def apply_chat_template( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d31ac4995fe2f..f7576509d06c8 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -94,7 +94,7 @@ async def create_chat_completion( tokenizer = await self.async_engine_client.get_tokenizer( lora_request) - conversation, mm_futures = parse_chat_messages( + conversation, mm_data_future = parse_chat_messages( request.messages, model_config, tokenizer) tool_dicts = None if request.tools is None else [ @@ -116,12 +116,8 @@ async def create_chat_completion( mm_data: Optional[MultiModalDataDict] = None try: - if len(mm_futures): - # since we support only single mm data currently - assert len( - mm_futures - ) == 1, "Multiple 'image_url' input is currently not supported." - mm_data = await mm_futures[0] + if mm_data_future: + mm_data = await mm_data_future except Exception as e: logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 1aeabb7a7d729..fc9ca29e9cf86 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -65,10 +65,10 @@ async def create_tokenize( if isinstance(request, TokenizeChatRequest): model_config = self.model_config - conversation, mm_futures = parse_chat_messages( + conversation, mm_data_future = parse_chat_messages( request.messages, model_config, tokenizer) - if mm_futures: + if mm_data_future: logger.warning( "Multi-modal inputs are ignored during tokenization")