From d06bfd344dd9dc8d53381f59705ad039ad213ba0 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 23 Sep 2024 01:44:48 -0600 Subject: [PATCH] [Core][Frontend] Support Passing Multimodal Processor Kwargs (#8657) Signed-off-by: Alex-Brooks Signed-off-by: Sumit Dubey --- tests/engine/test_arg_utils.py | 21 ++ .../decoder_only/vision_language/test_qwen.py | 29 +- tests/models/utils.py | 35 ++ tests/multimodal/test_processor_kwargs.py | 339 ++++++++++++++++++ vllm/config.py | 6 +- vllm/engine/arg_utils.py | 8 + vllm/engine/llm_engine.py | 3 +- vllm/entrypoints/llm.py | 2 + vllm/inputs/registry.py | 38 +- vllm/multimodal/base.py | 19 +- vllm/multimodal/image.py | 10 +- vllm/multimodal/registry.py | 9 + vllm/multimodal/video.py | 9 +- vllm/transformers_utils/image_processor.py | 64 ---- vllm/transformers_utils/processor.py | 65 +++- vllm/utils.py | 48 +++ 16 files changed, 589 insertions(+), 116 deletions(-) create mode 100644 tests/multimodal/test_processor_kwargs.py delete mode 100644 vllm/transformers_utils/image_processor.py diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 8dd200b35d0f3..360ac1bfbad93 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -40,3 +40,24 @@ def test_limit_mm_per_prompt_parser(arg, expected): def test_bad_nullable_kvs(arg): with pytest.raises(ArgumentTypeError): nullable_kvs(arg) + + +@pytest.mark.parametrize(("arg", "expected"), [ + (None, None), + ("{}", {}), + ('{"num_crops": 4}', { + "num_crops": 4 + }), + ('{"foo": {"bar": "baz"}}', { + "foo": { + "bar": "baz" + } + }), +]) +def test_mm_processor_kwargs_prompt_parser(arg, expected): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + if arg is None: + args = parser.parse_args([]) + else: + args = parser.parse_args(["--mm-processor-kwargs", arg]) + assert args.mm_processor_kwargs == expected diff --git a/tests/models/decoder_only/vision_language/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py index e4f79092b7606..638fb68b8f872 100644 --- a/tests/models/decoder_only/vision_language/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -5,14 +5,13 @@ import torch from PIL.Image import Image -from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size from ....conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput, VllmRunner, _ImageAssets) -from ...utils import check_logprobs_close +from ...utils import build_model_context, check_logprobs_close text_only_models = [ "Qwen/Qwen-7B-Chat" # Has no visual component @@ -42,32 +41,6 @@ IMG_SIZE = 448 -def build_model_context(model_name: str, - tokenizer_name: Optional[str] = None, - trust_remote_code: bool = False): - """Creates an InputContext for a given model. - - Args: - model_name: Name of the model being considered. - tokenizer_name: Name of the tokenizer being considered. - trust_remote_code: Whether or not to allow loading remote code. - - Returns: - InputContext for the model being considered. - """ - if tokenizer_name is None: - tokenizer_name = model_name - model_config = ModelConfig( - model_name, - tokenizer_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - dtype="float32", - seed=0, - ) - return InputContext(model_config) - - @pytest.fixture() def input_mapper_for_qwen(): # Lazy import to avoid initializing CUDA during test collection diff --git a/tests/models/utils.py b/tests/models/utils.py index 8e31a1d6eefed..eb6254f181827 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,6 +1,8 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union +from vllm.config import ModelConfig +from vllm.inputs import InputContext from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs TokensText = Tuple[List[int], str] @@ -240,3 +242,36 @@ def check_logprobs_close( warnings.simplefilter("always") warnings.warn(fail_msg, stacklevel=2) + + +def build_model_context(model_name: str, + tokenizer_name: Optional[str] = None, + trust_remote_code: bool = False, + mm_processor_kwargs: Optional[Dict] = None, + limit_mm_per_prompt: Optional[Dict] = None): + """Creates an InputContext for a given model. + + Args: + model_name: Name of the model being considered. + tokenizer_name: Name of the tokenizer being considered. + trust_remote_code: Whether or not to allow loading remote code. + mm_processor_kwargs: optional processor kwargs for to be leveraged + in the input processor, mapper, dummy data creation, etc. + limit_mm_per_prompt: Multimodal limits. + + Returns: + InputContext for the model being considered. + """ + if tokenizer_name is None: + tokenizer_name = model_name + model_config = ModelConfig( + model_name, + tokenizer_name, + tokenizer_mode="auto", + trust_remote_code=trust_remote_code, + dtype="float32", + seed=0, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt=limit_mm_per_prompt, + ) + return InputContext(model_config) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py new file mode 100644 index 0000000000000..5529ccd4fa570 --- /dev/null +++ b/tests/multimodal/test_processor_kwargs.py @@ -0,0 +1,339 @@ +from array import array +from typing import Mapping +from unittest.mock import patch + +import pytest +import torch + +from vllm.inputs import InputContext, LLMInputs +from vllm.inputs.registry import InputRegistry +from vllm.multimodal import MultiModalRegistry +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData + +from ..models.utils import build_model_context + +# Used for fast tests where the model doesn't matter +DUMMY_MODEL_ID = "facebook/opt-125m" +# Used for tests that need a multimodal model +MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" + +# For mm_processor_kwargs - we test overrides by defining mocks for each place +# it is used, and ensuring that we can pass processor kwargs an override value +# to receive the intended result for things like sequence length etc. +DEFAULT_NUM_CROPS = 4 +NUM_CROPS_OVERRIDE = 16 + + +# Mocks for all of the places that we use the mm_processor_kwargs +# to override values in different callables +@pytest.fixture +def use_processor_mock(): + """Patches the internal model input processor with an override callable.""" + + def custom_processor(ctx: InputContext, + llm_inputs: LLMInputs, + *, + num_crops=DEFAULT_NUM_CROPS): + # For testing purposes, we don't worry about the llm inputs / return + # type validation, and just return the value of the kwarg that we + # clobber. + return num_crops + + with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", + return_value=custom_processor): + yield + + +@pytest.fixture +def use_dummy_data_mock(): + """Patches the internal model input processor with an override callable.""" + + def custom_dummy_data_factory(self, + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + num_crops=DEFAULT_NUM_CROPS): + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) + return seq_data, None + + with patch( + "vllm.inputs.registry.InputRegistry._default_dummy_data_factory", + custom_dummy_data_factory): + yield + + +# Lazy import to avoid CUDA reinitialization error +def mm_model_cls(): + from vllm.model_executor.models.phi3v import Phi3VForCausalLM + + return Phi3VForCausalLM + + +# lambda whose signature matches max token calcs extra & mapper + extra kwargs +get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops +custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { + "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) +} + + +### Test for default processor logic & mm_processor_kwargs wrapping +def test_default_processor_is_a_noop(): + """Ensure that by default, there is no processor override.""" + dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID) + processor = dummy_registry.create_input_processor(ctx.model_config) + proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") + proc_outputs = processor(inputs=proc_inputs) + assert proc_inputs is proc_outputs + + +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_processor_default_kwargs(use_processor_mock, num_crops): + """Ensure input processors can use processor kwargs.""" + dummy_registry = InputRegistry() + # If we have a value for num_crops, pass the override value and make + # sure we get that value as a return-value from out mock processor, + # otherwise fall back to the default value + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } + expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops + ctx = build_model_context(DUMMY_MODEL_ID, + mm_processor_kwargs=mm_processor_kwargs) + processor = dummy_registry.create_input_processor(ctx.model_config) + + num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) + assert num_crops_val == expected_num_crops + + +@pytest.mark.parametrize( + "mm_processor_kwargs", + [ + # Not part of the signature + { + "does_not_exist": 100 + }, + # Part of the signature, not keyword only + { + "ctx": "something bad" + } + ]) +def test_processor_with_sad_kwarg_overrides(use_processor_mock, + mm_processor_kwargs): + """Ensure that input processors filter out invalid mm_processor_kwargs""" + dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID, + mm_processor_kwargs=mm_processor_kwargs) + + processor = dummy_registry.create_input_processor(ctx.model_config) + num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) + assert num_crops_val == DEFAULT_NUM_CROPS + + +### Test overrides for the dummy data +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): + """Ensure dummy data factories can use processor kwargs.""" + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID, + mm_processor_kwargs=mm_processor_kwargs) + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + # NOTE: seq_len is thrown away here since this will leverage the + # default dummy data factory that we have patched in, whose seq + # len is solely dependent on the value of the mm_processor_kwargs. + seq_data, _ = dummy_registry.dummy_data_for_profiling( + ctx.model_config, seq_len=-1, mm_registry=mm_registry) + assert len(seq_data.prompt_token_ids) == expected_seq_count + + +@pytest.mark.parametrize( + "mm_processor_kwargs", + [ + # Not part of the signature + { + "does_not_exist": 100 + }, + # Part of the signature, not keyword only + { + "ctx": "something bad" + } + ]) +def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, + mm_processor_kwargs): + """Ensure the dummy data factory filters out invalid mm_processor_kwargs""" + dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID, + mm_processor_kwargs=mm_processor_kwargs) + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + # NOTE: seq_len is thrown away here since this will leverage the + # default dummy data factory that we have patched in, whose seq + # len is solely dependent on the value of the mm_processor_kwargs. + seq_data, _ = dummy_registry.dummy_data_for_profiling( + ctx.model_config, seq_len=-1, mm_registry=mm_registry) + assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS + + +### Test overrides for the max token count per multimodal instance +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_max_tokens_kwarg_overrides(num_crops): + """Ensure max token calcs can use processor kwargs.""" + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt={"image": 1}) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the mm_processor_kwargs. + with patch.object( + mm_registry._get_plugin("image"), + "_max_mm_tokens", + {mm_model_cls(): get_num_crops}, + ): + max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( + ctx.model_config) + + assert expected_seq_count == max_multimodal_tokens + + +@pytest.mark.parametrize( + "mm_processor_kwargs", + [ + # Not part of the signature + { + "does_not_exist": 100 + }, + # Part of the signature, not keyword only + { + "ctx": "something bad" + } + ]) +def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): + """Ensure that max token calcs filters out invalid mm_processor_kwargs""" + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt={"image": 1}) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + # Similar before, but since these kwargs get filtered, + # we always get our default value back. + with patch.object( + mm_registry._get_plugin("image"), + "_max_mm_tokens", + {mm_model_cls(): get_num_crops}, + ): + max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( + ctx.model_config) + + assert max_multimodal_tokens == DEFAULT_NUM_CROPS + + +### Test overrides for the mapper +@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE]) +def test_default_mapper_with_processer_kwargs(image_assets, num_crops): + """Ensure that the mapper processor kwargs can fall back to HF models.""" + # NOTE - we don't validate bad inputs for the default mapper, because it's + # through the automodel interface in transformers, so we can't easily + # inspect what kwargs are or are not allowed. + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + mm_processor_kwargs={"num_crops": num_crops}, + limit_mm_per_prompt={"image": 1}) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] + assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 + + +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_custom_mapper_kwarg_overrides(image_assets, num_crops): + """Ensure custom mappers can use processor kwargs.""" + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt={"image": 1}) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the mm_processor_kwargs. + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + with patch.object( + mm_registry._get_plugin("image"), + "_default_input_mapper", + {mm_model_cls(): custom_mapper}, + ): + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + + assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1 + + +@pytest.mark.parametrize( + "mm_processor_kwargs", + [ + # Not part of the signature + { + "does_not_exist": 100 + }, + # Part of the signature, not keyword only + { + "ctx": "something bad" + } + ]) +def test_custom_mapper_with_sad_kwarg_overrides(image_assets, + mm_processor_kwargs): + """Ensure that custom mappers filters out invalid mm_processor_kwargs""" + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt={"image": 1}) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the mm_processor_kwargs. + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + with patch.object( + mm_registry._get_plugin("image"), + "_default_input_mapper", + {mm_model_cls(): custom_mapper}, + ): + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + + assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 diff --git a/vllm/config.py b/vllm/config.py index 7a15606836dcc..fae2d44f174bd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -122,6 +122,8 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. + mm_processor_kwargs: Arguments to be forwarded to the model's processor + for multi-modal data, e.g., image processor. """ def __init__(self, @@ -150,7 +152,8 @@ def __init__(self, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None, - config_format: ConfigFormat = ConfigFormat.AUTO) -> None: + config_format: ConfigFormat = ConfigFormat.AUTO, + mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -184,6 +187,7 @@ def __init__(self, self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc + self.mm_processor_kwargs = mm_processor_kwargs # Set enforce_eager to False if the value is unset. if self.enforce_eager is None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4139eca9c1832..ca6034ddbe5c5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -175,6 +175,7 @@ class EngineArgs: collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False override_neuron_config: Optional[Dict[str, Any]] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None def __post_init__(self): if self.tokenizer is None: @@ -513,6 +514,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'e.g.: `image=16,video=2` allows a maximum of 16 ' 'images and 2 videos per prompt. Defaults to 1 for ' 'each modality.')) + parser.add_argument( + '--mm-processor-kwargs', + default=None, + type=json.loads, + help=('Overrides for the multimodal input mapping/processing,' + 'e.g., image processor. For example: {"num_crops": 4}.')) # LoRA related configs parser.add_argument('--enable-lora', @@ -822,6 +829,7 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, override_neuron_config=self.override_neuron_config, config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 39409757d3812..80dde804addac 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -235,7 +235,7 @@ def __init__( "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s)", + "use_async_output_proc=%s, mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -268,6 +268,7 @@ def __init__( scheduler_config.num_scheduler_steps, cache_config.enable_prefix_caching, model_config.use_async_output_proc, + model_config.mm_processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. from vllm.plugins import load_general_plugins diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c7548ca4bcfbd..a86c51d23b34d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -134,6 +134,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: ''' @@ -174,6 +175,7 @@ def __init__( max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, disable_async_output_proc=disable_async_output_proc, + mm_processor_kwargs=mm_processor_kwargs, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 2df61a9149629..6ab23d1c4b769 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -9,6 +9,7 @@ from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.utils import get_allowed_kwarg_only_overrides from .data import LLMInputs @@ -68,12 +69,17 @@ def __call__( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], + **mm_processor_kwargs: Any, ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data to be inputted into the model. Note: :data:`InputProcessor` is not applied to the dummy data. + + The :code:`mm_processor_kwargs` are overrides provided at + initialization time to values in the config whose values + may affect the number of tokens per instance. """ ... @@ -152,6 +158,10 @@ def wrapper(model_cls: N) -> N: return wrapper + def _get_dummy_data_factory(self, model_cls: Type[nn.Module]): + return self._dummy_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + def dummy_data_for_profiling( self, model_config: "ModelConfig", @@ -174,15 +184,15 @@ def dummy_data_for_profiling( from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) - dummy_factory = self._dummy_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) + dummy_factory = self._get_dummy_data_factory(model_cls) + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + dummy_factory, overrides=model_config.mm_processor_kwargs) - seq_data, mm_data = dummy_factory( - InputContext(model_config), - seq_len, - _MultiModalCounts(mm_counts), - ) + seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs) # Having more tokens is over-conservative but otherwise fine num_tokens = seq_data.prompt_token_ids @@ -229,6 +239,10 @@ def wrapper(model_cls: N) -> N: return wrapper + def _get_model_input_processor(self, model_cls: Type[nn.Module]): + return self._input_processors_by_model_type \ + .get(model_cls, self._default_input_processor) + def process_input(self, model_config: "ModelConfig", inputs: LLMInputs) -> LLMInputs: """ @@ -243,15 +257,17 @@ def process_input(self, model_config: "ModelConfig", from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) + processor = self._get_model_input_processor(model_cls) - processor = self._input_processors_by_model_type \ - .get(model_cls, self._default_input_processor) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + processor, overrides=model_config.mm_processor_kwargs) - return processor(InputContext(model_config), inputs) + return processor(InputContext(model_config), inputs, + **mm_processor_kwargs) def create_input_processor(self, model_config: "ModelConfig"): """ - Create an input processor (see :meth:`process_input`) for a + Create an input processor (see :meth:`_process_input`) for a specific model. """ return functools.partial(self.process_input, model_config) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 032964fe0ac4e..87d3a4576f332 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -14,7 +14,8 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import JSONTree, is_list_of, json_map_leaves +from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, + json_map_leaves) logger = init_logger(__name__) @@ -256,11 +257,20 @@ def map_input(self, model_config: ModelConfig, model_cls, _ = get_model_architecture(model_config) mapper = self._input_mappers.get(model_cls) + # Only get processor kwargs at mapping time if we are not using the + # input mapper; no overrides are used on the default here because they + # should be passed to the huggingface resource at initialization time. + if mapper is not None and mapper != self._default_input_mapper: + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + mapper, overrides=model_config.mm_processor_kwargs) + else: + mm_processor_kwargs = {} + if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") - return mapper(InputContext(model_config), data) + return mapper(InputContext(model_config), data, **mm_processor_kwargs) @abstractmethod def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: @@ -333,7 +343,10 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: f"for model class {model_cls.__name__} in {self}.") if callable(max_mm_tokens): - max_mm_tokens = max_mm_tokens(InputContext(model_config)) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + max_mm_tokens, overrides=model_config.mm_processor_kwargs) + max_mm_tokens = max_mm_tokens(InputContext(model_config), + **mm_processor_kwargs) self._validate_max_multimodal_tokens(max_mm_tokens) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 6cdde949bc2b1..31b1c3f93411a 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -6,7 +6,7 @@ from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger -from vllm.transformers_utils.image_processor import get_image_processor +from vllm.transformers_utils.processor import get_image_processor from vllm.utils import is_list_of from .base import MultiModalData, MultiModalInputs, MultiModalPlugin @@ -23,9 +23,14 @@ def get_data_key(self) -> str: return "image" def _get_hf_image_processor(self, model_config: ModelConfig): + mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None + else model_config.mm_processor_kwargs) + # We don't explicitly check kwarg overrides to the HF class + # since the automodel just takes kwargs, so we can't inspect it return cached_get_image_processor( model_config.model, - trust_remote_code=model_config.trust_remote_code) + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs) def _default_input_mapper( self, @@ -37,6 +42,7 @@ def _default_input_mapper( # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): image_processor = self._get_hf_image_processor(model_config) + if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 745fc715caf45..3940e1671b57a 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -138,6 +138,15 @@ def create_input_mapper(self, model_config: ModelConfig): """ Create an input mapper (see :meth:`map_input`) for a specific model. """ + # NOTE - we currently make the assumption that if a model has multiple + # supported modalities, they take the same kwargs. For the default, + # this could be an issue in the future if it falls back to two HF + # resources and we can't inspect the signature easily since it's + # getting initialized through the autoclass. + # + # If this is a problem in the future, we should revisit it, but since + # it potentially introduces a lot of complexity for a currently + # uncommon case, we do not for simplicity of both use & implementation return functools.partial(self.map_input, model_config) def register_max_multimodal_tokens( diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 4401d13157923..39e75dbaf6872 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -6,7 +6,7 @@ from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger -from vllm.transformers_utils.image_processor import get_video_processor +from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import is_list_of @@ -37,9 +37,14 @@ def get_data_key(self) -> str: return "video" def _get_hf_video_processor(self, model_config: ModelConfig): + mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None + else model_config.mm_processor_kwargs) + # We don't explicitly check kwarg overrides to the HF class + # since the automodel just takes kwargs, so we can't inspect it return cached_get_video_processor( model_config.model, - trust_remote_code=model_config.trust_remote_code) + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs) def _default_input_mapper( self, diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py deleted file mode 100644 index 4cffac3724ba8..0000000000000 --- a/vllm/transformers_utils/image_processor.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import cast - - -def get_video_processor( - processor_name: str, - trust_remote_code: bool = False, -): - """ - Gets a processor for the given model name via HuggingFace. - """ - from transformers import AutoProcessor - - try: - processor = AutoProcessor.from_pretrained(processor_name) - video_processor = processor.video_processor - - except ValueError as e: - if not trust_remote_code: - err_msg = ( - "Failed to load the processor. If the processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - return video_processor - - -def get_image_processor( - processor_name: str, - *args, - trust_remote_code: bool = False, - **kwargs, -): - """Gets an image processor for the given model name via HuggingFace.""" - # don't put this import at the top level - # it will call torch.cuda.device_count() - from transformers import AutoImageProcessor - from transformers.image_processing_utils import BaseImageProcessor - - try: - processor = AutoImageProcessor.from_pretrained( - processor_name, - *args, - trust_remote_code=trust_remote_code, - **kwargs) - except ValueError as e: - # If the error pertains to the processor class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors - if not trust_remote_code: - err_msg = ( - "Failed to load the image processor. If the image processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - - return cast(BaseImageProcessor, processor) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 2001746c5f7f9..98663f7f0bd07 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,13 +1,13 @@ -from typing import cast +from typing import Any, cast def get_processor( processor_name: str, - *args, + *args: Any, trust_remote_code: bool = False, - **kwargs, + **kwargs: Any, ): - """Gets a processor for the given model name via HuggingFace.""" + """Load a processor for the given model name via HuggingFace.""" # don't put this import at the top level # it will call torch.cuda.device_count() from transformers import AutoProcessor @@ -35,3 +35,60 @@ def get_processor( raise e return cast(ProcessorMixin, processor) + + +def get_image_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load an image processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoImageProcessor + from transformers.image_processing_utils import BaseImageProcessor + + try: + processor = AutoImageProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the image processor. If the image processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return cast(BaseImageProcessor, processor) + + +def get_video_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load a video processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers.image_processing_utils import BaseImageProcessor + + processor = get_processor( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + return cast(BaseImageProcessor, processor.video_processor) diff --git a/vllm/utils.py b/vllm/utils.py index b1513b91a06c6..db2ef146e38ea 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -4,6 +4,7 @@ import datetime import enum import gc +import inspect import os import random import socket @@ -1237,6 +1238,53 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, return await task(*args, **kwargs) +def get_allowed_kwarg_only_overrides( + callable: Callable[..., object], + overrides: Optional[Dict[str, Any]], +) -> Dict[str, Any]: + """ + Given a callable which has one or more keyword only params and a dict + mapping param names to values, drop values that can be not be kwarg + expanded to overwrite one or more keyword-only args. This is used in a + few places to handle custom processor overrides for multimodal models, + e.g., for profiling when processor options provided by the user + may affect the number of mm tokens per instance. + + Args: + callable: Callable which takes 0 or more keyword only arguments. + overrides: Potential overrides to be used when invoking the callable. + + Returns: + Dictionary containing the kwargs to be leveraged which may be used + to overwrite one or more keyword only arguments when invoking the + callable. + """ + if not overrides: + return {} + + allowed_override_names = [ + name for name, param in inspect.signature(callable).parameters.items() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] + + # Drop any mm_processor_kwargs provided by the user that are + # not kwarg names accepted by the provided input processor. + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if kwarg_name in allowed_override_names + } + + # If anything is dropped, log a warning + dropped_keys = overrides.keys() - filtered_overrides.keys() + if dropped_keys: + logger.warning( + "The following intended overrides are not keyword-only args " + "and and will be dropped: %s", dropped_keys) + + return filtered_overrides + + # Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. # In particular, the FakeScalarType is not supported for earlier versions of # PyTorch which breaks dynamo for any ops registered using ScalarType.