-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] API For Common Layers In Torchvision #4333
Comments
Don't worry @oke-aditya , that is what issues are for. You've put in quite some work to compile this, so even if we totally disagree (not saying that we are) you don't need to be sorry at all. Thanks a lot for your work! I'm all for reducing duplications. For my own projects I have written
For the naming issue I would go for
I don't see why not. Not sure if we want to push it as feature, i.e. making it a priority to add something there, but if someone wants to contribute a loss I think we can accept this. So I would only put the general model blocks in there in the first place and see what the community wants.
Since
so copy-pasting the code is already not possible.
Yes, there is always a trade-off between customizability and ease-of-use. Given that we implement quite a few important models, IMO we can use this as "limit" for customizability. That means, all our stuff should be covered without some hacky workarounds or special casing. Everything beyond that probably warrants a custom implementation of the user. We can always make it more general in the future if there is some need.
The standard way is to "mark" (in the documentation as well as in the release notes) a module in a prototype state. The aim of this is twofold:
|
Thanks @oke-aditya for bringing up this topic! For example, I would treat layers like ConvNormAct as developer API since it's just an shortcut for composing three operators together. Can we put this kind of API under For layers like SqueezeExcitation that are back by papers, referenced in many model architectures. We could make them public to the community. Putting them in |
Hi @kazhang I agree to your thoughts. We should be exposing only those layers that are referenced in model architectures. It would be nice to reduce internal code duplication by keeping them under I would prefer |
Picks up from discussion in #4293 (comment)
🚀 Feature
API for Commonly used Layers in building models.
Motivation
A huge code duplication is involved in building very basic blocks for large neural networks. Some of these blocks can be standardized and re-used. Also these could be offered to end user as an API so that downstream libraries can build models easily.
E.g. for duplication are SqueezeExcitation, ConvBNRelu, etc.
Pitch
Create an API called
torchvision.nn
ortorchvision.layers
Our implementations need to be generic but not locked to certain activation functions or channels, etc.
These can be simply classes based on
nn.Module
.An example
User can use this as
Also layers can be mixed into new custom models.
E.g.
Points to Consider
We have torchvision.ops then why layers?
Ops are transforms that do manipulations with pre-processing and post processing of structures such as Boxes, Masks, Anchors, etc. These are not used in "model building" but are optional steps for specific models.
Also these are
E.g. NMS, IoU, RoI, etc.
One doesn't need ops for every model.
E.g. You don't need to do RoI align, for DeTR. Or you don't computer IoU for segmentation masks.
With separate API can be clear distinction in what are
layer
for models and operators fortasks
such as detection, segmentation.Should
torchvision.nn
contain losses?This is tricky, and for now I see no clear winner.
PyTorch does not differentiate the API for losses or layers.
E.g. we do
nn.Conv2d
which builds a convolutional layer. Also we donn.CrossEntropy
ornn.MSE
which builds a loss function.I'm not sure whether layers should be
torchvision.layers
ortorchvision.nn
(if implemented of course)Users don't need to worry about colliding namespaces. They can do.
Note that
nn
seems to be the convention adopted by torchtext.Other points to consider.
Portability: -
Currently most of the torchvision models are easily copy pastable. E.g. We can easily copy paste mobilenetv2.py file and edit it on the go to customize models.
By Adding such API we can reduce the internal code duplication but these files would no longer be single standalone files for models.
Layer Customization : -
Layer Customization has far too many options to consider.
E.g. there are several implementations possible for BasicBlock of ResNet or some slight modifications of inverted residual layer.
One can't create an implementation that will suit all the needs for everyone. If one tries to, then the API would be significantly complicated.
TorchScript: -
We shouldn't be hampering torchscript compatibility of any model while implementing above API.
Additional context
Some candidates for layers
Not sure why it is under ops, it doesn't it there
Also Quantizable versions of these !
Quantizable versions will allow users to directly fuse up the models.
Additionally I will recommend not hurrying this feature, we could create
torchvision.experimental.nn
and start working out things.(or probably I can try in a fork)Linking plans #3911 #4187
P.S. I'm a junior developer and all my thoughts are probably step too far. So please forgive me if I'm wrong.
cc @datumbox @pmeier @NicolasHug @fmassa
The text was updated successfully, but these errors were encountered: