Skip to content

Commit

Permalink
parametrizations: weight dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jul 11, 2024
1 parent 12a40bb commit 262a06b
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions returnn/frontend/parametrizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Parameterizations using the parametrization API (:func:`register_parametrization`).
Also see:
https://github.com/rwth-i6/returnn/issues/1518
https://pytorch.org/tutorials/intermediate/parametrizations.html
"""

from __future__ import annotations
from returnn.tensor import Tensor
import returnn.frontend as rf


__all__ = ["weight_dropout"]


def weight_dropout(module: rf.Module, param_name: str, *, drop_prob: float) -> rf.Module:
"""
Apply weight dropout to a parameter of a module.
:param module:
:param param_name: name of the parameter
:param drop_prob: dropout probability
:return: module
"""
rf.register_parametrization(module, param_name, _WeightDropout(drop_prob))
return module


class _WeightDropout:
def __init__(self, drop_prob: float):
self.drop_prob = drop_prob

def __call__(self, param: Tensor) -> Tensor:
def _on_train() -> Tensor:
with rf.gradient_checkpoint_scope():
return rf.dropout(param, drop_prob=self.drop_prob, on_forward=True)

return rf.cond(rf.get_run_ctx().train_flag, _on_train, lambda: param)

0 comments on commit 262a06b

Please sign in to comment.