From f982d57591c8158dc7b61c89a9bce54e109a7e49 Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Wed, 16 Oct 2024 16:14:14 -0700 Subject: [PATCH] Remove GeGLU activation function and golden tests. GeGLU is not a simple activation function, but a gated linear layer used in modern MLPs. Our users are not well served by a baked-in implementation of a linear layer presented as a simple activation function. PiperOrigin-RevId: 686677143 --- CHANGELOG.md | 2 +- flax/linen/__init__.py | 1 - flax/linen/activation.py | 41 ---------------------------- tests/linen/linen_activation_test.py | 32 ++-------------------- 4 files changed, 4 insertions(+), 72 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a298007a8..edeee4155f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ vNext - - - -- +- removed GeGLU simplistic activation, it should be implemented manually. - - - diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 9b80ca3c18..ff4e384acd 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -32,7 +32,6 @@ with_partitioning as with_partitioning, ) from .activation import ( - GeGLU as GeGLU, PReLU as PReLU, celu as celu, elu as elu, diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 3f36bdfcd1..8ccfff0d31 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -98,44 +98,3 @@ def __call__(self, inputs: Array) -> Array: return jnp.where( inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs ) - -class GeGLU(Module): - """Gated Linear Unit with GELU (GeGLU) activation function. - - GeGLU is a Flax layer that combines a linear transformation with a GELU - activation function in a gating mechanism. It is often used in Transformer models - to provide non-linear capabilities while preserving a strong linear component. - - Example usage:: - >>> import flax.linen as nn - - >>> class TransformerBlock(nn.Module): - ... @nn.compact - ... def __call__(self, x): - ... x = nn.Dense(2)(x) - ... x = nn.GeGLU()(x) # initialized - ... return x - - Attributes: - features: the number of output features (default: None). - """ - output_dim: int = -1 - - @compact - def __call__(self, inputs: Array) -> Array: - """Applies the GeGLU activation to the inputs. - - Args: - inputs: the nd-array to apply the GeGLU activation function to. - - Returns: - The transformed input. - """ - if self.output_dim == -1: - output_dim = inputs.shape[-1] - else: - output_dim = self.output_dim - - x = Dense(output_dim * 2)(inputs) - x, gate = x[..., : output_dim], x[..., output_dim :] - return x * gelu(gate) \ No newline at end of file diff --git a/tests/linen/linen_activation_test.py b/tests/linen/linen_activation_test.py index 5f8369c205..6d4d0eb4f1 100644 --- a/tests/linen/linen_activation_test.py +++ b/tests/linen/linen_activation_test.py @@ -14,13 +14,13 @@ """Tests for flax.linen.activation.""" +from absl.testing import absltest +from flax import linen as nn import jax +from jax import random import jax.numpy as jnp import numpy as np -from absl.testing import absltest -from jax import random -from flax import linen as nn # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() @@ -44,32 +44,6 @@ def test_prelu(self): np.testing.assert_array_almost_equal(expected_y, y) np.testing.assert_array_equal(init_negative_slope, expected_negative_slope) - def test_geglu(self): - rng = random.key(0) - x = jnp.array([[0.123,0.234], [0.456,0.789]]) - act = nn.GeGLU() - expected_result = jnp.array([[0.00024275, -0.00208032], - [0.00336634, -0.02307648]]) - y, _ = act.init_with_output(rng, x) - np.testing.assert_array_almost_equal(y, expected_result) - - def test_geglu_with_dim_expansion(self): - rng = random.key(0) - x = jnp.array([[0.123,0.234], [0.456,0.789]]) - act = nn.GeGLU(3) - expected_result = jnp.array([[-0.02157649, -0.00018928, -0.01176354], - [-0.08777858, 0.00258885, -0.18744925]]) - y, _ = act.init_with_output(rng, x) - np.testing.assert_array_almost_equal(y, expected_result) - - def test_geglu_with_dim_contraction(self): - rng = random.key(0) - x = jnp.array([[0.123,0.234], [0.456,0.789]]) - act = nn.GeGLU(1) - expected_result = jnp.array([[0.00224223], [0.0307451 ]]) - y, _ = act.init_with_output(rng, x) - np.testing.assert_array_almost_equal(y, expected_result) - if __name__ == '__main__': absltest.main()