Replies: 2 comments 1 reply
-
Okay, I finally did some serious tests, not with timm but using the official torchvision codebase. I didn't train on ImageNet, but the Tiny ImageNet (200 classes, 64x64 images, which may be too small). For their version of ResNet50 as the baseline, I used their training script (and hyperparameters) except only training for 60 epochs (and batch_size=64), and got
if I interpret the log correctly. With my suggestion above (depthwise with *4), and one extra operation between the middle 3x3 and the last 1x1 that I call "(Convolutional Layer with) a Twist", I got
with the exact same training scheme (60 epochs). If anyone thinks this is significant improvement, I'd be very happy to see it trained on the full ImageNet, with all the tricks in ResNet Strikes Back. Here are all the relevant changes: class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
# self.conv2 = conv3x3(width, width, stride, groups, dilation) # This is the original
# self.bn2 = norm_layer(width) # This is the original
# self.conv3 = conv1x1(width, planes * self.expansion) # This is the original
self.conv2 = conv3x3(width, width * 4, stride, width, dilation) # Modified
self.XY = None # Modified
self.mix = conv1x1(3, width * 4) # Modified
self.bn2 = norm_layer(width * 4) # Modified
self.conv3 = conv1x1(width * 4, planes * self.expansion) # Modified
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
if self.XY is None or self.XY.size()[2] != out.size()[2]: # Added
N, C, H, W = out.size() # Added
XX = torch.from_numpy(np.indices((1, 1, H, W))[3] * 2 / W - 1) # Added
YY = torch.from_numpy(np.indices((1, 1, H, W))[2] * 2 / H - 1) # Added
ones = torch.from_numpy(np.ones((1, 1, H, W))) # Added
self.XY = torch.cat([ones, XX, YY # Added
# , XX*XX, XX*YY, YY*YY # Added
], dim=1).type(out.dtype).to(out.device) # Added
out = out * self.mix(self.XY) # Added
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out Update: I've written a blog post about this, https://wandb.ai/liuyao12/posts/ConvNets-from-the-PDE-perspective--VmlldzoxNzY2NDE2 |
Beta Was this translation helpful? Give feedback.
-
A year later, I finally got to run timm and can answer my own questions: smaller dataset (and fewer training epochs) isn't indicative of performance on full ImageNet, which is what matters. (I understand that this kind of research is no longer fashionable, though there’s ConvNext, and more recently InternImage.) With timm's official training script for ResNext50 (240 epochs), I was able to hit
compared to the baseline of 78.112 on my machine (official timm acc1=79.762). |
Beta Was this translation helpful? Give feedback.
-
I've had this (newbie) question for a long time, but didn't know where to ask. After seeing ResNet Strikes Back, I think some of the followers of this repo (if not Ross himself) may be able to answer.
ResNet50 uses Bottleneck blocks, which goes like
First observe that, with 3x3 kernels going into more than 9 channels, there surely is a waste of computation. What I mean is that there can only be 9 different kernels (or more precisely, 9 linearly independent kernels). We could replace the middle 3x3 layer by a two-layer process
so it saves parameters while achieving the exact same thing (if no activation in the middle), am I right? That's the so-called "depthwise convolution“ with depth multiplier>1, but seems to have been neglected in architectural design. Well, it does seem to cost more parameters, but we could do x4 instead of x9, so that we save about half the parameters (and the number of channels remains a power of 2). Plus, my gut feeling is that the 64 kernels coming out of one channel won't actually fill up the 9-dimensional space of kernels, but rather lie in a lower-dimensional subspace. (That should be easy to verify with any pretrained weights.) Also, the 1x1 should just merge with the next 1x1 in the original Bottleneck, or some other way of redesigning the Bottleneck block.
Would that be something that could be quickly tested on ImageNet with timm?
That's certainly not specific to ResNet: any CNN that doesn't already use depthwise would have the same issue (and same fix), but maybe since the early days people are used to seeing 5x5 and 7x7 kernels that it wasn't noticed? To me, 3x3 is very natural because it acts like a differential operator (of order <=2) in the spatial/image dimensions, from the perspective of the continuous formulation of neural nets.
I'd also be happy to be proven wrong: either people have been using this, or it fails to outperform for whatever reason. Thanks for any feedbacks!
Beta Was this translation helpful? Give feedback.
All reactions