From 48aea734553c9283ee32327aeaea708dca6a5cda Mon Sep 17 00:00:00 2001 From: hhzhang16 <54051230+hhzhang16@users.noreply.github.com> Date: Fri, 4 Oct 2024 22:05:37 -0700 Subject: [PATCH] [Bugfix] Fixes Phi3v & Ultravox Multimodal EmbeddingInputs (#8979) Signed-off-by: Sumit Dubey --- vllm/model_executor/models/phi3v.py | 20 +++++++---- vllm/model_executor/models/ultravox.py | 48 ++++++++++++++++---------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index ebfffb25360cd..b875a83f876be 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -467,9 +467,10 @@ def input_processor_for_phi3v(ctx: InputContext, input_height=h, num_crops=num_crops)) elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape + image_feature_size = [image_data.shape[0]] + image_data = [image_data] elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[1] for item in image_data] + image_feature_size = [item.shape[0] for item in image_data] else: raise TypeError(f"Invalid image type: {type(image_data)}") @@ -611,9 +612,6 @@ def _parse_and_validate_image_input( image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) - if pixel_values is None: - return None - if pixel_values is None and image_embeds is None: return None @@ -650,7 +648,17 @@ def _process_image_input( ) -> torch.Tensor: if image_input["type"] == "image_embeds": - return image_input["data"] + image_data = image_input["data"] + if is_list_of(image_data, torch.Tensor): + # it's already a list of tensors + return image_data + if len(image_data.shape) == 3: + # 3D tensor + return list(torch.unbind(image_data, dim=0)) + raise ValueError( + "We expect batched 2D tensors;" + "this can be either a list of 2D tensors or a single 3D tensor." + ) assert self.vision_embed_tokens is not None image_embeds = self.vision_embed_tokens(image_input["data"], diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index daa6e72dd1002..101cf38c96b01 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -38,6 +38,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP @@ -119,6 +120,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): if not isinstance(data, list): data = [data] + # If the audio inputs are embeddings, no need for preprocessing + if is_list_of(data, torch.Tensor, check="all"): + return MultiModalInputs({"audio_embeds": data}) + audio_features = [] for audio_input in data: if not isinstance(audio_input, tuple): @@ -165,25 +170,30 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): audios = [audios] audio_token_counts = [] - for audio_data, sample_rate in audios: - audio_length = audio_data.shape[0] - if sample_rate != feature_extractor.sampling_rate: - # Account for resampling. - adjustment = feature_extractor.sampling_rate / sample_rate - audio_length = math.ceil(adjustment * audio_length) - - feature_extractor_output_length = math.ceil( - (audio_length - (feature_extractor.hop_length - 1)) / - feature_extractor.hop_length) - - uv_config = ctx.get_hf_config(UltravoxConfig) - audio_num_tokens = min( - max( - 1, - math.ceil(feature_extractor_output_length / - (uv_config.stack_factor * 2))), - get_ultravox_max_audio_tokens(ctx)) - audio_token_counts.append(audio_num_tokens) + for audio in audios: + if isinstance(audio, torch.Tensor): + audio_num_tokens = audio.shape[1] + audio_token_counts.append(audio_num_tokens) + else: + audio_data, sample_rate = audio + audio_length = audio_data.shape[0] + if sample_rate != feature_extractor.sampling_rate: + # Account for resampling. + adjustment = feature_extractor.sampling_rate / sample_rate + audio_length = math.ceil(adjustment * audio_length) + + feature_extractor_output_length = math.ceil( + (audio_length - (feature_extractor.hop_length - 1)) / + feature_extractor.hop_length) + + uv_config = ctx.get_hf_config(UltravoxConfig) + audio_num_tokens = min( + max( + 1, + math.ceil(feature_extractor_output_length / + (uv_config.stack_factor * 2))), + get_ultravox_max_audio_tokens(ctx)) + audio_token_counts.append(audio_num_tokens) tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)