From dcfdba1f5f61c23213f8e93b95b774b7407bcdf4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 3 Nov 2023 08:07:52 -0700 Subject: [PATCH] Make quickgelu models appear in listing --- timm/models/vision_transformer.py | 50 ++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 3f08b98f72..ab4b5084f1 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -35,7 +35,6 @@ import torch.utils.checkpoint from torch.jit import Final - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ @@ -1043,7 +1042,7 @@ def _cfg(url='', **kwargs): } -default_cfgs = generate_default_cfgs({ +default_cfgs = { # re-finetuned augreg 21k FT on in1k weights 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg( @@ -1459,49 +1458,60 @@ def _cfg(url='', **kwargs): 'vit_large_patch14_clip_224.dfn2b': _cfg( hf_hub_id='apple/DFN2B-CLIP-ViT-L-14', hf_hub_filename='open_clip_pytorch_model.bin', + notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), 'vit_huge_patch14_clip_224.dfn5b': _cfg( hf_hub_id='apple/DFN5B-CLIP-ViT-H-14', hf_hub_filename='open_clip_pytorch_model.bin', + notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), 'vit_huge_patch14_clip_378.dfn5b': _cfg( hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + notes=('natively QuickGELU, use quickgelu model variant for original results',), crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024), 'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg( hf_hub_id='facebook/metaclip-b32-fullcc2.5b', hf_hub_filename='metaclip_b32_fullcc2.5b.bin', license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), 'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg( hf_hub_id='facebook/metaclip-b16-fullcc2.5b', hf_hub_filename='metaclip_b16_fullcc2.5b.bin', license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), 'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg( hf_hub_id='facebook/metaclip-l14-fullcc2.5b', hf_hub_filename='metaclip_l14_fullcc2.5b.bin', license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), 'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg( hf_hub_id='facebook/metaclip-h14-fullcc2.5b', hf_hub_filename='metaclip_h14_fullcc2.5b.bin', license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), 'vit_base_patch32_clip_224.openai': _cfg( hf_hub_id='timm/', + notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_base_patch16_clip_224.openai': _cfg( hf_hub_id='timm/', + notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_large_patch14_clip_224.openai': _cfg( hf_hub_id='timm/', + notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), 'vit_large_patch14_clip_336.openai': _cfg( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', + notes=('natively QuickGELU, use quickgelu model variant for original results',), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), num_classes=768), @@ -1677,7 +1687,25 @@ def _cfg(url='', **kwargs): 'vit_medium_patch16_reg4_gap_256': _cfg( input_size=(3, 256, 256)), 'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)), +} + +_quick_gelu_cfgs = [ + 'vit_large_patch14_clip_224.dfn2b', + 'vit_huge_patch14_clip_224.dfn5b', + 'vit_huge_patch14_clip_378.dfn5b', + 'vit_base_patch32_clip_224.metaclip_2pt5b', + 'vit_base_patch16_clip_224.metaclip_2pt5b', + 'vit_large_patch14_clip_224.metaclip_2pt5b', + 'vit_huge_patch14_clip_224.metaclip_2pt5b', + 'vit_base_patch32_clip_224.openai', + 'vit_base_patch16_clip_224.openai', + 'vit_large_patch14_clip_224.openai', + 'vit_large_patch14_clip_336.openai', +] +default_cfgs.update({ + n.replace('_clip_', '_clip_quickgelu_'): default_cfgs[n] for n in _quick_gelu_cfgs }) +default_cfgs = generate_default_cfgs(default_cfgs) def _create_vision_transformer(variant, pretrained=False, **kwargs): @@ -2133,8 +2161,7 @@ def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, act_layer='quick_gelu') model = _create_vision_transformer( - 'vit_base_patch32_clip_224', # map to non quickgelu pretrained_cfg intentionally - pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_base_patch32_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2146,8 +2173,7 @@ def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, act_layer='quick_gelu') model = _create_vision_transformer( - 'vit_base_patch16_clip_224', # map to non quickgelu pretrained_cfg intentionally - pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2160,8 +2186,7 @@ def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTr patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, act_layer='quick_gelu') model = _create_vision_transformer( - 'vit_large_patch14_clip_224', # map to non quickgelu pretrained_cfg intentionally - pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_large_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2173,8 +2198,7 @@ def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTr patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, act_layer='quick_gelu') model = _create_vision_transformer( - 'vit_large_patch14_clip_336', # map to non quickgelu pretrained_cfg intentionally - pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_large_patch14_clip_quickgelu_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2186,8 +2210,7 @@ def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, act_layer='quick_gelu') model = _create_vision_transformer( - 'vit_huge_patch14_clip_224', # map to non quickgelu pretrained_cfg intentionally - pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_huge_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2199,8 +2222,7 @@ def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTra patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, act_layer='quick_gelu') model = _create_vision_transformer( - 'vit_huge_patch14_clip_378', # map to non quickgelu pretrained_cfg intentionally - pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_huge_patch14_clip_quickgelu_378', pretrained=pretrained, **dict(model_args, **kwargs)) return model