Skip to content

Commit

Permalink
[Model] LLaVA model refactor (vllm-project#4910)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored May 20, 2024
1 parent b57e6c5 commit 6287537
Showing 1 changed file with 107 additions and 30 deletions.
137 changes: 107 additions & 30 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
from torch import nn
Expand Down Expand Up @@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor,
return inputs_embeds


class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""


class LlavaImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, image_feature_size, hidden_size)"""


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]


class LlavaForConditionalGeneration(VisionLanguageModelBase):

def __init__(self,
Expand Down Expand Up @@ -102,6 +117,90 @@ def __init__(self,
config.vocab_size, logit_scale)
self.sampler = Sampler()

def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"{self.vision_language_config.image_input_shape[1:]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")

return data

def _parse_and_validate_image_input(
self, data: object) -> Optional[LlavaImageInputs]:
expected_input_type = self.vision_language_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType

if data is None:
return None

if expected_input_type == ImageInputType.PIXEL_VALUES:
if not isinstance(data, torch.Tensor):
raise TypeError("Image pixel vector should be a tensor, "
f"but received type: {type(data)}")

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(data),
)
elif expected_input_type == ImageInputType.IMAGE_FEATURES:
if not isinstance(data, torch.Tensor):
raise TypeError("Image feature vector should be a tensor, "
f"but received type: {type(data)}")

return LlavaImageFeatureInputs(
type="image_features",
data=self._validate_image_data(data),
)

return None

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features

raise ValueError(f"Unexpected select feature strategy: {strategy}")

def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)

image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]

return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)

def _process_image_pixels(self,
inputs: LlavaImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None

pixel_values = inputs["data"]

return self._image_pixels_to_features(self.vision_tower, pixel_values)

def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]

return self.multi_modal_projector(image_features)

def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
Expand Down Expand Up @@ -144,42 +243,20 @@ def forward(self,
For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024].
"""
if image_input is not None:
if list(image_input.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension "
f"plus "
f"{self.vision_language_config.image_input_shape[1:]}."
f" You supplied {image_input.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
if self.vision_tower is not None:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = self.vision_tower(image_input,
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if self.config.vision_feature_select_strategy == "default":
image_features = image_features[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
image_features = image_features
else:
raise ValueError(
f"Unexpected select feature strategy: "
f"{self.config.vision_feature_select_strategy}")
else:
image_features = image_input
vision_embeddings = self.multi_modal_projector(image_features)
parsed_image_input = self._parse_and_validate_image_input(image_input)

if parsed_image_input is not None:
vision_embeddings = self._process_image_input(parsed_image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)

inputs_embeds = _merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)

input_ids = None
else:
inputs_embeds = None

hidden_states = self.language_model(input_ids,
positions,
kv_caches,
Expand Down

0 comments on commit 6287537

Please sign in to comment.