Skip to content

Commit

Permalink
Remove GeGLU activation function and golden tests.
Browse files Browse the repository at this point in the history
GeGLU is not a simple activation function, but a gated linear layer used in modern MLPs.
Our users are not well served by a baked-in implementation of a linear layer presented as
a simple activation function.

PiperOrigin-RevId: 686677143
  • Loading branch information
levskaya authored and Flax Authors committed Oct 24, 2024
1 parent c12ba53 commit f982d57
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 72 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ vNext
-
-
-
-
- removed GeGLU simplistic activation, it should be implemented manually.
-
-
-
Expand Down
1 change: 0 additions & 1 deletion flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
with_partitioning as with_partitioning,
)
from .activation import (
GeGLU as GeGLU,
PReLU as PReLU,
celu as celu,
elu as elu,
Expand Down
41 changes: 0 additions & 41 deletions flax/linen/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,44 +98,3 @@ def __call__(self, inputs: Array) -> Array:
return jnp.where(
inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs
)

class GeGLU(Module):
"""Gated Linear Unit with GELU (GeGLU) activation function.
GeGLU is a Flax layer that combines a linear transformation with a GELU
activation function in a gating mechanism. It is often used in Transformer models
to provide non-linear capabilities while preserving a strong linear component.
Example usage::
>>> import flax.linen as nn
>>> class TransformerBlock(nn.Module):
... @nn.compact
... def __call__(self, x):
... x = nn.Dense(2)(x)
... x = nn.GeGLU()(x) # initialized
... return x
Attributes:
features: the number of output features (default: None).
"""
output_dim: int = -1

@compact
def __call__(self, inputs: Array) -> Array:
"""Applies the GeGLU activation to the inputs.
Args:
inputs: the nd-array to apply the GeGLU activation function to.
Returns:
The transformed input.
"""
if self.output_dim == -1:
output_dim = inputs.shape[-1]
else:
output_dim = self.output_dim

x = Dense(output_dim * 2)(inputs)
x, gate = x[..., : output_dim], x[..., output_dim :]
return x * gelu(gate)
32 changes: 3 additions & 29 deletions tests/linen/linen_activation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

"""Tests for flax.linen.activation."""

from absl.testing import absltest
from flax import linen as nn
import jax
from jax import random
import jax.numpy as jnp
import numpy as np
from absl.testing import absltest
from jax import random

from flax import linen as nn

# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()
Expand All @@ -44,32 +44,6 @@ def test_prelu(self):
np.testing.assert_array_almost_equal(expected_y, y)
np.testing.assert_array_equal(init_negative_slope, expected_negative_slope)

def test_geglu(self):
rng = random.key(0)
x = jnp.array([[0.123,0.234], [0.456,0.789]])
act = nn.GeGLU()
expected_result = jnp.array([[0.00024275, -0.00208032],
[0.00336634, -0.02307648]])
y, _ = act.init_with_output(rng, x)
np.testing.assert_array_almost_equal(y, expected_result)

def test_geglu_with_dim_expansion(self):
rng = random.key(0)
x = jnp.array([[0.123,0.234], [0.456,0.789]])
act = nn.GeGLU(3)
expected_result = jnp.array([[-0.02157649, -0.00018928, -0.01176354],
[-0.08777858, 0.00258885, -0.18744925]])
y, _ = act.init_with_output(rng, x)
np.testing.assert_array_almost_equal(y, expected_result)

def test_geglu_with_dim_contraction(self):
rng = random.key(0)
x = jnp.array([[0.123,0.234], [0.456,0.789]])
act = nn.GeGLU(1)
expected_result = jnp.array([[0.00224223], [0.0307451 ]])
y, _ = act.init_with_output(rng, x)
np.testing.assert_array_almost_equal(y, expected_result)


if __name__ == '__main__':
absltest.main()

0 comments on commit f982d57

Please sign in to comment.