Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RF weight dropout and variational noise #1518

Closed
albertz opened this issue May 24, 2024 · 9 comments · Fixed by #1578
Closed

RF weight dropout and variational noise #1518

albertz opened this issue May 24, 2024 · 9 comments · Fixed by #1578
Assignees

Comments

@albertz
Copy link
Member

albertz commented May 24, 2024

Currently we don't have weight dropout in the RF. We should add it.

(I thought there was an issue already about it but I don't find it.)

Related:

In general, there is a whole class of similar features on parameter reparameterizations, which would require a similar mechanism, like weight norm.

Regarding implementation:

I think we could follow the PyTorch implementation of similar logic (e.g. weight norm) by using a forward-pre-hook. We already have support for hooks via rf.Module.register_forward_hook/rf.hooks.setup_post_hook_on_method.

Regarding PyTorch:

  • torch.nn.utils.weight_norm uses register_forward_pre_hook. However, the doc says:

    This function is deprecated. Use torch.nn.utils.parametrizations.weight_norm() which uses the modern parametrization API.

  • torch.nn.utils.parametrizations.weight_norm uses the modern parametrization API, i.e. torch.nn.utils.parametrize.register_parametrization. On register_parametrization (it's somewhat ugly (although our own hooks mechanism is also a bit complicated), but this here looks even worse):
    • It uses _inject_new_class to replace the original module class by a new dummy one, which is going to be extended by adding a property for the parameter.
      (We cannot add the property to the object. A property can only be added to a class. See the descriptor guide doc.)
    • It deletes the parameter (delattr) and stores it in a ParametrizationList.
    • It injects a property (_inject_property) for the parameter. This will call the parametrization to get it.
    • Note that this implementation of parametrization is not working with scripting. (I'm not exactly sure why though?)
    • Even tracing is not working together with caching? (Again not sure why?)
    • There is a caching mechanism which is disabled by default. The user need to explicitly enable it temporarily, like so:
      import torch.nn.utils.parametrize as P
      ...
      with P.cached():
          output = model(inputs)
      When leaving the cached() context, it will clear the (global) cache.

So, comparing the modern parametrization API register_parametrization to the way it was done in torch.nn.utils.weight_norm via register_forward_pre_hook:

  • register_parametrization also works for any other module methods, while register_forward_pre_hook only works for forward.
  • register_forward_pre_hook will create a temporary buffer to store the calculated parameter. This is buffer is not freed after forward. It is just updated in the next pre-hook. Via register_parametrization, the calculated parameter is freed after it's usage (and with caching enabled, after you leave the cached() context scope). So register_parametrization requires less memory.
  • register_parametrization will recompute the parameter several times when it is accessed multiple times (whenever the property is accessed), unless the user uses the parameterization cached mechanism explicitly. But you usually would want to avoid the redundant computation, i.e. you want to use the cached context. But that means you also need to modify the way you call your model by installing the cached context somewhere. This is not automatic. On the other side, register_forward_pre_hook is always automatic. (Although, if you call forward multiple times, it would also cause redundant recomputations of the parameter.)
  • register_forward_pre_hook is conceptually much simpler and less ugly than register_parametrization.
  • PyTorch parametrization API introduction and discussions: PR Parametrization Functionality pytorch/pytorch#33344, issue A class to perform constrained optimization through a parametrization pytorch/pytorch#28937, issue [proposal] [discussion] Refactor pruning/weight_norm using new Reparametrization functionality + actually deprecate old impl of SpectralNorm pytorch/pytorch#7313.

Concluding from that, I'm a bit unsure what way to go for RF... Using register_forward_pre_hook looks too error-prone (only covers the hooked function, nothing else)... but the other approach looks too complicated? But maybe still better. The caching mechanism is maybe also not so important for now? For all use cases, I think it would not matter (e.g. rf.Linear, rf.SelfAttention, rf.Conv, etc.).

Further, we should also support this with gradient checkpointing such that the weights are not stored twice in memory. In our existing TF implementation of variational noise, we already use gradient checkpointing, where only the random number generator state is stored and not the dropout mask nor the weight. Thus there is almost no memory overhead. See gradient_checkpoint_scope and co. For PyTorch, it is currently unclear how to do this. I moved this over to a separate issue: #1552

@albertz

This comment was marked as duplicate.

@NeoLegends NeoLegends self-assigned this Jun 7, 2024
@NeoLegends

This comment was marked as off-topic.

@NeoLegends

This comment was marked as duplicate.

@albertz

This comment was marked as duplicate.

@NeoLegends

This comment was marked as duplicate.

@albertz

This comment was marked as off-topic.

@albertz
Copy link
Member Author

albertz commented Jun 26, 2024

(Note, I made a separate issue just for the gradient checkpointing aspect in PyTorch: #1552. So this issue here can just focus on the RF specific question on how to implement weight dropout (or also weight noise / variational noise). Edit This is implemented now. See #1559, gradient_checkpoint_scope.)

@albertz albertz changed the title RF weight dropout RF weight dropout and variational noise Jun 26, 2024
@albertz
Copy link
Member Author

albertz commented Jul 9, 2024

So, I tend to reimplement something very similar as the PyTorch parametrization API, and also following some of the internal design choices.

  • I don't want to extend rf.Module. It's also not easily possible anyway, as we don't have such _parameters dict as in torch.nn.Module, but our parameters are simply normal attributes.
  • I want to have it as property, not as buffer. The buffer would take additional memory, which I want to avoid at all cost. (Also all our effort for the gradient checkpointing is to save memory.)
  • -> We also must inject a dummy class into the object, following similar logic as in _inject_new_class and then _inject_property.
    • Note that they don't allow this to be serialized, and have a custom deepcopy function. I think we don't need this. I think serialization of such classes can also work. I need to play around with that later. I don't think we need to care about it for now.
  • I'm not sure yet about the caching logic. I think we can skip this for the beginning. In any case, I'm not sure I would follow the same PyTorch API for this. It could be more automatic in RF. But anyway, to save memory, maybe we don't want this.

Some open questions:

  • In case the parametrization doesn't change the underlying vars, but is just sth optional, applied in training, like weight dropout, variational noise, I would not want that the module parameter list changes. But is this possible?
    • I think yes. Currently RFModuleAsPTModule uses named_parameters, and that iterates through vars(module).items(), i.e. it should actually see the underlying var, if we did not remove it.

@albertz
Copy link
Member Author

albertz commented Jul 10, 2024

I also thought about deriving or extending rf.Parameter. I'm not exactly sure how though. It is currently also a Tensor, and I don't think we can make this dynamically evaluate on reads, without changing Tensor itself. So I think this does not work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants