Skip to content

Commit

Permalink
Add quickgelu vit clip variants, simplify get_norm_layer and allow st…
Browse files Browse the repository at this point in the history
…ring args in vit norm/act. Add metaclip CLIP weights
  • Loading branch information
rwightman committed Nov 3, 2023
1 parent c55bc41 commit a2e4a4c
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 42 deletions.
14 changes: 14 additions & 0 deletions timm/layers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,17 @@ def __init__(self, inplace: bool = False):

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input, approximate='tanh')


def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)


class QuickGELU(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(QuickGELU, self).__init__()

def forward(self, input: torch.Tensor) -> torch.Tensor:
return quick_gelu(input)
6 changes: 4 additions & 2 deletions timm/layers/create_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
selu=F.selu,
gelu=gelu,
gelu_tanh=gelu_tanh,
quick_gelu=quick_gelu,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
Expand All @@ -42,7 +43,7 @@
mish=F.mish if _has_mish else mish_jit,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
hard_mish=hard_mish_jit
hard_mish=hard_mish_jit,
)

_ACT_FN_ME = dict(
Expand Down Expand Up @@ -73,6 +74,7 @@
selu=nn.SELU,
gelu=GELU,
gelu_tanh=GELUTanh,
quick_gelu=QuickGELU,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
Expand All @@ -87,7 +89,7 @@
mish=nn.Mish if _has_mish else MishJit,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
hard_mish=HardMishJit
hard_mish=HardMishJit,
)

_ACT_LAYER_ME = dict(
Expand Down
24 changes: 12 additions & 12 deletions timm/layers/create_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
Copyright 2022 Ross Wightman
"""
import types
import functools
import types
from typing import Type

import torch.nn as nn

from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from torchvision.ops import FrozenBatchNorm2d

_NORM_MAP = dict(
batchnorm=nn.BatchNorm2d,
Expand All @@ -19,6 +21,8 @@
groupnorm1=GroupNorm1,
layernorm=LayerNorm,
layernorm2d=LayerNorm2d,
rmsnorm=RmsNorm,
frozenbatchnorm2d=FrozenBatchNorm2d,
)
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}

Expand All @@ -30,7 +34,10 @@ def create_norm_layer(layer_name, num_features, **kwargs):


def get_norm_layer(norm_layer):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
if not norm_layer:
# None or '' should return None
return None
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
norm_kwargs = {}

# unbind partial fn, so args can be rebound later
Expand All @@ -40,16 +47,9 @@ def get_norm_layer(norm_layer):

if isinstance(norm_layer, str):
layer_name = norm_layer.replace('_', '')
norm_layer = _NORM_MAP.get(layer_name, None)
elif norm_layer in _NORM_TYPES:
norm_layer = norm_layer
elif isinstance(norm_layer, types.FunctionType):
# if function type, assume it is a lambda/fn that creates a norm layer
norm_layer = norm_layer
norm_layer = _NORM_MAP[layer_name]
else:
type_name = norm_layer.__name__.lower().replace('_', '')
norm_layer = _NORM_MAP.get(type_name, None)
assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
norm_layer = norm_layer

if norm_kwargs:
norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
Expand Down
161 changes: 133 additions & 28 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@
import math
from collections import OrderedDict
from functools import partial
from typing import Callable, List, Optional, Sequence, Tuple, Union
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final


from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
get_act_layer, get_norm_layer, LayerType
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
Expand Down Expand Up @@ -414,10 +416,10 @@ def __init__(
drop_path_rate: float = 0.,
weight_init: str = '',
embed_layer: Callable = PatchEmbed,
norm_layer: Optional[Callable] = None,
act_layer: Optional[Callable] = None,
block_fn: Callable = Block,
mlp_layer: Callable = Mlp,
norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = Block,
mlp_layer: Type[nn.Module] = Mlp,
):
"""
Args:
Expand Down Expand Up @@ -450,8 +452,8 @@ def __init__(
assert global_pool in ('', 'avg', 'token', 'map')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = get_act_layer(act_layer) or nn.GELU

self.num_classes = num_classes
self.global_pool = global_pool
Expand Down Expand Up @@ -1415,46 +1417,75 @@ def _cfg(url='', **kwargs):
hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch16_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch16_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_giant_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_gigantic_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),

'vit_base_patch32_clip_224.datacompxl': _cfg(
hf_hub_id='laion/',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch32_clip_256.datacompxl': _cfg(
hf_hub_id='laion/',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 256, 256), num_classes=512),
'vit_base_patch16_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),

'vit_base_patch16_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_224.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_378.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),

'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-b32-fullcc2.5b',
hf_hub_filename='metaclip_b32_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-b16-fullcc2.5b',
hf_hub_filename='metaclip_b16_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-l14-fullcc2.5b',
hf_hub_filename='metaclip_l14_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-h14-fullcc2.5b',
hf_hub_filename='metaclip_h14_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_giant_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_gigantic_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),

'vit_base_patch32_clip_224.openai': _cfg(
hf_hub_id='timm/',
Expand Down Expand Up @@ -2078,6 +2109,80 @@ def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransform
'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/32 CLIP image tower @ 224x224
"""
model_args = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/16 CLIP image tower w/ QuickGELU act
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
"""
from timm.layers import get_act_layer
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act.
"""
model_args = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act
"""
model_args = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
return model


# Experimental models below

@register_model
Expand Down

0 comments on commit a2e4a4c

Please sign in to comment.