-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RF weight dropout and variational noise #1518
Comments
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as duplicate.
This comment was marked as off-topic.
This comment was marked as off-topic.
(Note, I made a separate issue just for the gradient checkpointing aspect in PyTorch: #1552. So this issue here can just focus on the RF specific question on how to implement weight dropout (or also weight noise / variational noise). Edit This is implemented now. See #1559, |
So, I tend to reimplement something very similar as the PyTorch parametrization API, and also following some of the internal design choices.
Some open questions:
|
I also thought about deriving or extending |
Currently we don't have weight dropout in the RF. We should add it.
(I thought there was an issue already about it but I don't find it.)
Related:
returnn_common.nn.utils.weight_dropout.weight_dropout
. Related issue: How to define the API for parameter initialization, regularization (L2, weight dropout, etc), maybe updater opts per-param returnn_common#59, Allow__init__
logic to work equally for graph-based and eager-based backends, specifically re-parameterization like weight norm returnn_common#250nn
#1264 (comment)In general, there is a whole class of similar features on parameter reparameterizations, which would require a similar mechanism, like weight norm.
Regarding implementation:
I think we could follow the PyTorch implementation of similar logic (e.g. weight norm) by using a forward-pre-hook. We already have support for hooks via
rf.Module.register_forward_hook
/rf.hooks.setup_post_hook_on_method
.Regarding PyTorch:
torch.nn.utils.weight_norm
usesregister_forward_pre_hook
. However, the doc says:torch.nn.utils.parametrizations.weight_norm
uses the modern parametrization API, i.e.torch.nn.utils.parametrize.register_parametrization
. Onregister_parametrization
(it's somewhat ugly (although our own hooks mechanism is also a bit complicated), but this here looks even worse):_inject_new_class
to replace the original module class by a new dummy one, which is going to be extended by adding a property for the parameter.(We cannot add the property to the object. A property can only be added to a class. See the descriptor guide doc.)
delattr
) and stores it in aParametrizationList
._inject_property
) for the parameter. This will call theparametrization
to get it.cached()
context, it will clear the (global) cache.So, comparing the modern parametrization API
register_parametrization
to the way it was done intorch.nn.utils.weight_norm
viaregister_forward_pre_hook
:register_parametrization
also works for any other module methods, whileregister_forward_pre_hook
only works forforward
.register_forward_pre_hook
will create a temporary buffer to store the calculated parameter. This is buffer is not freed afterforward
. It is just updated in the next pre-hook. Viaregister_parametrization
, the calculated parameter is freed after it's usage (and with caching enabled, after you leave thecached()
context scope). Soregister_parametrization
requires less memory.register_parametrization
will recompute the parameter several times when it is accessed multiple times (whenever the property is accessed), unless the user uses the parameterizationcached
mechanism explicitly. But you usually would want to avoid the redundant computation, i.e. you want to use thecached
context. But that means you also need to modify the way you call your model by installing thecached
context somewhere. This is not automatic. On the other side,register_forward_pre_hook
is always automatic. (Although, if you callforward
multiple times, it would also cause redundant recomputations of the parameter.)register_forward_pre_hook
is conceptually much simpler and less ugly thanregister_parametrization
.Concluding from that, I'm a bit unsure what way to go for RF... Using
register_forward_pre_hook
looks too error-prone (only covers the hooked function, nothing else)... but the other approach looks too complicated? But maybe still better. The caching mechanism is maybe also not so important for now? For all use cases, I think it would not matter (e.g.rf.Linear
,rf.SelfAttention
,rf.Conv
, etc.).Further, we should also support this with gradient checkpointing such that the weights are not stored twice in memory. In our existing TF implementation of variational noise, we already use gradient checkpointing, where only the random number generator state is stored and not the dropout mask nor the weight. Thus there is almost no memory overhead. See
gradient_checkpoint_scope
and co. For PyTorch, it is currently unclear how to do this. I moved this over to a separate issue: #1552The text was updated successfully, but these errors were encountered: