Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vision models #17731

Merged
merged 16 commits into from
Jun 24, 2022
62 changes: 26 additions & 36 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,7 @@ class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
"""


# Inspired by
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)


# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Expand All @@ -112,16 +102,16 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
argument.
"""
if drop_prob == 0.0 or not training:
return x
return input
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
output = input.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
class BeitDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob: Optional[float] = None) -> None:
Expand Down Expand Up @@ -151,12 +141,7 @@ def __init__(self, config: BeitConfig) -> None:
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
else:
self.mask_token = None
self.patch_embeddings = PatchEmbeddings(
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
self.patch_embeddings = BeitPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
Expand Down Expand Up @@ -184,38 +169,43 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo
return embeddings


# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class PatchEmbeddings(nn.Module):
class BeitPatchEmbeddings(nn.Module):
"""
Image to Patch Embedding.
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""

def __init__(
self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768
) -> None:
def __init__(self, config):
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size

image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.patch_shape = patch_shape

self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
# FIXME look at relaxing size constraints
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)

return x
return embeddings


class BeitSelfAttention(nn.Module):
Expand Down Expand Up @@ -393,7 +383,7 @@ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop
self.intermediate = BeitIntermediate(config)
self.output = BeitOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.drop_path = BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

init_values = config.layer_scale_init_value
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/beit/modeling_flax_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class FlaxBeitPatchEmbeddings(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation

def setup(self):
self.num_channels = self.config.num_channels
image_size = self.config.image_size
patch_size = self.config.patch_size
num_patches = (image_size // patch_size) * (image_size // patch_size)
Expand All @@ -187,6 +188,11 @@ def setup(self):
)

def __call__(self, pixel_values):
num_channels = pixel_values.shape[-1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.projection(pixel_values)
batch_size, _, _, channels = embeddings.shape
return jnp.reshape(embeddings, (batch_size, -1, channels))
Expand Down Expand Up @@ -603,7 +609,7 @@ def __init__(
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3)
input_shape = (1, config.image_size, config.image_size, config.num_channels)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
Expand Down
37 changes: 24 additions & 13 deletions src/transformers/models/convnext/modeling_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,36 +53,41 @@
]


# Stochastic depth implementation
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
Connect' is a different form of dropout in a separate paper... See discussion:
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return x
return input
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
output = input.div(keep_prob) * random_tensor
return output


# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
class ConvNextDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob=None):
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob

def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)

def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)


class ConvNextLayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
Expand Down Expand Up @@ -122,8 +127,14 @@ def __init__(self, config):
config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
)
self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
self.num_channels = config.num_channels

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
num_channels = pixel_values.shape[1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.patch_embeddings(pixel_values)
embeddings = self.layernorm(embeddings)
return embeddings
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/convnext/modeling_tf_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import numpy as np
import tensorflow as tf

from transformers import shape_list

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
from ...modeling_tf_utils import (
Expand Down Expand Up @@ -77,11 +79,18 @@ def __init__(self, config, **kwargs):
bias_initializer="zeros",
)
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
self.num_channels = config.num_channels

def call(self, pixel_values):
if isinstance(pixel_values, dict):
pixel_values = pixel_values["pixel_values"]

num_channels = shape_list(pixel_values)[1]
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)

# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
Expand Down
31 changes: 18 additions & 13 deletions src/transformers/models/cvt/modeling_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,41 @@ class BaseModelOutputWithCLSToken(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None


# Copied from transformers.models.convnext.modeling_convnext.drop_path
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
Connect' is a different form of dropout in a separate paper... See discussion:
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return x
return input
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
output = input.div(keep_prob) * random_tensor
return output


# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
class CvtDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob=None):
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob

def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)

def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)


class CvtEmbeddings(nn.Module):
"""
Expand Down
Loading