Skip to content

Commit

Permalink
RF gating
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 18, 2023
1 parent ea30a18 commit c1550d2
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion returnn/frontend/math_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from __future__ import annotations
import typing
from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Union, Tuple
import numpy
from returnn.tensor import Tensor, Dim
import returnn.frontend as rf
from .types import RawTensorTypes as _RawTensorTypes

__all__ = [
Expand Down Expand Up @@ -56,6 +57,7 @@
"swish",
"softmax",
"log_softmax",
"gating",
]


Expand Down Expand Up @@ -445,3 +447,25 @@ def log_softmax(a: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor:
"""log_softmax"""
# noinspection PyProtectedMember
return a._raw_backend.log_softmax(a, axis=axis, use_mask=use_mask)


def gating(
x: Tensor, *, axis: Optional[Dim] = None, gate_func=sigmoid, act_func=identity, out_dim: Optional[Dim] = None
) -> Tuple[Tensor, Dim]:
"""
Like in gated linear unit (GLU): https://arxiv.org/abs/1612.08083
GLU refers also to the linear transformation before the gating -- this is why this function is not called GLU.
GLU uses gate_func=sigmoid and act_func=identity (the defaults here).
There are other potential gating variants you might be interested at.
See for example: https://arxiv.org/abs/2002.05202, e.g. gate_func=gelu.
"""
if axis is None:
assert x.feature_dim is not None, f"gating {x}: need tensor with feature dim set, or explicit `axis`"
axis = x.feature_dim
assert axis.is_static() and axis.dimension % 2 == 0, f"gating {x}: need static dim, and even, got {axis}"
if not out_dim:
out_dim = axis.div_left(2)

a, b = rf.split(x, axis=axis, out_dims=[out_dim, out_dim])
return act_func(a) * gate_func(b), out_dim

0 comments on commit c1550d2

Please sign in to comment.