-
Notifications
You must be signed in to change notification settings - Fork 22.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Need support for bounded weight, via optionally allowing F.relu() to fall through its gradient. #97452
Comments
Thanks! Yes it can be related if implemented. I'm thinking of something simpler for bounded weights, see modified post above. |
I also tagged @lezcano to comment on suitability of parametrizations for efficiently implementing this. I think, its API is a good match. Implementing some sort of quasi-relu you're proposing might also be a separate feature request... I think currently only two weight parametrizations were implemented: orthogonalization and bounding spectral norm: https://pytorch.org/docs/stable/_modules/torch/nn/utils/parametrizations.html WeightNorm was kept as legacy impl |
relu has a well-defined differentiability behavior. If you need special non-standard definition of gradient, you should use custom autograd function or gradient hooks. |
@ngimel is right here. Also, as a pointer, you may want to use LeakyReLU or play around with other activation functions that do not zero-out the inputs. FWIW, from a mathematical perspective, having incorrect gradients is not a good idea. |
understood, thanks! |
🚀 The feature, motivation and pitch
Feature: to facilitate bounded weights in modules, allow
F.relu()
to behave like identity during back-propagation.Motivation:
Sometimes modules need to have weights that are non-negative (or bounded by other values), and simply using
F.relu(w)
doesn't work.Because, during training
w
could become negative, and in that situationF.relu(w)
has 0-gradient, and from then onw
will no longer be updated and gets stuck there forever.Current solution is to clamp
w
by 0 after gradient update in the training script, which is outside the module containingw
and thus breaks encapsulation.If we allow
F.relu()
to behave like identity during back-propagation, thenF.relu(w)
will be able to clampw
by 0 during forward, and will not trap w below 0 during training.More generally, we can use
F.relu(w-c)+c
to realize any constant lower bound on w, and-F.relu(c-w)+c
to realize any constant upper bound on w.Certainly we can implement our custom version of relu() to do the trick, it's just more convienent to have it in
F.relu()
.Alternatives
No response
Additional context
No response
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki
The text was updated successfully, but these errors were encountered: