diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 50b3d4c6a89533..ce3bae57c8d88e 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -467,6 +467,7 @@ def forward( (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length ) or (input_ids.shape[-1] == 1 and pixel_values is not None) + image_features = None if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, @@ -474,69 +475,67 @@ def forward( vision_feature_select_strategy=vision_feature_select_strategy, ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + if legacy_processing: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + # prefill stage vs decoding stage (legacy behavior copied) + if input_ids.shape[1] != 1: + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels ) - # prefill stage vs decoding stage (legacy behavior copied) - if input_ids.shape[1] != 1: - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ - -target_length: - ] - - # TODO: @raushan retain only the new behavior after v4.47 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: - n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() - n_image_features = image_features.shape[1] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + elif image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, @@ -597,12 +596,6 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - # Trigger the new behavior if we have more than image embeddings seq length tokens for images - legacy_processing = ( - input_ids is not None - and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) - model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -613,7 +606,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing or cache_position[0] == 0: + if cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 0cbda9cfd64b74..5a49337b2b5d96 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -846,6 +846,7 @@ def forward( (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length ) or (input_ids.shape[-1] == 1 and pixel_values is not None) + image_features = None if pixel_values is not None and pixel_values.size(0) > 0: image_features = self.get_image_features( pixel_values, @@ -861,74 +862,73 @@ def forward( vision_feature_select_strategy=vision_feature_select_strategy, image_newline=self.image_newline, ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + + if legacy_processing: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + if input_ids.shape[1] != 1: + inputs_embeds = inputs_embeds.to(image_features.dtype) + inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features( + image_features, + feature_lens, + inputs_embeds, + input_ids, + attention_mask, + position_ids, + labels=labels, ) - if input_ids.shape[1] != 1: - inputs_embeds = inputs_embeds.to(image_features.dtype) - inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features( - image_features, - feature_lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + else: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ - -target_length: - ] + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # TODO: @raushan retain only the new behavior after v4.47 - else: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + elif image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, @@ -990,11 +990,6 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - legacy_processing = ( - input_ids is not None - and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) - model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -1007,7 +1002,7 @@ def prepare_inputs_for_generation( # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model - if legacy_processing or cache_position[0] == 0: + if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values model_inputs["image_sizes"] = image_sizes diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 96f4373afd9ec6..44b372535d70bd 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -1110,17 +1110,6 @@ def prepare_inputs_for_generation( ): # Overwritten -- extra custom processing - if input_ids is not None: - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - legacy_processing = (img_token_not_enough and pixel_values is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -1133,7 +1122,7 @@ def prepare_inputs_for_generation( # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model - if legacy_processing or cache_position[0] == 0: + if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values_videos"] = pixel_values_videos model_inputs["image_sizes"] = image_sizes diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index c1ed7571941b9e..e9974e920493ff 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -623,17 +623,6 @@ def prepare_inputs_for_generation( ): # Overwritten -- extra custom processing - if input_ids is not None: - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - legacy_processing = (img_token_not_enough and pixel_values is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -646,7 +635,7 @@ def prepare_inputs_for_generation( # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model - if legacy_processing or cache_position[0] == 0: + if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values_videos"] = pixel_values_videos model_inputs["image_sizes"] = image_sizes diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index c4ec1b5196929a..86f4c8bd11cde7 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -714,17 +714,6 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - if input_ids is not None: - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - legacy_processing = (img_token_not_enough and pixel_values_images is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -735,7 +724,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing or cache_position[0] == 0: + if cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values_images"] = pixel_values_images diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index dd7baa34406fb0..68211ad046fa69 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -461,72 +461,71 @@ def forward( (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length ) or (input_ids.shape[-1] == 1 and pixel_values is not None) + image_features = None if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layers=vision_feature_layers ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in VipLLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + if legacy_processing: + logger.warning_once( + "Expanding inputs for image tokens in VipLLaVa should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + # prefill stage vs decoding stage (legacy behavior copied) + if input_ids.shape[1] != 1: + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels ) - # prefill stage vs decoding stage (legacy behavior copied) - if input_ids.shape[1] != 1: - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # in the case one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ - -target_length: - ] - - # TODO: @raushan retain only the new behavior after v4.47 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: - n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() - n_image_features = image_features.shape[1] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # Filter out only the tokens that can be un-attended, this can happen + # in the case one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + elif image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, @@ -585,12 +584,6 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - # Trigger the new behavior if we have more than image embeddings seq length tokens for images - legacy_processing = ( - input_ids is not None - and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) - model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -601,7 +594,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing or cache_position[0] == 0: + if cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values