diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py index 2b1d3c5b43b44..b0e7264e89118 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/test_paligemma.py @@ -129,7 +129,7 @@ def run_test( [0.25, 0.5, 1.0], ], ) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", ["float", "half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 16548c6c1e8c7..7e0888b5f5abd 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -277,6 +277,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 2af2bedd8e48e..8a2bacbd96b67 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -19,7 +19,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_tokenizer -from vllm.sequence import SamplerOutput, SequenceData +from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from .interfaces import SupportsVision from .utils import merge_vision_embeddings @@ -111,7 +111,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): orig_prompt = llm_inputs.get("prompt") orig_prompt_ids = llm_inputs.get("prompt_token_ids") - if image_token_str in orig_prompt: + if orig_prompt is not None and image_token_str in orig_prompt: logger.warning( "The image token '%s' was detected in the prompt and " "will be removed. Please follow the proper prompt format" @@ -214,7 +214,9 @@ def _parse_and_validate_image_input( def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, pixel_values: torch.Tensor) -> torch.Tensor: - image_outputs = vision_tower(pixel_values, output_hidden_states=True) + target_dtype = vision_tower.get_input_embeddings().weight.dtype + image_outputs = vision_tower(pixel_values.to(dtype=target_dtype), + output_hidden_states=True) selected_image_features = image_outputs.last_hidden_state @@ -236,9 +238,12 @@ def _process_image_input( return self.multi_modal_projector(image_features) - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object) -> SamplerOutput: parsed_image_input = self._parse_and_validate_image_input(**kwargs) @@ -263,6 +268,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, positions, kv_caches, attn_metadata, + None, inputs_embeds=inputs_embeds) return hidden_states