Skip to content

Commit

Permalink
[Bugfix] Fix PaliGemma MMP (vllm-project#6930)
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 authored Jul 30, 2024
1 parent 6e063ea commit c66c7f8
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -133,12 +132,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()

self.linear = ColumnParallelLinear(vision_hidden_size,
projection_dim,
bias=True)
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear(image_features)
hidden_states = self.linear(image_features)
return hidden_states


Expand Down

0 comments on commit c66c7f8

Please sign in to comment.