diff --git a/megatron/model/activations.py b/megatron/model/activations.py new file mode 100644 index 000000000..82ccdf098 --- /dev/null +++ b/megatron/model/activations.py @@ -0,0 +1,40 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class _GLUBaseModule(nn.Module): + def __init__(self, activation_fn): + super().__init__() + self.activation_fn = activation_fn + + def forward(self, x): + # dim=-1 breaks in jit for pt<1.10 + x1, x2 = x.chunk(2, dim=(x.ndim-1)) + return x1 * self.activation_fn(x2) + + +class LiGLU(_GLUBaseModule): + def __init__(self): + super().__init__(nn.Identity()) + + +class GEGLU(_GLUBaseModule): + def __init__(self): + super().__init__(F.gelu) + + +class ReGLU(_GLUBaseModule): + def __init__(self): + super().__init__(F.relu) + + +class SwiGLU(_GLUBaseModule): + def __init__(self): + super().__init__(F.silu) + + +liglu = torch.jit.script(LiGLU()) +geglu = torch.jit.script(GEGLU()) +reglu = torch.jit.script(ReGLU()) +swiglu = torch.jit.script(SwiGLU()) diff --git a/tests/test_activations.py b/tests/test_activations.py new file mode 100644 index 000000000..6d1f8ab1d --- /dev/null +++ b/tests/test_activations.py @@ -0,0 +1,43 @@ +import random +import unittest + +import torch +from torch.nn import functional as F + +from megatron.model.activations import liglu, geglu, reglu, swiglu + +from .utils import set_seed + + +class TestActivations(unittest.TestCase): + def setUp(self): + """setup an input of reasonable size""" + set_seed() + self.batch_size = random.randint(2, 64) + self.seq_len = random.randint(256, 1025) + self.num_channels = random.randint(1, 384) * 2 + self.x = torch.randn(self.batch_size, self.seq_len, self.num_channels) + self.x1, self.x2 = self.x.chunk(2, dim=-1) + + def test_shapes(self): + # glu should halve the last dimension + output_shape = [self.batch_size, self.seq_len, self.num_channels // 2] + for activation_fn in [liglu, geglu, reglu, swiglu]: + output = activation_fn(self.x) + self.assertEqual(list(output.shape), output_shape) + + def test_liglu(self): + expected = self.x1 * self.x2 + torch.testing.assert_equal(liglu(self.x), expected) + + def test_geglu(self): + expected = self.x1 * F.gelu(self.x2) + torch.testing.assert_equal(geglu(self.x), expected) + + def test_reglu(self): + expected = self.x1 * F.relu(self.x2) + torch.testing.assert_equal(reglu(self.x), expected) + + def test_swiglu(self): + expected = self.x1 * F.silu(self.x2) + torch.testing.assert_equal(swiglu(self.x), expected) \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..623897aa8 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,10 @@ +import random + +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) \ No newline at end of file