Skip to content

Commit

Permalink
Fix InternVL
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Aug 28, 2024
1 parent 88748cb commit 3adeb01
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
20 changes: 10 additions & 10 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@

class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: torch.Tensor
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
`(batch_size * num_images * 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
Expand All @@ -54,7 +54,7 @@ class InternVLImagePixelInputs(TypedDict):

class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
Expand Down Expand Up @@ -358,7 +358,7 @@ def pixel_shuffle(self, x, scale_factor=0.5):
x = x.permute(0, 2, 1, 3).contiguous()
return x

def extract_feature(self, pixel_values):
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
vit_embeds = self.vision_model(pixel_values=pixel_values)
vit_embeds = vit_embeds[:, 1:, :]

Expand All @@ -371,9 +371,7 @@ def extract_feature(self, pixel_values):
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds

def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
Expand All @@ -382,10 +380,11 @@ def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)

if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")

for d in data:
_validate_shape(d)
Expand Down Expand Up @@ -420,7 +419,8 @@ def _parse_and_validate_image_input(

return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(flatten_bn(pixel_values)),
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True).flatten(0, 1)),
)

raise AssertionError("This line should be unreachable.")
Expand Down
4 changes: 2 additions & 2 deletions vllm/multimodal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

logger = init_logger(__name__)

NestedTensors = Union[List["NestedTensors"], torch.Tensor]
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
Expand Down Expand Up @@ -61,7 +61,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return stacked
return tensors_

return torch.stack(tensors_)

Expand Down

0 comments on commit 3adeb01

Please sign in to comment.