Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need support for bounded weight, via optionally allowing F.relu() to fall through its gradient. #97452

Closed
mega-optimus opened this issue Mar 23, 2023 · 6 comments
Labels
module: nn Related to torch.nn

Comments

@mega-optimus
Copy link
Contributor

mega-optimus commented Mar 23, 2023

🚀 The feature, motivation and pitch

  • Feature: to facilitate bounded weights in modules, allow F.relu() to behave like identity during back-propagation.

  • Motivation:

Sometimes modules need to have weights that are non-negative (or bounded by other values), and simply using F.relu(w) doesn't work.
Because, during training w could become negative, and in that situation F.relu(w) has 0-gradient, and from then on w will no longer be updated and gets stuck there forever.

Current solution is to clamp w by 0 after gradient update in the training script, which is outside the module containing w and thus breaks encapsulation.

If we allow F.relu() to behave like identity during back-propagation, then F.relu(w) will be able to clamp w by 0 during forward, and will not trap w below 0 during training.
More generally, we can use F.relu(w-c)+c to realize any constant lower bound on w, and -F.relu(c-w)+c to realize any constant upper bound on w.

Certainly we can implement our custom version of relu() to do the trick, it's just more convienent to have it in F.relu().

Alternatives

No response

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Mar 23, 2023

I wonder also if #7313 can be used for this...

cc @lezcano

@mega-optimus mega-optimus changed the title Need support for non-negative weight. Need support for bounded weight, via optionally allowing relu (or generally any function) to fall through its gradient. Mar 23, 2023
@mega-optimus mega-optimus changed the title Need support for bounded weight, via optionally allowing relu (or generally any function) to fall through its gradient. Need support for bounded weight, via optionally allowing relu to fall through its gradient. Mar 23, 2023
@mega-optimus mega-optimus changed the title Need support for bounded weight, via optionally allowing relu to fall through its gradient. Need support for bounded weight, via optionally allowing F.relu() to fall through its gradient. Mar 23, 2023
@mega-optimus
Copy link
Contributor Author

I wonder also if #7313 can be used for this...

Thanks! Yes it can be related if implemented. I'm thinking of something simpler for bounded weights, see modified post above.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Mar 23, 2023

I also tagged @lezcano to comment on suitability of parametrizations for efficiently implementing this. I think, its API is a good match. Implementing some sort of quasi-relu you're proposing might also be a separate feature request...

I think currently only two weight parametrizations were implemented: orthogonalization and bounding spectral norm: https://pytorch.org/docs/stable/_modules/torch/nn/utils/parametrizations.html

WeightNorm was kept as legacy impl

@ngimel
Copy link
Collaborator

ngimel commented Mar 23, 2023

relu has a well-defined differentiability behavior. If you need special non-standard definition of gradient, you should use custom autograd function or gradient hooks.
Reparametrization functionality for parameters also exists, so I'm closing this issue, as it doesn't look like this feature should live in core.

@ngimel ngimel closed this as completed Mar 23, 2023
@ngimel ngimel added the module: nn Related to torch.nn label Mar 23, 2023
@lezcano
Copy link
Collaborator

lezcano commented Mar 23, 2023

@ngimel is right here. Also, as a pointer, you may want to use LeakyReLU or play around with other activation functions that do not zero-out the inputs. FWIW, from a mathematical perspective, having incorrect gradients is not a good idea.

@mega-optimus
Copy link
Contributor Author

understood, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn
Projects
None yet
Development

No branches or pull requests

4 participants