Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding multi-layer perceptron in ops #6053

Merged
merged 9 commits into from
May 19, 2022
Merged

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented May 19, 2022

We should avoid using ViT's MLP block from Swin:

from .vision_transformer import MLPBlock

The specific layer is very common and has been previously requested at #4333

This PR:

  • Adds a generic MLP block to TorchVision.
  • Handles BC on ViT.
  • Replaces MLPBlock with MLP on Swin and patches the unreleased weights (weights uploaded on S3 and manifold).

References:

Proof that the new API doesn't break Swin. The minor differences are expected due to the known non-deterministic behaviour of some kernels:

srun -p dev --cpus-per-task=96 -t 24:00:00 --gpus-per-node=1 torchrun --nproc_per_node=1 train.py --model swin_t --test-only -b 1 --weights="Swin_T_Weights.IMAGENET1K_V1"
Test:  Acc@1 81.476 Acc@5 95.780

srun -p dev --cpus-per-task=96 -t 24:00:00 --gpus-per-node=1 torchrun --nproc_per_node=1 train.py --model swin_s --test-only -b 1 --weights="Swin_S_Weights.IMAGENET1K_V1"
Test:  Acc@1 83.182 Acc@5 96.366

srun -p dev --cpus-per-task=96 -t 24:00:00 --gpus-per-node=1 torchrun --nproc_per_node=1 train.py --model swin_b --test-only -b 1 --weights="Swin_B_Weights.IMAGENET1K_V1"
Test:  Acc@1 83.584 Acc@5 96.636

cc @ankitade

@datumbox datumbox mentioned this pull request May 19, 2022
24 tasks
@datumbox datumbox requested a review from NicolasHug May 19, 2022 15:12
torchvision/models/vision_transformer.py Show resolved Hide resolved
torchvision/ops/misc.py Outdated Show resolved Hide resolved
hidden_channels (List[int]): List of the hidden channel dimensions
out_channels (int): Number of channels of the output
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Annotations in docstring make me sad :'(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know how you feel above this. :( I think all the callables all over TorchVision are added like that to provide info on what they are supposed to return.

torchvision/ops/misc.py Outdated Show resolved Hide resolved
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API LGTM, thanks @datumbox

Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, feel free to merge after the same changes are done in Swin. If you prefer I can have a second look once thats done

@datumbox datumbox changed the title [WIP] Adding multi-layer perceptron in ops Adding multi-layer perceptron in ops May 19, 2022
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param
from .convnext import Permute
from .vision_transformer import MLPBlock
from .convnext import Permute # TODO: move Permute on ops
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a straight move from convnext to ops (no weight patching needed) but to avoid doing everything on a single PR I plan to do it on a follow up.

@datumbox datumbox merged commit 77cad12 into pytorch:main May 19, 2022
@datumbox datumbox deleted the ops/mlp branch May 19, 2022 18:15
facebook-github-bot pushed a commit that referenced this pull request Jun 1, 2022
Summary:
* Adding an MLP block.

* Adding documentation

* Update typos.

* Fix inplace for Dropout.

* Apply recommendations from code review.

* Making changes on pre-trained models.

* Fix linter

Reviewed By: datumbox, NicolasHug

Differential Revision: D36760914

fbshipit-source-id: 331d2ebbf9bb1782695c14bb6ee5e158847ba356
in_dim = hidden_dim

layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
layers.append(torch.nn.Dropout(dropout, **params))
Copy link

@thomasbbrunner thomasbbrunner Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not very clear for me why there's a Dropout layer after the last layer. I saw that it was present in the previous MLPBlock class, but no other implementation of MLP with dropout (that I could find) has a dropout layer on the output. Including the one in the multimodal package.

Maybe this was something specific for the usecase of MLPBlock? If so, this should not be in this class.

Copy link
Contributor Author

@datumbox datumbox Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right there are various implementations of MLP, some of which don't have at all dropout, some have at the middle but not at the end or some have everywhere. If you check the references, you will see that all patterns exist. Our implementation is like that because it replaces MLP layers used in existing models like ViT and Swin. We also try to support more complex variations with more than 2 linear layers. Your observation is correct though that if one wanted to avoid having dropout at the end, the current implementation wouldn't let them. Since that variant is also valid, perhaps it's worth making this update in a non-BC way with a new boolean that controls the appearance of Dropout at the end or not. WDYT?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your suggestion sounds very nice. What would be the default value be for the boolean? I guess that setting it to True (with dropout) would cause no breaking changes. At the same time, I would say that not having a dropout in the last layer is more common (default) configuration? Also, I'd be intested in working on this, whichever option is chosen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that setting it to True (with dropout) would cause no breaking changes.

Yes, you are right we will need to maintain BC. Note that using True is the "default" setup on TorchVision at the moment as literally all existing models require dropout everywhere.

I'd be intested in working on this

Sounds great, let me recommend the following. Could you start an issue, summarizing what you said here and providing a few references of the usage of MLP with a middle dropout but without the final one? Providing a few examples from real-world vision architectures will help build a stronger case. Once we clarify the details on the issue, we can discuss a potential PR. 😃

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! I am a bit short on time at the moment, but will have more time in the upcoming weeks. Nevertheless, I'm interested in this and will be working on it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants