diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index d88f26a3089bfd..d699af419aa258 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -18,6 +18,7 @@ import collections.abc import math from dataclasses import dataclass +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -99,7 +100,7 @@ def to_2tuple(x): # Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py -def drop_path(x, drop_prob: float = 0.0, training: bool = False): +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). @@ -122,11 +123,11 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False): class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob=None): + def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: @@ -141,7 +142,7 @@ class BeitEmbeddings(nn.Module): """ - def __init__(self, config): + def __init__(self, config: BeitConfig) -> None: super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) @@ -162,7 +163,7 @@ def __init__(self, config): self.position_embeddings = None self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, pixel_values, bool_masked_pos=None): + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: embeddings = self.patch_embeddings(pixel_values) batch_size, seq_len, _ = embeddings.size() @@ -189,7 +190,9 @@ class PatchEmbeddings(nn.Module): Image to Patch Embedding. """ - def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + def __init__( + self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768 + ) -> None: super().__init__() image_size = to_2tuple(image_size) patch_size = to_2tuple(patch_size) @@ -202,7 +205,7 @@ def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values): + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape # FIXME look at relaxing size constraints if height != self.image_size[0] or width != self.image_size[1]: @@ -215,7 +218,7 @@ def forward(self, pixel_values): class BeitSelfAttention(nn.Module): - def __init__(self, config, window_size=None): + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -243,7 +246,13 @@ def transpose_for_scores(self, x): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["BeitRelativePositionBias"] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) @@ -291,12 +300,12 @@ class BeitSelfOutput(nn.Module): layernorm applied before each block. """ - def __init__(self, config): + def __init__(self, config: BeitConfig) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor, gamma=None): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -304,7 +313,7 @@ def forward(self, hidden_states, input_tensor, gamma=None): class BeitAttention(nn.Module): - def __init__(self, config, window_size=None): + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: super().__init__() self.attention = BeitSelfAttention(config, window_size=window_size) self.output = BeitSelfOutput(config) @@ -328,7 +337,13 @@ def prune_heads(self, heads): self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["BeitRelativePositionBias"] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias) attention_output = self.output(self_outputs[0], hidden_states) @@ -338,7 +353,7 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False, relati class BeitIntermediate(nn.Module): - def __init__(self, config): + def __init__(self, config: BeitConfig) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): @@ -346,7 +361,7 @@ def __init__(self, config): else: self.intermediate_act_fn = config.hidden_act - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) @@ -354,12 +369,12 @@ def forward(self, hidden_states): class BeitOutput(nn.Module): - def __init__(self, config): + def __init__(self, config: BeitConfig) -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -369,7 +384,7 @@ def forward(self, hidden_states): class BeitLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" - def __init__(self, config, window_size=None, drop_path_rate=0.0): + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 @@ -387,7 +402,13 @@ def __init__(self, config, window_size=None, drop_path_rate=0.0): else: self.lambda_1, self.lambda_2 = None, None - def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["BeitRelativePositionBias"] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention head_mask, @@ -422,7 +443,7 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False, relati class BeitRelativePositionBias(nn.Module): - def __init__(self, config, window_size): + def __init__(self, config: BeitConfig, window_size: tuple) -> None: super().__init__() self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 @@ -451,7 +472,7 @@ def __init__(self, config, window_size): self.register_buffer("relative_position_index", relative_position_index) - def forward(self): + def forward(self) -> torch.Tensor: relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1 ) # Wh*Ww,Wh*Ww,nH @@ -460,7 +481,7 @@ def forward(self): class BeitEncoder(nn.Module): - def __init__(self, config, window_size=None): + def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: super().__init__() self.config = config if config.use_shared_relative_position_bias: @@ -484,12 +505,12 @@ def __init__(self, config, window_size=None): def forward( self, - hidden_states, - head_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -606,7 +627,7 @@ def _set_gradient_checkpointing(self, module, value=False): BEIT_START_DOCSTRING, ) class BeitModel(BeitPreTrainedModel): - def __init__(self, config, add_pooling_layer=True): + def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None: super().__init__(config) self.config = config @@ -643,13 +664,13 @@ class PreTrainedModel ) def forward( self, - pixel_values=None, - bool_masked_pos=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BeitModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -691,13 +712,13 @@ def forward( class BeitPooler(nn.Module): - def __init__(self, config): + def __init__(self, config: BeitModel) -> None: super().__init__() self.layernorm = ( nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None ) - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.layernorm is not None: # Mean pool the final hidden states of the patch tokens patch_tokens = hidden_states[:, 1:, :] @@ -714,7 +735,7 @@ def forward(self, hidden_states): BEIT_START_DOCSTRING, ) class BeitForMaskedImageModeling(BeitPreTrainedModel): - def __init__(self, config): + def __init__(self, config: BeitModel) -> None: super().__init__(config) self.num_labels = config.num_labels @@ -731,14 +752,14 @@ def __init__(self, config): @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - pixel_values=None, - bool_masked_pos=None, - head_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -814,7 +835,7 @@ def forward( BEIT_START_DOCSTRING, ) class BeitForImageClassification(BeitPreTrainedModel): - def __init__(self, config): + def __init__(self, config: BeitModel) -> None: super().__init__(config) self.num_labels = config.num_labels @@ -836,13 +857,13 @@ def __init__(self, config): ) def forward( self, - pixel_values=None, - head_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., @@ -904,7 +925,15 @@ class BeitConvModule(nn.Module): Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. """ - def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int], str] = 0, + bias: bool = False, + dilation: Union[int, Tuple[int, int]] = 1, + ) -> None: super().__init__() self.conv = nn.Conv2d( in_channels=in_channels, @@ -917,7 +946,7 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False self.bn = nn.BatchNorm2d(out_channels) self.activation = nn.ReLU() - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: output = self.conv(input) output = self.bn(output) output = self.activation(output) @@ -939,7 +968,7 @@ class BeitPyramidPoolingModule(nn.ModuleList): Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. """ - def __init__(self, pool_scales, in_channels, channels, align_corners): + def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None: super().__init__() self.pool_scales = pool_scales self.align_corners = align_corners @@ -953,7 +982,7 @@ def __init__(self, pool_scales, in_channels, channels, align_corners): ) ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: ppm_outs = [] for ppm in self: ppm_out = ppm(x) @@ -972,7 +1001,7 @@ class BeitUperHead(nn.Module): Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. """ - def __init__(self, config): + def __init__(self, config: BeitConfig) -> None: super().__init__() self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) @@ -1019,7 +1048,7 @@ def psp_forward(self, inputs): return output - def forward(self, encoder_hidden_states): + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: # build laterals laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] @@ -1064,7 +1093,9 @@ class BeitFCNHead(nn.Module): Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. """ - def __init__(self, config, in_index=2, kernel_size=3, dilation=1): + def __init__( + self, config: BeitConfig, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, Tuple[int, int]] = 1 + ) -> None: super().__init__() self.in_channels = config.hidden_size self.channels = config.auxiliary_channels @@ -1096,7 +1127,7 @@ def __init__(self, config, in_index=2, kernel_size=3, dilation=1): self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1) - def forward(self, encoder_hidden_states): + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: # just take the relevant feature maps hidden_states = encoder_hidden_states[self.in_index] output = self.convs(hidden_states) @@ -1113,7 +1144,7 @@ def forward(self, encoder_hidden_states): BEIT_START_DOCSTRING, ) class BeitForSemanticSegmentation(BeitPreTrainedModel): - def __init__(self, config): + def __init__(self, config: BeitConfig) -> None: super().__init__(config) self.num_labels = config.num_labels @@ -1160,13 +1191,13 @@ def compute_loss(self, logits, auxiliary_logits, labels): @replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - pixel_values=None, - head_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SemanticSegmentationModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 9696db6a8776ef..91b673ff6f5c6d 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -18,7 +18,7 @@ import collections.abc import math from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -77,7 +77,7 @@ class DeiTEmbeddings(nn.Module): """ - def __init__(self, config, use_mask_token=False): + def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None: super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) @@ -93,7 +93,7 @@ def __init__(self, config, use_mask_token=False): self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, pixel_values, bool_masked_pos=None): + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: embeddings = self.patch_embeddings(pixel_values) batch_size, seq_len, _ = embeddings.size() @@ -117,7 +117,13 @@ class PatchEmbeddings(nn.Module): """ - def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + def __init__( + self, + image_size: int = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + num_channels: int = 3, + embed_dim: int = 768, + ) -> None: super().__init__() image_size = to_2tuple(image_size) patch_size = to_2tuple(patch_size) @@ -128,7 +134,7 @@ def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values): + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape # FIXME look at relaxing size constraints if height != self.image_size[0] or width != self.image_size[1]: @@ -141,7 +147,7 @@ def forward(self, pixel_values): # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT class DeiTSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -159,12 +165,14 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, head_mask=None, output_attentions=False): + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) @@ -205,12 +213,12 @@ class DeiTSelfOutput(nn.Module): layernorm applied before each block. """ - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -220,13 +228,13 @@ def forward(self, hidden_states, input_tensor): # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT class DeiTAttention(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.attention = DeiTSelfAttention(config) self.output = DeiTSelfOutput(config) self.pruned_heads = set() - def prune_heads(self, heads): + def prune_heads(self, heads: Set[int]) -> None: if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( @@ -244,7 +252,12 @@ def prune_heads(self, heads): self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, hidden_states, head_mask=None, output_attentions=False): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_outputs = self.attention(hidden_states, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) @@ -255,7 +268,7 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT class DeiTIntermediate(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): @@ -263,7 +276,7 @@ def __init__(self, config): else: self.intermediate_act_fn = config.hidden_act - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) @@ -273,12 +286,12 @@ def forward(self, hidden_states): # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT class DeiTOutput(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -291,7 +304,7 @@ def forward(self, hidden_states, input_tensor): class DeiTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 @@ -301,7 +314,12 @@ def __init__(self, config): self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states, head_mask=None, output_attentions=False): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in DeiT, layernorm is applied before self-attention head_mask, @@ -327,7 +345,7 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT class DeiTEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.config = config self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)]) @@ -335,12 +353,12 @@ def __init__(self, config): def forward( self, - hidden_states, - head_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -395,7 +413,7 @@ class DeiTPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True - def _init_weights(self, module): + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Slightly different from the TF version which uses truncated_normal for initialization @@ -407,7 +425,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module: DeiTEncoder, value: bool = False) -> None: if isinstance(module, DeiTEncoder): module.gradient_checkpointing = value @@ -451,7 +469,7 @@ def _set_gradient_checkpointing(self, module, value=False): DEIT_START_DOCSTRING, ) class DeiTModel(DeiTPreTrainedModel): - def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + def __init__(self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None: super().__init__(config) self.config = config @@ -464,7 +482,7 @@ def __init__(self, config, add_pooling_layer=True, use_mask_token=False): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): + def get_input_embeddings(self) -> PatchEmbeddings: return self.embeddings.patch_embeddings def _prune_heads(self, heads_to_prune): @@ -486,12 +504,12 @@ class PreTrainedModel ) def forward( self, - pixel_values=None, - bool_masked_pos=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -554,7 +572,7 @@ def forward(self, hidden_states): DEIT_START_DOCSTRING, ) class DeiTForMaskedImageModeling(DeiTPreTrainedModel): - def __init__(self, config): + def __init__(self, config: DeiTConfig) -> None: super().__init__(config) self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True) @@ -571,13 +589,13 @@ def __init__(self, config): @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - pixel_values=None, - bool_masked_pos=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -662,7 +680,7 @@ def forward( DEIT_START_DOCSTRING, ) class DeiTForImageClassification(DeiTPreTrainedModel): - def __init__(self, config): + def __init__(self, config: DeiTConfig) -> None: super().__init__(config) self.num_labels = config.num_labels @@ -678,13 +696,13 @@ def __init__(self, config): @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - pixel_values=None, - head_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., @@ -811,7 +829,7 @@ class token). DEIT_START_DOCSTRING, ) class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): - def __init__(self, config): + def __init__(self, config: DeiTConfig) -> None: super().__init__(config) self.num_labels = config.num_labels @@ -838,12 +856,12 @@ def __init__(self, config): ) def forward( self, - pixel_values=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.deit( diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index b96846574a1688..a885937e2aca52 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -389,12 +389,12 @@ class ViltSelfOutput(nn.Module): layernorm applied before each block. """ - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -438,7 +438,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, output_att # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt class ViltIntermediate(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): @@ -446,7 +446,7 @@ def __init__(self, config): else: self.intermediate_act_fn = config.hidden_act - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) @@ -456,12 +456,12 @@ def forward(self, hidden_states): # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt class ViltOutput(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 6422755e62b12a..8c89cf9cac53ca 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -17,6 +17,7 @@ import collections.abc import math +from typing import Dict, List, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -76,7 +77,7 @@ class ViTEmbeddings(nn.Module): """ - def __init__(self, config, use_mask_token=False): + def __init__(self, config, use_mask_token: bool = False) -> None: super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) @@ -92,7 +93,7 @@ def __init__(self, config, use_mask_token=False): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config - def interpolate_pos_encoding(self, embeddings, height, width): + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. @@ -123,7 +124,12 @@ def interpolate_pos_encoding(self, embeddings, height, width): patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) - def forward(self, pixel_values, bool_masked_pos=None, interpolate_pos_encoding=False): + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) @@ -157,7 +163,13 @@ class PatchEmbeddings(nn.Module): """ - def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + def __init__( + self, + image_size: int = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + num_channels: int = 3, + embed_dim: int = 768, + ): super().__init__() image_size = to_2tuple(image_size) patch_size = to_2tuple(patch_size) @@ -168,7 +180,7 @@ def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) - def forward(self, pixel_values, interpolate_pos_encoding=False): + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if not interpolate_pos_encoding: if height != self.image_size[0] or width != self.image_size[1]: @@ -180,7 +192,7 @@ def forward(self, pixel_values, interpolate_pos_encoding=False): class ViTSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -198,12 +210,14 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, head_mask=None, output_attentions=False): + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) @@ -243,12 +257,12 @@ class ViTSelfOutput(nn.Module): layernorm applied before each block. """ - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -257,13 +271,13 @@ def forward(self, hidden_states, input_tensor): class ViTAttention(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.attention = ViTSelfAttention(config) self.output = ViTSelfOutput(config) self.pruned_heads = set() - def prune_heads(self, heads): + def prune_heads(self, heads: Set[int]) -> None: if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( @@ -281,7 +295,12 @@ def prune_heads(self, heads): self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, hidden_states, head_mask=None, output_attentions=False): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_outputs = self.attention(hidden_states, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) @@ -291,7 +310,7 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): class ViTIntermediate(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): @@ -299,7 +318,7 @@ def __init__(self, config): else: self.intermediate_act_fn = config.hidden_act - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) @@ -308,12 +327,12 @@ def forward(self, hidden_states): class ViTOutput(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -325,7 +344,7 @@ def forward(self, hidden_states, input_tensor): class ViTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 @@ -335,7 +354,12 @@ def __init__(self, config): self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states, head_mask=None, output_attentions=False): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention head_mask, @@ -360,7 +384,7 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): class ViTEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.config = config self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) @@ -368,12 +392,12 @@ def __init__(self, config): def forward( self, - hidden_states, - head_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -427,7 +451,7 @@ class ViTPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True - def _init_weights(self, module): + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Slightly different from the TF version which uses truncated_normal for initialization @@ -439,7 +463,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: if isinstance(module, ViTEncoder): module.gradient_checkpointing = value @@ -485,7 +509,7 @@ def _set_gradient_checkpointing(self, module, value=False): VIT_START_DOCSTRING, ) class ViTModel(ViTPreTrainedModel): - def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): super().__init__(config) self.config = config @@ -498,10 +522,10 @@ def __init__(self, config, add_pooling_layer=True, use_mask_token=False): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self): + def get_input_embeddings(self) -> PatchEmbeddings: return self.embeddings.patch_embeddings - def _prune_heads(self, heads_to_prune): + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel @@ -520,13 +544,13 @@ class PreTrainedModel ) def forward( self, - pixel_values=None, - bool_masked_pos=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - interpolate_pos_encoding=None, - return_dict=None, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -590,7 +614,7 @@ def forward(self, hidden_states): VIT_START_DOCSTRING, ) class ViTForMaskedImageModeling(ViTPreTrainedModel): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__(config) self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True) @@ -607,14 +631,14 @@ def __init__(self, config): @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - pixel_values=None, - bool_masked_pos=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - interpolate_pos_encoding=None, - return_dict=None, - ): + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedLMOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -700,7 +724,7 @@ def forward( VIT_START_DOCSTRING, ) class ViTForImageClassification(ViTPreTrainedModel): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__(config) self.num_labels = config.num_labels @@ -722,14 +746,14 @@ def __init__(self, config): ) def forward( self, - pixel_values=None, - head_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - interpolate_pos_encoding=None, - return_dict=None, - ): + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 61c11fb6cda4c8..8b3d0f90dc5f7d 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -19,7 +19,7 @@ import math from copy import deepcopy from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Set, Tuple, Union import numpy as np import torch @@ -318,7 +318,7 @@ def forward(self, pixel_values): # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention class ViTMAESelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -336,12 +336,14 @@ def __init__(self, config): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, head_mask=None, output_attentions=False): + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) @@ -382,12 +384,12 @@ class ViTMAESelfOutput(nn.Module): layernorm applied before each block. """ - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -397,13 +399,13 @@ def forward(self, hidden_states, input_tensor): # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE class ViTMAEAttention(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.attention = ViTMAESelfAttention(config) self.output = ViTMAESelfOutput(config) self.pruned_heads = set() - def prune_heads(self, heads): + def prune_heads(self, heads: Set[int]) -> None: if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( @@ -421,7 +423,12 @@ def prune_heads(self, heads): self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - def forward(self, hidden_states, head_mask=None, output_attentions=False): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_outputs = self.attention(hidden_states, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) @@ -432,7 +439,7 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): # Copied from transformers.models.vit.modeling_vit.ViTIntermediate class ViTMAEIntermediate(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): @@ -440,7 +447,7 @@ def __init__(self, config): else: self.intermediate_act_fn = config.hidden_act - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) @@ -450,12 +457,12 @@ def forward(self, hidden_states): # Copied from transformers.models.vit.modeling_vit.ViTOutput class ViTMAEOutput(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward(self, hidden_states, input_tensor): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -468,7 +475,7 @@ def forward(self, hidden_states, input_tensor): class ViTMAELayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 @@ -478,7 +485,12 @@ def __init__(self, config): self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states, head_mask=None, output_attentions=False): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention head_mask, @@ -504,7 +516,7 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE class ViTMAEEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config) -> None: super().__init__() self.config = config self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)]) @@ -512,12 +524,12 @@ def __init__(self, config): def forward( self, - hidden_states, - head_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None