Skip to content

Commit

Permalink
[Bugfix] Fixes Phi3v & Ultravox Multimodal EmbeddingInputs (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#8979)

Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
hhzhang16 authored and sumitd2 committed Nov 14, 2024
1 parent 0e48174 commit 48aea73
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
20 changes: 14 additions & 6 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,10 @@ def input_processor_for_phi3v(ctx: InputContext,
input_height=h,
num_crops=num_crops))
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
image_feature_size = [image_data.shape[0]]
image_data = [image_data]
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
image_feature_size = [item.shape[0] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

Expand Down Expand Up @@ -611,9 +612,6 @@ def _parse_and_validate_image_input(
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None:
return None

if pixel_values is None and image_embeds is None:
return None

Expand Down Expand Up @@ -650,7 +648,17 @@ def _process_image_input(
) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]
image_data = image_input["data"]
if is_list_of(image_data, torch.Tensor):
# it's already a list of tensors
return image_data
if len(image_data.shape) == 3:
# 3D tensor
return list(torch.unbind(image_data, dim=0))
raise ValueError(
"We expect batched 2D tensors;"
"this can be either a list of 2D tensors or a single 3D tensor."
)

assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(image_input["data"],
Expand Down
48 changes: 29 additions & 19 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP

Expand Down Expand Up @@ -119,6 +120,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]

# If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"):
return MultiModalInputs({"audio_embeds": data})

audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
Expand Down Expand Up @@ -165,25 +170,30 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
audios = [audios]

audio_token_counts = []
for audio_data, sample_rate in audios:
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)

feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)

uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)
for audio in audios:
if isinstance(audio, torch.Tensor):
audio_num_tokens = audio.shape[1]
audio_token_counts.append(audio_num_tokens)
else:
audio_data, sample_rate = audio
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)

feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)

uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)

tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)

Expand Down

0 comments on commit 48aea73

Please sign in to comment.