diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 2071e41e672fca..e1acfcefa09442 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -190,6 +190,7 @@ Flax), PyTorch, and/or TensorFlow. | CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | | Data2VecAudio | ❌ | ❌ | ✅ | ❌ | ❌ | | Data2VecText | ❌ | ❌ | ✅ | ❌ | ❌ | +| Data2VecVision | ❌ | ❌ | ✅ | ❌ | ❌ | | DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ | | DeBERTa-v2 | ✅ | ❌ | ✅ | ✅ | ❌ | | Decision Transformer | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/data2vec.mdx b/docs/source/en/model_doc/data2vec.mdx index f84593d0f9a59b..87f5a2fb43411c 100644 --- a/docs/source/en/model_doc/data2vec.mdx +++ b/docs/source/en/model_doc/data2vec.mdx @@ -33,10 +33,13 @@ Models and code are available at www.github.com/pytorch/fairseq/tree/master/exam Tips: -- Both Data2VecAudio and Data2VecText have been trained using the same self-supervised learning method. - In the case of Data2VecAudio, preprocessing is identical to [`RobertaModel`], including tokenization. +- Data2VecAudio, Data2VecText, and Data2VecVision have all been trained using the same self-supervised learning method. +- For Data2VecAudio, preprocessing is identical to [`Wav2Vec2Model`], including feature extraction +- For Data2VecText, preprocessing is identical to [`RobertaModel`], including tokenization. +- For Data2VecVision, preprocessing is identical to [`BeitModel`], including feature extraction. + +This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten) -This model was contributed by [edugp](https://huggingface.co/edugp). The original code can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec). @@ -48,12 +51,16 @@ The original code can be found [here](https://github.com/pytorch/fairseq/tree/ma [[autodoc]] Data2VecAudioConfig +## Data2VecVisionConfig + +[[autodoc]] Data2VecVisionConfig + + ## Data2VecAudioModel [[autodoc]] Data2VecAudioModel - forward - ## Data2VecAudioForAudioFrameClassification [[autodoc]] Data2VecAudioForAudioFrameClassification @@ -108,3 +115,18 @@ The original code can be found [here](https://github.com/pytorch/fairseq/tree/ma [[autodoc]] Data2VecTextForQuestionAnswering - forward + +## Data2VecVisionModel + +[[autodoc]] Data2VecVisionModel + - forward + +## Data2VecVisionForImageClassification + +[[autodoc]] Data2VecVisionForImageClassification + - forward + +## Data2VecVisionForSemanticSegmentation + +[[autodoc]] Data2VecVisionForSemanticSegmentation + - forward diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 6020b9fe70a2e3..f8e88b3501a818 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -54,6 +54,7 @@ Ready-made configurations include the following architectures: - BlenderbotSmall - CamemBERT - Data2VecText +- Data2VecVision - DistilBERT - ELECTRA - FlauBERT diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a399cbe3633881..779025cd56a59c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -170,7 +170,13 @@ "models.convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"], "models.cpm": ["CpmTokenizer"], "models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"], - "models.data2vec": ["DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Data2VecAudioConfig", "Data2VecTextConfig"], + "models.data2vec": [ + "DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP", + "Data2VecAudioConfig", + "Data2VecTextConfig", + "Data2VecVisionConfig", + ], "models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"], "models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"], "models.decision_transformer": ["DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "DecisionTransformerConfig"], @@ -868,6 +874,7 @@ [ "DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST", "DATA2VEC_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", + "DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST", "Data2VecAudioForAudioFrameClassification", "Data2VecAudioForCTC", "Data2VecAudioForSequenceClassification", @@ -882,6 +889,10 @@ "Data2VecTextForTokenClassification", "Data2VecTextModel", "Data2VecTextPreTrainedModel", + "Data2VecVisionForImageClassification", + "Data2VecVisionForSemanticSegmentation", + "Data2VecVisionModel", + "Data2VecVisionPreTrainedModel", ] ) _import_structure["models.deberta"].extend( @@ -2555,7 +2566,13 @@ from .models.convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig from .models.cpm import CpmTokenizer from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer - from .models.data2vec import DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Data2VecAudioConfig, Data2VecTextConfig + from .models.data2vec import ( + DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, + DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP, + Data2VecAudioConfig, + Data2VecTextConfig, + Data2VecVisionConfig, + ) from .models.deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig, DebertaTokenizer from .models.deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config from .models.decision_transformer import ( @@ -3151,6 +3168,7 @@ from .models.data2vec import ( DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST, DATA2VEC_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, + DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST, Data2VecAudioForAudioFrameClassification, Data2VecAudioForCTC, Data2VecAudioForSequenceClassification, @@ -3165,6 +3183,10 @@ Data2VecTextForTokenClassification, Data2VecTextModel, Data2VecTextPreTrainedModel, + Data2VecVisionForImageClassification, + Data2VecVisionForSemanticSegmentation, + Data2VecVisionModel, + Data2VecVisionPreTrainedModel, ) from .models.deberta import ( DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 88235528d6a6f6..210199aec81858 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -59,6 +59,7 @@ ("layoutlmv2", "LayoutLMv2Config"), ("plbart", "PLBartConfig"), ("beit", "BeitConfig"), + ("data2vec-vision", "Data2VecVisionConfig"), ("rembert", "RemBertConfig"), ("visual_bert", "VisualBertConfig"), ("canine", "CanineConfig"), @@ -162,6 +163,7 @@ ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("data2vec-vision", "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -349,12 +351,18 @@ ("layoutxlm", "LayoutXLM"), ("data2vec-audio", "Data2VecAudio"), ("data2vec-text", "Data2VecText"), + ("data2vec-vision", "Data2VecVision"), ("dit", "DiT"), ] ) SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict( - [("openai-gpt", "openai"), ("data2vec-audio", "data2vec"), ("data2vec-text", "data2vec")] + [ + ("openai-gpt", "openai"), + ("data2vec-audio", "data2vec"), + ("data2vec-text", "data2vec"), + ("data2vec-vision", "data2vec"), + ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index fbe3b50c98a973..78d32ad0708910 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -55,6 +55,7 @@ ("layoutlmv2", "LayoutLMv2Model"), ("plbart", "PLBartModel"), ("beit", "BeitModel"), + ("data2vec-vision", "Data2VecVisionModel"), ("rembert", "RemBertModel"), ("visual_bert", "VisualBertModel"), ("canine", "CanineModel"), @@ -290,6 +291,7 @@ ("vit", "ViTForImageClassification"), ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), ("beit", "BeitForImageClassification"), + ("data2vec-vision", "Data2VecVisionForImageClassification"), ("segformer", "SegformerForImageClassification"), ("imagegpt", "ImageGPTForImageClassification"), ( @@ -321,6 +323,7 @@ [ # Model for Semantic Segmentation mapping ("beit", "BeitForSemanticSegmentation"), + ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"), ("segformer", "SegformerForSemanticSegmentation"), ("dpt", "DPTForSemanticSegmentation"), ] diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 80a49fcd20bd76..73eaf26b7ee30a 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -702,7 +702,8 @@ def forward( pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] return BeitModelOutputWithPooling( last_hidden_state=sequence_output, @@ -713,7 +714,7 @@ def forward( class BeitPooler(nn.Module): - def __init__(self, config: BeitModel) -> None: + def __init__(self, config: BeitConfig) -> None: super().__init__() self.layernorm = ( nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None @@ -736,7 +737,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: BEIT_START_DOCSTRING, ) class BeitForMaskedImageModeling(BeitPreTrainedModel): - def __init__(self, config: BeitModel) -> None: + def __init__(self, config: BeitConfig) -> None: super().__init__(config) self.num_labels = config.num_labels @@ -817,7 +818,7 @@ def forward( masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels) if not return_dict: - output = (prediction_scores,) + outputs[2:] + output = (prediction_scores,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return MaskedLMOutput( @@ -836,7 +837,7 @@ def forward( BEIT_START_DOCSTRING, ) class BeitForImageClassification(BeitPreTrainedModel): - def __init__(self, config: BeitModel) -> None: + def __init__(self, config: BeitConfig) -> None: super().__init__(config) self.num_labels = config.num_labels @@ -1237,7 +1238,7 @@ def forward( return_dict=return_dict, ) - encoder_hidden_states = outputs.hidden_states if return_dict else outputs[2] + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] # only keep certain features, and reshape # note that we do +1 as the encoder_hidden_states also includes the initial embeddings @@ -1268,9 +1269,9 @@ def forward( if not return_dict: if output_hidden_states: - output = (logits,) + outputs[2:] + output = (logits,) + outputs[1:] else: - output = (logits,) + outputs[3:] + output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SemanticSegmenterOutput( diff --git a/src/transformers/models/data2vec/__init__.py b/src/transformers/models/data2vec/__init__.py index d93c64d1e324ae..a1296fd334e365 100644 --- a/src/transformers/models/data2vec/__init__.py +++ b/src/transformers/models/data2vec/__init__.py @@ -31,6 +31,11 @@ "Data2VecTextConfig", "Data2VecTextOnnxConfig", ], + "configuration_data2vec_vision": [ + "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP", + "Data2VecVisionConfig", + "Data2VecVisionOnnxConfig", + ], } if is_torch_available(): @@ -54,6 +59,14 @@ "Data2VecTextModel", "Data2VecTextPreTrainedModel", ] + _import_structure["modeling_data2vec_vision"] = [ + "DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST", + "Data2VecVisionForImageClassification", + "Data2VecVisionForMaskedImageModeling", + "Data2VecVisionForSemanticSegmentation", + "Data2VecVisionModel", + "Data2VecVisionPreTrainedModel", + ] if TYPE_CHECKING: from .configuration_data2vec_audio import DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP, Data2VecAudioConfig @@ -62,6 +75,11 @@ Data2VecTextConfig, Data2VecTextOnnxConfig, ) + from .configuration_data2vec_vision import ( + DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP, + Data2VecVisionConfig, + Data2VecVisionOnnxConfig, + ) if is_torch_available(): from .modeling_data2vec_audio import ( @@ -84,6 +102,14 @@ Data2VecTextModel, Data2VecTextPreTrainedModel, ) + from .modeling_data2vec_vision import ( + DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST, + Data2VecVisionForImageClassification, + Data2VecVisionForMaskedImageModeling, + Data2VecVisionForSemanticSegmentation, + Data2VecVisionModel, + Data2VecVisionPreTrainedModel, + ) else: import sys diff --git a/src/transformers/models/data2vec/configuration_data2vec_vision.py b/src/transformers/models/data2vec/configuration_data2vec_vision.py new file mode 100644 index 00000000000000..5508f4d9e7e779 --- /dev/null +++ b/src/transformers/models/data2vec/configuration_data2vec_vision.py @@ -0,0 +1,200 @@ +# coding=utf-8 +# Copyright Meta Platforms and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Data2VecVision model configuration""" +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/data2vec-vision-base-ft": "https://huggingface.co/facebook/data2vec-vision-base-ft/resolve/main/config.json", +} + + +class Data2VecVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Data2VecVisionModel`]. It is used to instantiate + an Data2VecVision model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Data2VecVision + [facebook/data2vec-vision-base](https://huggingface.co/facebook/data2vec-vision-base) architecture. + + Args: + vocab_size (`int`, *optional*, defaults to 8092): + Vocabulary size of the Data2VecVision model. Defines the number of different image tokens that can be used + during pre-training. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether to use a mask token for masked image modeling. + use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`): + Whether to use BERT-style absolute position embeddings. + use_relative_position_bias (`bool`, *optional*, defaults to `False`): + Whether to use T5-style relative position embeddings in the self-attention layers. + use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`): + Whether to use the same relative position embeddings across all self-attention layers of the Transformer. + layer_scale_init_value (`float`, *optional*, defaults to 0.1): + Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_mean_pooling (`bool`, *optional*, defaults to `True`): + Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the + CLS token, before applying the classification head. + out_indices (`List[int]`, *optional*, defaults to `[3, 5, 7, 11]`): + Indices of the feature maps to use for semantic segmentation. + pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`): + Pooling scales used in Pooling Pyramid Module applied on the last feature map. + use_auxiliary_head (`bool`, *optional*, defaults to `True`): + Whether to use an auxiliary head during training. + auxiliary_loss_weight (`float`, *optional*, defaults to 0.4): + Weight of the cross-entropy loss of the auxiliary head. + auxiliary_channels (`int`, *optional*, defaults to 256): + Number of channels to use in the auxiliary head. + auxiliary_num_convs (`int`, *optional*, defaults to 1): + Number of convolutional layers to use in the auxiliary head. + auxiliary_concat_input (`bool`, *optional*, defaults to `False`): + Whether to concatenate the output of the auxiliary head with the input before the classification layer. + semantic_loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function of the semantic segmentation model. + + Example: + + ```python + >>> from transformers import Data2VecVisionModel, Data2VecVisionConfig + + >>> # Initializing a Data2VecVision data2vec_vision-base-patch16-224-in22k style configuration + >>> configuration = Data2VecVisionConfig() + + >>> # Initializing a model from the data2vec_vision-base-patch16-224-in22k style configuration + >>> model = Data2VecVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "data2vec-vision" + + def __init__( + self, + vocab_size=8192, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + is_encoder_decoder=False, + image_size=224, + patch_size=16, + num_channels=3, + use_mask_token=False, + use_absolute_position_embeddings=False, + use_relative_position_bias=False, + use_shared_relative_position_bias=False, + layer_scale_init_value=0.1, + drop_path_rate=0.1, + use_mean_pooling=True, + out_indices=[3, 5, 7, 11], + pool_scales=[1, 2, 3, 6], + use_auxiliary_head=True, + auxiliary_loss_weight=0.4, + auxiliary_channels=256, + auxiliary_num_convs=1, + auxiliary_concat_input=False, + semantic_loss_ignore_index=255, + **kwargs + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.use_mask_token = use_mask_token + self.use_absolute_position_embeddings = use_absolute_position_embeddings + self.use_relative_position_bias = use_relative_position_bias + self.use_shared_relative_position_bias = use_shared_relative_position_bias + self.layer_scale_init_value = layer_scale_init_value + self.drop_path_rate = drop_path_rate + self.use_mean_pooling = use_mean_pooling + # decode head attributes (semantic segmentation) + self.out_indices = out_indices + self.pool_scales = pool_scales + # auxiliary head attributes (semantic segmentation) + self.use_auxiliary_head = use_auxiliary_head + self.auxiliary_loss_weight = auxiliary_loss_weight + self.auxiliary_channels = auxiliary_channels + self.auxiliary_num_convs = auxiliary_num_convs + self.auxiliary_concat_input = auxiliary_concat_input + self.semantic_loss_ignore_index = semantic_loss_ignore_index + + +# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig +class Data2VecVisionOnnxConfig(OnnxConfig): + + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 diff --git a/src/transformers/models/data2vec/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/data2vec/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py new file mode 100755 index 00000000000000..b375167c8de83c --- /dev/null +++ b/src/transformers/models/data2vec/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +import argparse +import json + +import torch +from PIL import Image + +from huggingface_hub import hf_hub_download +from timm.models import create_model +from transformers import ( + BeitFeatureExtractor, + Data2VecVisionConfig, + Data2VecVisionForImageClassification, + Data2VecVisionModel, +) + + +def create_rename_keys(config, has_lm_head=False, is_semantic=False, hf_prefix="data2vec."): + prefix = "backbone." if is_semantic else "" + + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append( + (f"{prefix}blocks.{i}.norm1.weight", f"{hf_prefix}encoder.layer.{i}.layernorm_before.weight") + ) + rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"{hf_prefix}encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append( + (f"{prefix}blocks.{i}.attn.proj.weight", f"{hf_prefix}encoder.layer.{i}.attention.output.dense.weight") + ) + rename_keys.append( + (f"{prefix}blocks.{i}.attn.proj.bias", f"{hf_prefix}encoder.layer.{i}.attention.output.dense.bias") + ) + rename_keys.append( + (f"{prefix}blocks.{i}.norm2.weight", f"{hf_prefix}encoder.layer.{i}.layernorm_after.weight") + ) + rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"{hf_prefix}encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append( + (f"{prefix}blocks.{i}.mlp.fc1.weight", f"{hf_prefix}encoder.layer.{i}.intermediate.dense.weight") + ) + rename_keys.append( + (f"{prefix}blocks.{i}.mlp.fc1.bias", f"{hf_prefix}encoder.layer.{i}.intermediate.dense.bias") + ) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"{hf_prefix}encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"{hf_prefix}encoder.layer.{i}.output.dense.bias")) + + # projection layer + position embeddings + rename_keys.extend( + [ + (f"{prefix}cls_token", f"{hf_prefix}embeddings.cls_token"), + (f"{prefix}patch_embed.proj.weight", f"{hf_prefix}embeddings.patch_embeddings.projection.weight"), + (f"{prefix}patch_embed.proj.bias", f"{hf_prefix}embeddings.patch_embeddings.projection.bias"), + ] + ) + + if has_lm_head: + # mask token + shared relative position bias + layernorm + rename_keys.extend( + [ + ("mask_token", f"{hf_prefix}embeddings.mask_token"), + ( + "rel_pos_bias.relative_position_bias_table", + f"{hf_prefix}encoder.relative_position_bias.relative_position_bias_table", + ), + ( + "rel_pos_bias.relative_position_index", + f"{hf_prefix}encoder.relative_position_bias.relative_position_index", + ), + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ] + ) + elif is_semantic: + # semantic segmentation classification heads + rename_keys.extend( + [ + ("decode_head.conv_seg.weight", "decode_head.classifier.weight"), + ("decode_head.conv_seg.bias", "decode_head.classifier.bias"), + ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"), + ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"), + ] + ) + else: + # layernorm + classification head + rename_keys.extend( + [ + ("fc_norm.weight", f"{hf_prefix}pooler.layernorm.weight"), + ("fc_norm.bias", f"{hf_prefix}pooler.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False, hf_prefix="data2vec_vision."): + for i in range(config.num_hidden_layers): + prefix = "backbone." if is_semantic else "" + # queries, keys and values + in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight") + q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias") + v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias") + + state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias + state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias + + # gamma_1 and gamma_2 + # we call them lambda because otherwise they are renamed when using .from_pretrained + gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1") + gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2") + + state_dict[f"{hf_prefix}encoder.layer.{i}.lambda_1"] = gamma_1 + state_dict[f"{hf_prefix}encoder.layer.{i}.lambda_2"] = gamma_2 + + # relative_position bias table + index + if not has_lm_head: + # each layer has its own relative position bias + table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table") + index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index") + + state_dict[ + f"{hf_prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table" + ] = table + state_dict[ + f"{hf_prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index" + ] = index + + +def get_args(): + parser = argparse.ArgumentParser( + "Convert Data2VecVision to HF for image classification and pretraining", add_help=False + ) + parser.add_argument("--hf_checkpoint_name", type=str) + parser.add_argument("--input_size", default=224, type=int, help="images input size") + parser.add_argument("--beit_checkpoint", default="", help="beit checkpoint") + + return parser.parse_args() + + +def load_beit_model(args, is_finetuned, is_large): + def load_state_dict(model, state_dict, prefix="", ignore_missing="relative_position_index"): + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model, prefix=prefix) + + warn_missing_keys = [] + ignore_missing_keys = [] + for key in missing_keys: + keep_flag = True + for ignore_key in ignore_missing.split("|"): + if ignore_key in key: + keep_flag = False + break + if keep_flag: + warn_missing_keys.append(key) + else: + ignore_missing_keys.append(key) + + missing_keys = warn_missing_keys + + if len(missing_keys) > 0: + print( + "Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys + ) + ) + if len(unexpected_keys) > 0: + print("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)) + if len(ignore_missing_keys) > 0: + print( + "Ignored weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, ignore_missing_keys + ) + ) + if len(error_msgs) > 0: + print("\n".join(error_msgs)) + + model_kwargs = { + "pretrained": False, + "use_shared_rel_pos_bias": True, + "use_abs_pos_emb": False, + "init_values": 0.1, + } + + if is_finetuned: + model_kwargs.update( + { + "num_classes": 1000, + "use_mean_pooling": True, + "init_scale": 0.001, + "use_rel_pos_bias": True, + } + ) + + model = create_model( + "beit_large_patch16_224" if is_large else "beit_base_patch16_224", + **model_kwargs, + ) + patch_size = model.patch_embed.patch_size + args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1]) + checkpoint = torch.load(args.beit_checkpoint, map_location="cpu") + + print(f"Load ckpt from {args.beit_checkpoint}") + checkpoint_model = None + for model_key in ("model", "module"): + if model_key in checkpoint: + checkpoint_model = checkpoint[model_key] + print(f"Load state_dict by model_key = {model_key}") + break + + all_keys = list(checkpoint_model.keys()) + for key in all_keys: + if "relative_position_index" in key: + checkpoint_model.pop(key) + + if "relative_position_bias_table" in key: + rel_pos_bias = checkpoint_model[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = model.state_dict()[key].size() + dst_patch_shape = model.patch_embed.patch_shape + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + + load_state_dict(model, checkpoint_model, prefix="") + + return model + + +def main(): + args = get_args() + + is_finetuned = "ft1k" in args.hf_checkpoint_name + is_large = "large" in args.hf_checkpoint_name + + if is_finetuned: + # To convert Beit's data2vec_vision to HF you need to copy + # https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_finetune.py + # into this folder. + import modeling_finetune # noqa: F401 + else: + # To convert Beit's data2vec_vision to HF you need to copy + # https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_cyclical.py + # into this folder + # IMPORTANT: Note that for now we've only converted the down-stream + # model and not the full pretrained model. This means for the integration + # test you need to add a `return x` after the following line: + # https://github.com/facebookresearch/data2vec_vision/blob/af9a36349aaed59ae66e69b5dabeef2d62fdc5da/beit/modeling_cyclical.py#L197 + # to make the integration test pass. + import modeling_cyclical # noqa: F401 + + # 1. Create model config + config = Data2VecVisionConfig() + if is_finetuned: + config.use_relative_position_bias = True + config.use_shared_relative_position_bias = False + config.use_mean_pooling = True + config.num_labels = 1000 + + repo_id = "datasets/huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + else: + config.use_relative_position_bias = False + config.use_shared_relative_position_bias = True + config.use_mean_pooling = False + + if is_large: + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + + # 2. Load Beit model + orig_model = load_beit_model(args, is_finetuned, is_large) + orig_model.eval() + + # 3. Forward Beit model + feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False) + image = Image.open("../../../../tests/fixtures/tests_samples/COCO/000000039769.png") + encoding = feature_extractor(images=image, return_tensors="pt") + pixel_values = encoding["pixel_values"] + + orig_args = (pixel_values,) if is_finetuned else (pixel_values, None) + with torch.no_grad(): + orig_model_output = orig_model(*orig_args) + + # 4. Load HF Data2VecVision model + if is_finetuned: + hf_model = Data2VecVisionForImageClassification(config) + hf_model.eval() + has_lm_head = False + hf_prefix = "data2vec_vision." + else: + hf_model = Data2VecVisionModel(config) + hf_model.eval() + has_lm_head = True + hf_prefix = "" + + rename_keys = create_rename_keys(config, hf_prefix=hf_prefix, has_lm_head=has_lm_head) + state_dict = orig_model.state_dict() + for src, dest in rename_keys: + val = state_dict.pop(src) + state_dict[dest] = val + + read_in_q_k_v(state_dict, config, hf_prefix=hf_prefix, has_lm_head=has_lm_head) + missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False) + print("HF missing", missing_keys) + print("HF unexpected_keys", unexpected_keys) + + # 5. Forward HF Data2VecVision model + with torch.no_grad(): + hf_model_output = hf_model(pixel_values) + + hf_output = hf_model_output.logits if is_finetuned else hf_model_output.last_hidden_state + + # 6. Compare + max_absolute_diff = torch.max(torch.abs(hf_output - orig_model_output)).item() + + print(f"max_absolute_diff = {max_absolute_diff}") + success = torch.allclose(hf_output, orig_model_output, atol=1e-3) + print("Do both models output the same tensors?", "🔥" if success else "💩") + if not success: + raise Exception("Something went wRoNg") + + # 7. Save + print(f"Saving to {args.hf_checkpoint_name}") + hf_model.save_pretrained(args.hf_checkpoint_name) + feature_extractor.save_pretrained(args.hf_checkpoint_name) + + +if __name__ == "__main__": + main() + # Run the following to convert checkpoints + # python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \ + # --beit_checkpoint ./pretrained_base.pt \ + # --hf_checkpoint_name "./data2vec-vision-base" + # python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \ + # --beit_checkpoint ./finetuned_base.pt \ + # --hf_checkpoint_name "./data2vec-vision-base-ft1k" + # python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \ + # --beit_checkpoint ./pretrained_large.pt \ + # --hf_checkpoint_name "./data2vec-vision-large" + # python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \ + # --beit_checkpoint ./finetuned_large.pt \ + # --hf_checkpoint_name "./data2vec-vision-large-ft1k" diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py new file mode 100644 index 00000000000000..3e3d4cc4f3418e --- /dev/null +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -0,0 +1,1212 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Data2VecVision model.""" + + +import collections.abc +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + SemanticSegmenterOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_data2vec_vision import Data2VecVisionConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Data2VecVisionConfig" +_FEAT_EXTRACTOR_FOR_DOC = "BeitFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base" +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k" +_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote" + +DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/data2vec-vision-base-ft1k", + # See all Data2VecVision models at https://huggingface.co/models?filter=data2vec-vision +] + + +@dataclass +# Copied from transformers.models.beit.modeling_beit.BeitModelOutputWithPooling with Beit->Data2VecVision +class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling): + """ + Class for outputs of [`Data2VecVisionModel`]. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if + *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token + will be returned. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + +# Inspired by +# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py +# From PyTorch internals +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.DropPath +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision +class Data2VecVisionEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + + """ + + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + if config.use_mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + else: + self.mask_token = None + self.patch_embeddings = PatchEmbeddings( + image_size=config.image_size, + patch_size=config.patch_size, + num_channels=config.num_channels, + embed_dim=config.hidden_size, + ) + num_patches = self.patch_embeddings.num_patches + if config.use_absolute_position_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + else: + self.position_embeddings = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: + + embeddings = self.patch_embeddings(pixel_values) + batch_size, seq_len, _ = embeddings.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1 - w) + mask_tokens * w + + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + + return embeddings + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copied from transformers.models.beit.modeling_beit.PatchEmbeddings +class PatchEmbeddings(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768 + ) -> None: + super().__init__() + image_size = to_2tuple(image_size) + patch_size = to_2tuple(patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + self.patch_shape = patch_shape + + self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + # FIXME look at relaxing size constraints + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + x = self.projection(pixel_values).flatten(2).transpose(1, 2) + + return x + + +# Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision +class Data2VecVisionSelfAttention(nn.Module): + def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + if window_size: + self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size) + else: + self.relative_position_bias = None + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Add relative position bias if present. + if self.relative_position_bias is not None: + attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + attention_scores = attention_scores + relative_position_bias + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.beit.modeling_beit.BeitSelfOutput with Beit->Data2VecVision +class Data2VecVisionSelfOutput(nn.Module): + """ + The residual connection is defined in Data2VecVisionLayer instead of here (as is the case with other models), due + to the layernorm applied before each block. + """ + + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision +class Data2VecVisionAttention(nn.Module): + def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + self.attention = Data2VecVisionSelfAttention(config, window_size=window_size) + self.output = Data2VecVisionSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.beit.modeling_beit.BeitIntermediate with Beit->Data2VecVision +class Data2VecVisionIntermediate(nn.Module): + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.beit.modeling_beit.BeitOutput with Beit->Data2VecVision +class Data2VecVisionOutput(nn.Module): + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision +class Data2VecVisionLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__( + self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0 + ) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Data2VecVisionAttention(config, window_size=window_size) + self.intermediate = Data2VecVisionIntermediate(config) + self.output = Data2VecVisionOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + init_values = config.layer_scale_init_value + if init_values > 0: + self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True) + self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True) + else: + self.lambda_1, self.lambda_2 = None, None + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # apply lambda_1 if present + if self.lambda_1 is not None: + attention_output = self.lambda_1 * attention_output + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Data2VecVision, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output) + + if self.lambda_2 is not None: + layer_output = self.lambda_2 * layer_output + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.beit.modeling_beit.BeitRelativePositionBias with Beit->Data2VecVision +class Data2VecVisionRelativePositionBias(nn.Module): + def __init__(self, config: Data2VecVisionConfig, window_size: tuple) -> None: + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, config.num_attention_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1 + ) # Wh*Ww,Wh*Ww,nH + + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision +class Data2VecVisionEncoder(nn.Module): + def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None: + super().__init__() + self.config = config + if config.use_shared_relative_position_bias: + self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size) + else: + self.relative_position_bias = None + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] + self.layer = nn.ModuleList( + [ + Data2VecVisionLayer( + config, + window_size=window_size if config.use_relative_position_bias else None, + drop_path_rate=dpr[i], + ) + for i in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + ) + else: + relative_position_bias = ( + self.relative_position_bias() if self.relative_position_bias is not None else None + ) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.beit.modeling_beit.BeitPreTrainedModel with Beit->Data2VecVision,beit->data2vec_vision +class Data2VecVisionPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Data2VecVisionConfig + base_model_prefix = "data2vec_vision" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, Data2VecVisionEncoder): + module.gradient_checkpointing = value + + +DATA2VEC_VISION_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DATA2VEC_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`BeitFeatureExtractor`]. See + [`BeitFeatureExtractor.__call__`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.", + DATA2VEC_VISION_START_DOCSTRING, +) +# Copied from transformers.models.beit.modeling_beit.BeitModel with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,True->False +class Data2VecVisionModel(Data2VecVisionPreTrainedModel): + def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False) -> None: + super().__init__(config) + self.config = config + + self.embeddings = Data2VecVisionEmbeddings(config) + self.encoder = Data2VecVisionEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape) + + self.layernorm = ( + nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) + self.pooler = Data2VecVisionPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Data2VecVisionModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, Data2VecVisionModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return Data2VecVisionModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.beit.modeling_beit.BeitPooler with Beit->Data2VecVision +class Data2VecVisionPooler(nn.Module): + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + self.layernorm = ( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.layernorm is not None: + # Mean pool the final hidden states of the patch tokens + patch_tokens = hidden_states[:, 1:, :] + pooled_output = self.layernorm(patch_tokens.mean(1)) + else: + # Pool by simply taking the final hidden state of the [CLS] token + pooled_output = hidden_states[:, 0] + + return pooled_output + + +@add_start_docstrings( + """ + Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of + the final hidden states of the patch tokens) e.g. for ImageNet. + """, + DATA2VEC_VISION_START_DOCSTRING, +) +# Copied from transformers.models.beit.modeling_beit.BeitForImageClassification with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,beit->data2vec_vision +class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel): + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=True) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.data2vec_vision( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.beit.modeling_beit.BeitConvModule with Beit->Data2VecVision +class Data2VecVisionConvModule(nn.Module): + """ + A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution + layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int], str] = 0, + bias: bool = False, + dilation: Union[int, Tuple[int, int]] = 1, + ) -> None: + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + dilation=dilation, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.conv(input) + output = self.bn(output) + output = self.activation(output) + + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision +class Data2VecVisionPyramidPoolingModule(nn.ModuleList): + """ + Pyramid Pooling Module (PPM) used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + align_corners (bool): align_corners argument of F.interpolate. + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None: + super().__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + for pool_scale in pool_scales: + self.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(pool_scale), + Data2VecVisionConvModule(self.in_channels, self.channels, kernel_size=1), + ) + ) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + ppm_outs = [] + for ppm in self: + ppm_out = ppm(x) + upsampled_ppm_out = nn.functional.interpolate( + ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners + ) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +# Copied from transformers.models.beit.modeling_beit.BeitUperHead with Beit->Data2VecVision +class Data2VecVisionUperHead(nn.Module): + """ + Unified Perceptual Parsing for Scene Understanding. This head is the implementation of + [UPerNet](https://arxiv.org/abs/1807.10221). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__() + + self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) + self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] + self.channels = config.hidden_size + self.align_corners = False + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + # PSP Module + self.psp_modules = Data2VecVisionPyramidPoolingModule( + self.pool_scales, + self.in_channels[-1], + self.channels, + align_corners=self.align_corners, + ) + self.bottleneck = Data2VecVisionConvModule( + self.in_channels[-1] + len(self.pool_scales) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = Data2VecVisionConvModule(in_channels, self.channels, kernel_size=1) + fpn_conv = Data2VecVisionConvModule(self.channels, self.channels, kernel_size=3, padding=1) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = Data2VecVisionConvModule( + len(self.in_channels) * self.channels, + self.channels, + kernel_size=3, + padding=1, + ) + + def psp_forward(self, inputs): + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # build laterals + laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] + + laterals.append(self.psp_forward(encoder_hidden_states)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate( + laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners + ) + + # build outputs + fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = nn.functional.interpolate( + fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners + ) + fpn_outs = torch.cat(fpn_outs, dim=1) + output = self.fpn_bottleneck(fpn_outs) + output = self.classifier(output) + + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitFCNHead with Beit->Data2VecVision +class Data2VecVisionFCNHead(nn.Module): + """ + Fully Convolution Networks for Semantic Segmentation. This head is implemented of + [FCNNet](https://arxiv.org/abs/1411.4038>). + + Args: + config (Data2VecVisionConfig): Configuration. + in_channels + kernel_size (int): The kernel size for convs in the head. Default: 3. + dilation (int): The dilation rate for convs in the head. Default: 1. + + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + config: Data2VecVisionConfig, + in_index: int = 2, + kernel_size: int = 3, + dilation: Union[int, Tuple[int, int]] = 1, + ) -> None: + super().__init__() + self.in_channels = config.hidden_size + self.channels = config.auxiliary_channels + self.num_convs = config.auxiliary_num_convs + self.concat_input = config.auxiliary_concat_input + self.in_index = in_index + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + Data2VecVisionConvModule( + self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + for i in range(self.num_convs - 1): + convs.append( + Data2VecVisionConvModule( + self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation + ) + ) + if self.num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = Data2VecVisionConvModule( + self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + + self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + # just take the relevant feature maps + hidden_states = encoder_hidden_states[self.in_index] + output = self.convs(hidden_states) + if self.concat_input: + output = self.conv_cat(torch.cat([hidden_states, output], dim=1)) + output = self.classifier(output) + return output + + +@add_start_docstrings( + """ + Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes. + """, + DATA2VEC_VISION_START_DOCSTRING, +) +# Copied from transformers.models.beit.modeling_beit.BeitForSemanticSegmentation with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,microsoft/beit-base-finetuned-ade-640-640->facebook/data2vec-vision-base,beit->data2vec_vision +class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel): + def __init__(self, config: Data2VecVisionConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=False) + + # FPNs + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + nn.BatchNorm2d(config.hidden_size), + nn.GELU(), + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + ) + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2), + ) + self.fpn3 = nn.Identity() + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + + # Semantic segmentation head(s) + self.decode_head = Data2VecVisionUperHead(config) + self.auxiliary_head = Data2VecVisionFCNHead(config) if config.use_auxiliary_head else None + + # Initialize weights and apply final processing + self.post_init() + + def compute_loss(self, logits, auxiliary_logits, labels): + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + if auxiliary_logits is not None: + upsampled_auxiliary_logits = nn.functional.interpolate( + auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + # compute weighted loss + loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) + main_loss = loss_fct(upsampled_logits, labels) + auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels) + loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss + + return loss + + @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SemanticSegmenterOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import Data2VecVisionFeatureExtractor, Data2VecVisionForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = Data2VecVisionFeatureExtractor.from_pretrained("facebook/data2vec-vision-base") + >>> model = Data2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base") + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.data2vec_vision( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + # only keep certain features, and reshape + # note that we do +1 as the encoder_hidden_states also includes the initial embeddings + features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices] + batch_size = pixel_values.shape[0] + patch_resolution = self.config.image_size // self.config.patch_size + features = [ + x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features + ] + + # apply FPNs + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(features)): + features[i] = ops[i](features[i]) + + logits = self.decode_head(features) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(features) + + loss = None + if labels is not None: + if self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + else: + loss = self.compute_loss(logits, auxiliary_logits, labels) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c4485d91bfd75c..898848d5ba16e4 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1222,6 +1222,9 @@ def __init__(self, *args, **kwargs): DATA2VEC_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = None +DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST = None + + class Data2VecAudioForAudioFrameClassification(metaclass=DummyObject): _backends = ["torch"] @@ -1320,6 +1323,34 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Data2VecVisionForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Data2VecVisionForSemanticSegmentation(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Data2VecVisionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Data2VecVisionPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/data2vec/test_modeling_data2vec_vision.py b/tests/data2vec/test_modeling_data2vec_vision.py new file mode 100644 index 00000000000000..6005e9b379593b --- /dev/null +++ b/tests/data2vec/test_modeling_data2vec_vision.py @@ -0,0 +1,444 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Data2VecVision model. """ + + +import inspect +import unittest + +from transformers import Data2VecVisionConfig +from transformers.models.auto import get_values +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ..test_configuration_common import ConfigTester +from ..test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import ( + MODEL_MAPPING, + Data2VecVisionForImageClassification, + Data2VecVisionForSemanticSegmentation, + Data2VecVisionModel, + ) + from transformers.models.data2vec.modeling_data2vec_vision import ( + DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST, + to_2tuple, + ) + + +if is_vision_available(): + from PIL import Image + + from transformers import BeitFeatureExtractor + + +class Data2VecVisionModelTester: + def __init__( + self, + parent, + vocab_size=100, + batch_size=13, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + num_labels=3, + scope=None, + out_indices=[0, 1, 2, 3], + ): + self.parent = parent + self.vocab_size = 100 + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.scope = scope + self.out_indices = out_indices + self.num_labels = num_labels + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + pixel_labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + pixel_labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels) + + config = self.get_config() + + return config, pixel_values, labels, pixel_labels + + def get_config(self): + return Data2VecVisionConfig( + vocab_size=self.vocab_size, + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + out_indices=self.out_indices, + ) + + def create_and_check_model(self, config, pixel_values, labels, pixel_labels): + model = Data2VecVisionModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(self.image_size) + patch_size = to_2tuple(self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels): + config.num_labels = self.type_sequence_label_size + model = Data2VecVisionForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + def create_and_check_for_image_segmentation(self, config, pixel_values, labels, pixel_labels): + config.num_labels = self.num_labels + model = Data2VecVisionForSemanticSegmentation(config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2) + ) + result = model(pixel_values, labels=pixel_labels) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels, pixel_labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as Data2VecVision does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = ( + (Data2VecVisionModel, Data2VecVisionForImageClassification, Data2VecVisionForSemanticSegmentation) + if is_torch_available() + else () + ) + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = Data2VecVisionModelTester(self) + self.config_tester = ConfigTester( + self, config_class=Data2VecVisionConfig, has_text_modality=False, hidden_size=37 + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_inputs_embeds(self): + # Data2VecVision does not use inputs_embeds + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_image_segmentation(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs) + + def test_training(self): + if not self.model_tester.is_training: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in [*get_values(MODEL_MAPPING)]: + continue + + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + + def test_training_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if not self.model_tester.is_training: + return + + config.use_cache = False + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in [*get_values(MODEL_MAPPING)] or not model_class.supports_gradient_checkpointing: + continue + # TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + # this can then be incorporated into _prepare_for_class in test_modeling_common.py + elif model_class.__name__ == "Data2VecVisionForSemanticSegmentation": + batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape + inputs_dict["labels"] = torch.zeros( + [self.model_tester.batch_size, height, width], device=torch_device + ).long() + model = model_class(config) + model.gradient_checkpointing_enable() + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + # we skip lambda parameters as these require special initial values + # determined by config.layer_scale_init_value + if "lambda" in name: + continue + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + # in Data2VecVision, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 1 + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + self.assertEqual(out_len + 1, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + # Data2VecVision has a different seq_length + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_length = num_patches + 1 + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = Data2VecVisionModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +@require_vision +class Data2VecVisionModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ( + BeitFeatureExtractor.from_pretrained("facebook/data2vec-vision-base-ft1k") + if is_vision_available() + else None + ) + + @slow + def test_inference_image_classification_head_imagenet_1k(self): + model = Data2VecVisionForImageClassification.from_pretrained("facebook/data2vec-vision-base-ft1k").to( + torch_device + ) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + logits = outputs.logits + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(logits.shape, expected_shape) + + expected_slice = torch.tensor([0.3277, -0.1395, 0.0911]).to(torch_device) + + self.assertTrue(torch.allclose(logits[0, :3], expected_slice, atol=1e-4)) + + expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]] + self.assertEqual(logits[0].topk(2).indices.cpu().tolist(), expected_top2) diff --git a/utils/check_repo.py b/utils/check_repo.py index 99af0935527497..d64d70c3a9943b 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -186,6 +186,7 @@ [ ("data2vec-text", "data2vec"), ("data2vec-audio", "data2vec"), + ("data2vec-vision", "data2vec"), ] ) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 153404c8814128..077dd5f13af434 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -15,6 +15,7 @@ src/transformers/models/blenderbot/modeling_blenderbot.py src/transformers/models/blenderbot_small/modeling_blenderbot_small.py src/transformers/models/convnext/modeling_convnext.py src/transformers/models/data2vec/modeling_data2vec_audio.py +src/transformers/models/data2vec/modeling_data2vec_vision.py src/transformers/models/deit/modeling_deit.py src/transformers/models/dpt/modeling_dpt.py src/transformers/models/electra/modeling_electra.py