Skip to content

Commit

Permalink
Some RepVit tweaks
Browse files Browse the repository at this point in the history
* add head dropout to RepVit as all models have that arg
* default train to non-distilled head output via distilled_training flag (set_distilled_training) so fine-tune works by default w/o distillation script
* camel case naming tweaks to match other models
  • Loading branch information
rwightman committed Aug 9, 2023
1 parent f677190 commit c692715
Showing 1 changed file with 37 additions and 25 deletions.
62 changes: 37 additions & 25 deletions timm/models/repvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Adapted from official impl at https://github.com/jameslahm/RepViT
"""

__all__ = ['RepViT']
__all__ = ['RepVit']

import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
Expand Down Expand Up @@ -81,7 +81,7 @@ def fuse(self):
return m


class RepVGGDW(nn.Module):
class RepVggDw(nn.Module):
def __init__(self, ed, kernel_size):
super().__init__()
self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed)
Expand Down Expand Up @@ -115,7 +115,7 @@ def fuse(self):
return conv


class RepViTMlp(nn.Module):
class RepVitMlp(nn.Module):
def __init__(self, in_dim, hidden_dim, act_layer):
super().__init__()
self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0)
Expand All @@ -130,9 +130,9 @@ class RepViTBlock(nn.Module):
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer):
super(RepViTBlock, self).__init__()

self.token_mixer = RepVGGDW(in_dim, kernel_size)
self.token_mixer = RepVggDw(in_dim, kernel_size)
self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity()
self.channel_mixer = RepViTMlp(in_dim, in_dim * mlp_ratio, act_layer)
self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer)

def forward(self, x):
x = self.token_mixer(x)
Expand All @@ -142,7 +142,7 @@ def forward(self, x):
return identity + x


class RepViTStem(nn.Module):
class RepVitStem(nn.Module):
def __init__(self, in_chs, out_chs, act_layer):
super().__init__()
self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1)
Expand All @@ -154,13 +154,13 @@ def forward(self, x):
return self.conv2(self.act1(self.conv1(x)))


class RepViTDownsample(nn.Module):
class RepVitDownsample(nn.Module):
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer):
super().__init__()
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer)
self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim)
self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1)
self.ffn = RepViTMlp(out_dim, out_dim * mlp_ratio, act_layer)
self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer)

def forward(self, x):
x = self.pre_block(x)
Expand All @@ -171,22 +171,25 @@ def forward(self, x):
return x + identity


class RepViTClassifier(nn.Module):
def __init__(self, dim, num_classes, distillation=False):
class RepVitClassifier(nn.Module):
def __init__(self, dim, num_classes, distillation=False, drop=0.):
super().__init__()
self.head_drop = nn.Dropout(drop)
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.distillation = distillation
self.num_classes=num_classes
self.distilled_training = False
self.num_classes = num_classes
if distillation:
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()

def forward(self, x):
x = self.head_drop(x)
if self.distillation:
x1, x2 = self.head(x), self.head_dist(x)
if (not self.training) or torch.jit.is_scripting():
return (x1 + x2) / 2
else:
if self.training and self.distilled_training and not torch.jit.is_scripting():
return x1, x2
else:
return (x1 + x2) / 2
else:
x = self.head(x)
return x
Expand All @@ -207,11 +210,11 @@ def fuse(self):
return head


class RepViTStage(nn.Module):
class RepVitStage(nn.Module):
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True):
super().__init__()
if downsample:
self.downsample = RepViTDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
else:
assert in_dim == out_dim
self.downsample = nn.Identity()
Expand All @@ -230,7 +233,7 @@ def forward(self, x):
return x


class RepViT(nn.Module):
class RepVit(nn.Module):
def __init__(
self,
in_chans=3,
Expand All @@ -243,15 +246,16 @@ def __init__(
num_classes=1000,
act_layer=nn.GELU,
distillation=True,
drop_rate=0.,
):
super(RepViT, self).__init__()
super(RepVit, self).__init__()
self.grad_checkpointing = False
self.global_pool = global_pool
self.embed_dim = embed_dim
self.num_classes = num_classes

in_dim = embed_dim[0]
self.stem = RepViTStem(in_chans, in_dim, act_layer)
self.stem = RepVitStem(in_chans, in_dim, act_layer)
stride = self.stem.stride
resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])

Expand All @@ -263,7 +267,7 @@ def __init__(
for i in range(num_stages):
downsample = True if i != 0 else False
stages.append(
RepViTStage(
RepVitStage(
in_dim,
embed_dim[i],
depth[i],
Expand All @@ -281,7 +285,8 @@ def __init__(
self.stages = nn.Sequential(*stages)

self.num_features = embed_dim[-1]
self.head = RepViTClassifier(embed_dim[-1], num_classes, distillation)
self.head_drop = nn.Dropout(drop_rate)
self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation)

@torch.jit.ignore
def group_matcher(self, coarse=False):
Expand All @@ -304,9 +309,13 @@ def reset_classifier(self, num_classes, global_pool=None, distillation=False):
if global_pool is not None:
self.global_pool = global_pool
self.head = (
RepViTClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity()
RepVitClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity()
)

@torch.jit.ignore
def set_distilled_training(self, enable=True):
self.head.distilled_training = enable

def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
Expand All @@ -317,8 +326,9 @@ def forward_features(self, x):

def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
return x if pre_logits else self.head(x)
x = x.mean((2, 3), keepdim=False)
x = self.head_drop(x)
return self.head(x)

def forward(self, x):
x = self.forward_features(x)
Expand Down Expand Up @@ -373,7 +383,9 @@ def _cfg(url='', **kwargs):
def _create_repvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
RepViT, variant, pretrained, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs
RepVit, variant, pretrained,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs,
)
return model

Expand Down

0 comments on commit c692715

Please sign in to comment.