Skip to content

Commit

Permalink
[Bugfix] Fix dtype mismatch in PaliGemma (#6367)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Jul 12, 2024
1 parent aea19f0 commit 024ad87
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tests/models/test_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def run_test(
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["float", "half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
Expand Down
14 changes: 10 additions & 4 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import SamplerOutput, SequenceData
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData

from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
Expand Down Expand Up @@ -111,7 +111,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
orig_prompt = llm_inputs.get("prompt")
orig_prompt_ids = llm_inputs.get("prompt_token_ids")

if image_token_str in orig_prompt:
if orig_prompt is not None and image_token_str in orig_prompt:
logger.warning(
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
Expand Down Expand Up @@ -214,7 +214,9 @@ def _parse_and_validate_image_input(
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:

image_outputs = vision_tower(pixel_values, output_hidden_states=True)
target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
output_hidden_states=True)

selected_image_features = image_outputs.last_hidden_state

Expand All @@ -236,9 +238,12 @@ def _process_image_input(

return self.multi_modal_projector(image_features)

def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object) -> SamplerOutput:

parsed_image_input = self._parse_and_validate_image_input(**kwargs)
Expand All @@ -263,6 +268,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)

return hidden_states
Expand Down

0 comments on commit 024ad87

Please sign in to comment.