Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ConvNeXt architecture in prototype #5197

Merged
merged 22 commits into from
Jan 20, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Jan 14, 2022

Experimental implementation of ConvNeXt from https://arxiv.org/pdf/2201.03545.pdf

We managed to fully reproduce the paper using the original recipe described on the paper (Acc@1 82.064 Acc@5 95.858). Nevertheless, using TorchVision's new recipe along with a switch from SGD=>AdamW (--opt adamw --lr 1e-3 --weight-decay 0.05) we are able to improve the accuracy by 0.45 points:

PYTHONPATH=$PYTHONPATH:`pwd` python -u run_with_submitit.py --ngpus 8 --nodes 2 \ 
--model convnext_tiny --batch-size 64 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \ 
--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \ 
--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \
--train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4

Model Validation:

torchrun --nproc_per_node=1 train.py --test-only --weights ConvNeXt_Tiny_Weights.ImageNet1K_V1 --model convnext_tiny
Acc@1 82.520 Acc@5 96.146

cc @datumbox @vfdev-5 @bjuncek

@facebook-github-bot
Copy link

facebook-github-bot commented Jan 14, 2022

💊 CI failures summary and remediations

As of commit 2edbd8d (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



🚧 2 ongoing upstream failures:

These were probably caused by upstream breakages that are not fixed yet.


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@datumbox datumbox marked this pull request as draft January 15, 2022 01:32
@datumbox datumbox changed the title [NOMERGE] Adding ConvNeXt architecture [WIP] Adding ConvNeXt architecture Jan 16, 2022
Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HannaMao @s9xie @liuzhuang13 Thanks for your work on ConvNeXt.

Last week we've added an experimental TorchVision-compatible implementation of your architecture based on your code. This is done to facilitate experiments on how TorchVision's new recipe works on new architectures.

Below I've left a couple of comments for you, could you please check and let me know your thoughts?

Also we managed to reproduce the reported accuracies using TorchVision's components for the tiny version (the other variants require more time to complete). While experimenting with the recipe, I noticed that it's very difficult to train the architecture with standard SGD. AdamW certainly does the trick, but I just wanted to know if that's something you observed as well doing your research.

torchvision/models/convnext.py Outdated Show resolved Hide resolved

def forward(self, x):
if not self.channels_last:
x = x.permute(0, 2, 3, 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At TorchVision, we try to reuse as much as possible standard PyTorch components. I understand that in your original implementation, you provide a custom implementation for channels first. Could you talk about the performance degradation you've experienced and how much this was that lead you to reimplementing it?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is we don't have a standard PyTorch component for channels_first LN. Here are some discussions with Ross Wightman (https://twitter.com/wightmanr/status/1481383509142818817?s=20). In Ross's timm implementation, he has a presumably better LN implementation for channel_first. See here: https://github.com/rwightman/pytorch-image-models/blob/b669f4a5881d17fe3875656ec40138c1ef50c4c9/timm/models/convnext.py#L109

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the references. I wasn't aware of the concurrent discussions. I'll have a look to measure on our side. Ideally I would like to reuse the existing kernels as much as possible, unless there is a big gap in performance to justify a custom implementation.

Copy link

@liuzhuang13 liuzhuang13 Jan 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @datumbox,
In this post I try to explain why we use linear layers instead of conv layers for 1x1 convs in residual blocks, in case it is any help: facebookresearch/ConvNeXt#18 (comment)

As for why we use the custom LN in the downsampling layers instead of permuting -> PyTorch LN -> permuting back, the reason is similar, we observe the former is slightly faster when used in downsampling layers.

Thanks for your work on incorporating ConvNeXt!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liuzhuang13 Thanks for the reply. Is it fair to say that the approach that we follow here is expected to be 0-5% slower than the optimum? I haven't had the chance to run benchmarks but that's what I understand from your note.

FYI, the issue is that TorchVision has a common API across all models to accept the norm_layer as a parameter which is going to be tricky to support if I switch the 1x1 convs to linear. Where or not we will do this, depends on the speed impact.

Copy link

@liuzhuang13 liuzhuang13 Jan 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried switching to conv layers in residual blocks, and using permutation + LN in all cases. I found combined together they cause 20-30% slowdown in inference @ 224 resolution for ConvNeXt-T/S, compared to our released impl. However, in ConvNeXt-B at 224 or any model at 384 resolution, it seems as fast as our released impl. I only tried ConvNeXt T/S/B. This is on V100s, and I cannot say much on other platforms though. It is indeed a bit strange to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for looking into it @liuzhuang13. 20-30% slow sounds very large and we would probably want to make it faster. I'll run benchmarks later and share with you any findings.

@s9xie
Copy link

s9xie commented Jan 16, 2022

Also we managed to reproduce the reported accuracies using TorchVision's components for the tiny version (the other variants require more time to complete). While experimenting with the recipe, I noticed that it's very difficult to train the architecture with standard SGD. AdamW certainly does the trick, but I just wanted to know if that's something you observed as well doing your research.

Regarding SGD vs AdamW, this is something we are particularly interested and plan to explore ourselves. Thanks for the data point, we also find ConvNeXt behaves very similarly to ViT/Swin Transformer. There are a couple of hypothesis - 1) the hyper-parameter choices needs to be revisited with SGD. 2) Based on the ConvStem paper (https://arxiv.org/abs/2106.14881), one source of instability w.r.t SGD might be the initial patchify stem. We plan to try ConvStem with SGD and hopefully this will address the optimization problem (and the patchify layer did not bring a substantial gain in accuracy anyways).

@datumbox
Copy link
Contributor Author

datumbox commented Jan 17, 2022

Regarding SGD vs AdamW, this is something we are particularly interested and plan to explore ourselves.

@s9xie Sounds great, please do let me know about any of your findings. In the meantime, I changed optimizer on our recipe and we got the following accuracies for the tiny variant: Acc@1 82.508 Acc@5 96.196

The above is achieved by doing the minimum possible modifications on our recipe to change SGD with AdamW, so here are the overwrites only related to the optimizer change: --opt adamw --lr 1e-3 --weight-decay 0.05. The values are adopted straight from the paper without any grid-search (with the LR adjusted for our batch-size), so it's very likely that we can do better.

@xiaohu2015
Copy link
Contributor

@datumbox
Copy link
Contributor Author

@xiaohu2015 It's not in the immediate plans but you can add a proposal at #2707 and we can discuss it. :)

@vadimkantorov
Copy link

It may be useful to derive ConvNexT from nn.Sequential (#3331), it should be very simple by just calling super().__init__(collections.OrderedDict(features = features, avgpool = avgpool, classifier = classifier)) and then probably the custom forward is not needed.

@datumbox datumbox changed the title [WIP] Adding ConvNeXt architecture Adding ConvNeXt architecture in prototype Jan 20, 2022
@datumbox datumbox marked this pull request as ready for review January 20, 2022 12:02
super().__init__(*args, **kwargs)

def forward(self, x):
# TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarking necessary and potential rewrite to move out of prototype.

Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just made some minor comments

references/classification/README.md Outdated Show resolved Hide resolved
torchvision/prototype/models/convnext.py Outdated Show resolved Hide resolved
self.num_layers = num_layers

def __repr__(self) -> str:
s = self.__class__.__name__ + "("
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing f-string indication f

nit: if you want to remove multiple assignments you can write something like

s = (
    self.__class__.__name__ +
    f"(input_channels={input_channels}, out_channels={out_channels}, num_layers={num_layers})"
)

or if you rename input_channels to in_channels you can get everything in one line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a quite common pattern in TorchVision that I'm repeating here. See this. We could change in all instances perhaps on a separate issue?

Also good call for the input_channels vs in_channels. Here I maintain it for consistency with other models such a shufflenets, mobilenetv3, efficientnets, vit etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Makes sense to leave it as it is now. I will create an issue to investigate if it makes sense to change everywhere.
I prefer using fstrings and explicitly use the variables we need but I can see that using this patter with s.format(**self.__dict__) is quite generic

self.num_layers = num_layers

def __repr__(self) -> str:
s = self.__class__.__name__ + "("
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Makes sense to leave it as it is now. I will create an issue to investigate if it makes sense to change everywhere.
I prefer using fstrings and explicitly use the variables we need but I can see that using this patter with s.format(**self.__dict__) is quite generic

@datumbox datumbox merged commit afda28a into pytorch:main Jan 20, 2022
@datumbox datumbox deleted the models/convnext branch January 20, 2022 15:49
@datumbox datumbox linked an issue Jan 21, 2022 that may be closed by this pull request
37 tasks
facebook-github-bot pushed a commit that referenced this pull request Jan 26, 2022
Summary:
* Adding CNBlock and skeleton architecture

* Completed implementation.

* Adding model in prototypes.

* Add test and minor refactor for JIT.

* Fix mypy.

* Fixing naming conventions.

* Fixing tests.

* Fix stochastic depth percentages.

* Adding stochastic depth to tiny variant.

* Minor refactoring and adding comments.

* Adding weights.

* Update default weights.

* Fix transforms issue

* Move convnext to prototype.

* linter fix

* fix docs

* Addressing code review comments.

Reviewed By: jdsgomes, prabhat00155

Differential Revision: D33739375

fbshipit-source-id: 9df87bff1030cb629faf7d056957d1153a58af42
@datumbox datumbox mentioned this pull request Feb 11, 2022
24 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Are new models planned to be added?
7 participants