This repository has been archived by the owner on Mar 15, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 557
/
resmlp_models.py
197 lines (154 loc) · 6.39 KB
/
resmlp_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import Mlp, PatchEmbed , _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath
__all__ = [
'resmlp_12', 'resmlp_24', 'resmlp_36', 'resmlpB_24'
]
class Affine(nn.Module):
def __init__(self, dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
def forward(self, x):
return self.alpha * x + self.beta
class layers_scale_mlp_blocks(nn.Module):
def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU,init_values=1e-4,num_patches = 196):
super().__init__()
self.norm1 = Affine(dim)
self.attn = nn.Linear(num_patches, num_patches)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = Affine(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(4.0 * dim), act_layer=act_layer, drop=drop)
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x).transpose(1,2)).transpose(1,2))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class resmlp_models(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,drop_rate=0.,
Patch_layer=PatchEmbed,act_layer=nn.GELU,
drop_path_rate=0.0,init_scale=1e-4):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed = Patch_layer(
img_size=img_size, patch_size=patch_size, in_chans=int(in_chans), embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([
layers_scale_mlp_blocks(
dim=embed_dim,drop=drop_rate,drop_path=dpr[i],
act_layer=act_layer,init_values=init_scale,
num_patches=num_patches)
for i in range(depth)])
self.norm = Affine(embed_dim)
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
for i , blk in enumerate(self.blocks):
x = blk(x)
x = self.norm(x)
x = x.mean(dim=1).reshape(B,1,-1)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
@register_model
def resmlp_12(pretrained=False,dist=False, **kwargs):
model = resmlp_models(
patch_size=16, embed_dim=384, depth=12,
Patch_layer=PatchEmbed,
init_scale=0.1,**kwargs)
model.default_cfg = _cfg()
if pretrained:
if dist:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth"
else:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth"
checkpoint = torch.hub.load_state_dict_from_url(
url=url_path,
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint)
return model
@register_model
def resmlp_24(pretrained=False,dist=False,dino=False, **kwargs):
model = resmlp_models(
patch_size=16, embed_dim=384, depth=24,
Patch_layer=PatchEmbed,
init_scale=1e-5,**kwargs)
model.default_cfg = _cfg()
if pretrained:
if dist:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth"
elif dino:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth"
else:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth"
checkpoint = torch.hub.load_state_dict_from_url(
url=url_path,
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint)
return model
@register_model
def resmlp_36(pretrained=False,dist=False, **kwargs):
model = resmlp_models(
patch_size=16, embed_dim=384, depth=36,
Patch_layer=PatchEmbed,
init_scale=1e-6,**kwargs)
model.default_cfg = _cfg()
if pretrained:
if dist:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth"
else:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth"
checkpoint = torch.hub.load_state_dict_from_url(
url=url_path,
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint)
return model
@register_model
def resmlpB_24(pretrained=False,dist=False, in_22k = False, **kwargs):
model = resmlp_models(
patch_size=8, embed_dim=768, depth=24,
Patch_layer=PatchEmbed,
init_scale=1e-6,**kwargs)
model.default_cfg = _cfg()
if pretrained:
if dist:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth"
elif in_22k:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth"
else:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth"
checkpoint = torch.hub.load_state_dict_from_url(
url=url_path,
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint)
return model