-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ConvNeXt architecture in prototype #5197
Conversation
💊 CI failures summary and remediationsAs of commit 2edbd8d (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
🚧 2 ongoing upstream failures:These were probably caused by upstream breakages that are not fixed yet.
This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HannaMao @s9xie @liuzhuang13 Thanks for your work on ConvNeXt.
Last week we've added an experimental TorchVision-compatible implementation of your architecture based on your code. This is done to facilitate experiments on how TorchVision's new recipe works on new architectures.
Below I've left a couple of comments for you, could you please check and let me know your thoughts?
Also we managed to reproduce the reported accuracies using TorchVision's components for the tiny version (the other variants require more time to complete). While experimenting with the recipe, I noticed that it's very difficult to train the architecture with standard SGD. AdamW certainly does the trick, but I just wanted to know if that's something you observed as well doing your research.
torchvision/models/convnext.py
Outdated
|
||
def forward(self, x): | ||
if not self.channels_last: | ||
x = x.permute(0, 2, 3, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At TorchVision, we try to reuse as much as possible standard PyTorch components. I understand that in your original implementation, you provide a custom implementation for channels first. Could you talk about the performance degradation you've experienced and how much this was that lead you to reimplementing it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is we don't have a standard PyTorch component for channels_first LN. Here are some discussions with Ross Wightman (https://twitter.com/wightmanr/status/1481383509142818817?s=20). In Ross's timm implementation, he has a presumably better LN implementation for channel_first. See here: https://github.com/rwightman/pytorch-image-models/blob/b669f4a5881d17fe3875656ec40138c1ef50c4c9/timm/models/convnext.py#L109
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the references. I wasn't aware of the concurrent discussions. I'll have a look to measure on our side. Ideally I would like to reuse the existing kernels as much as possible, unless there is a big gap in performance to justify a custom implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @datumbox,
In this post I try to explain why we use linear layers instead of conv layers for 1x1 convs in residual blocks, in case it is any help: facebookresearch/ConvNeXt#18 (comment)
As for why we use the custom LN in the downsampling layers instead of permuting -> PyTorch LN -> permuting back, the reason is similar, we observe the former is slightly faster when used in downsampling layers.
Thanks for your work on incorporating ConvNeXt!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@liuzhuang13 Thanks for the reply. Is it fair to say that the approach that we follow here is expected to be 0-5% slower than the optimum? I haven't had the chance to run benchmarks but that's what I understand from your note.
FYI, the issue is that TorchVision has a common API across all models to accept the norm_layer
as a parameter which is going to be tricky to support if I switch the 1x1 convs to linear. Where or not we will do this, depends on the speed impact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just tried switching to conv layers in residual blocks, and using permutation + LN in all cases. I found combined together they cause 20-30% slowdown in inference @ 224 resolution for ConvNeXt-T/S, compared to our released impl. However, in ConvNeXt-B at 224 or any model at 384 resolution, it seems as fast as our released impl. I only tried ConvNeXt T/S/B. This is on V100s, and I cannot say much on other platforms though. It is indeed a bit strange to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for looking into it @liuzhuang13. 20-30% slow sounds very large and we would probably want to make it faster. I'll run benchmarks later and share with you any findings.
Regarding SGD vs AdamW, this is something we are particularly interested and plan to explore ourselves. Thanks for the data point, we also find ConvNeXt behaves very similarly to ViT/Swin Transformer. There are a couple of hypothesis - 1) the hyper-parameter choices needs to be revisited with SGD. 2) Based on the ConvStem paper (https://arxiv.org/abs/2106.14881), one source of instability w.r.t SGD might be the initial patchify stem. We plan to try ConvStem with SGD and hopefully this will address the optimization problem (and the patchify layer did not bring a substantial gain in accuracy anyways). |
@s9xie Sounds great, please do let me know about any of your findings. In the meantime, I changed optimizer on our recipe and we got the following accuracies for the tiny variant: The above is achieved by doing the minimum possible modifications on our recipe to change SGD with AdamW, so here are the overwrites only related to the optimizer change: |
@datumbox Good! Do you plan to add another CNN: https://github.com/facebookresearch/deit/blob/main/patchconvnet_models.py |
@xiaohu2015 It's not in the immediate plans but you can add a proposal at #2707 and we can discuss it. :) |
It may be useful to derive ConvNexT from nn.Sequential (#3331), it should be very simple by just calling |
super().__init__(*args, **kwargs) | ||
|
||
def forward(self, x): | ||
# TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmarking necessary and potential rewrite to move out of prototype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just made some minor comments
self.num_layers = num_layers | ||
|
||
def __repr__(self) -> str: | ||
s = self.__class__.__name__ + "(" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing f-string indication f
nit: if you want to remove multiple assignments you can write something like
s = (
self.__class__.__name__ +
f"(input_channels={input_channels}, out_channels={out_channels}, num_layers={num_layers})"
)
or if you rename input_channels to in_channels you can get everything in one line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a quite common pattern in TorchVision that I'm repeating here. See this. We could change in all instances perhaps on a separate issue?
Also good call for the input_channels
vs in_channels
. Here I maintain it for consistency with other models such a shufflenets, mobilenetv3, efficientnets, vit etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 Makes sense to leave it as it is now. I will create an issue to investigate if it makes sense to change everywhere.
I prefer using fstrings
and explicitly use the variables we need but I can see that using this patter with s.format(**self.__dict__)
is quite generic
self.num_layers = num_layers | ||
|
||
def __repr__(self) -> str: | ||
s = self.__class__.__name__ + "(" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 Makes sense to leave it as it is now. I will create an issue to investigate if it makes sense to change everywhere.
I prefer using fstrings
and explicitly use the variables we need but I can see that using this patter with s.format(**self.__dict__)
is quite generic
Summary: * Adding CNBlock and skeleton architecture * Completed implementation. * Adding model in prototypes. * Add test and minor refactor for JIT. * Fix mypy. * Fixing naming conventions. * Fixing tests. * Fix stochastic depth percentages. * Adding stochastic depth to tiny variant. * Minor refactoring and adding comments. * Adding weights. * Update default weights. * Fix transforms issue * Move convnext to prototype. * linter fix * fix docs * Addressing code review comments. Reviewed By: jdsgomes, prabhat00155 Differential Revision: D33739375 fbshipit-source-id: 9df87bff1030cb629faf7d056957d1153a58af42
Experimental implementation of ConvNeXt from https://arxiv.org/pdf/2201.03545.pdf
We managed to fully reproduce the paper using the original recipe described on the paper (
Acc@1 82.064 Acc@5 95.858
). Nevertheless, using TorchVision's new recipe along with a switch from SGD=>AdamW (--opt adamw --lr 1e-3 --weight-decay 0.05
) we are able to improve the accuracy by 0.45 points:Model Validation:
cc @datumbox @vfdev-5 @bjuncek