diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8a2bacbd96b67..2af48b6bc190f 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -9,7 +9,6 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -133,12 +132,10 @@ class PaliGemmaMultiModalProjector(nn.Module): def __init__(self, vision_hidden_size: int, projection_dim: int): super().__init__() - self.linear = ColumnParallelLinear(vision_hidden_size, - projection_dim, - bias=True) + self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True) def forward(self, image_features: torch.Tensor) -> torch.Tensor: - hidden_states, _ = self.linear(image_features) + hidden_states = self.linear(image_features) return hidden_states