Skip to content

Commit

Permalink
Add GLU variants (#47)
Browse files Browse the repository at this point in the history
* feat: add glu variant activations

* fix: rm extraneous parentheses

* feat: rm bias to support jit

* fix: replace negative dim with explicit dim

* fix: use `x.ndim` for generic dim handling

* docs: add note on version for posterity

Co-authored-by: Stas Bekman <[email protected]>

* docs: specify jit in `x.ndim` comment

Co-authored-by: Stas Bekman <[email protected]>

* test: add simple tests to check activations

* fix: use `torch.testing` for tensor checks

* test: use seed-controlled random batch inputs

Co-authored-by: Stas Bekman <[email protected]>
  • Loading branch information
jaketae and stas00 authored Aug 8, 2021
1 parent 5e3963d commit effb2fb
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
40 changes: 40 additions & 0 deletions megatron/model/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch import nn
from torch.nn import functional as F


class _GLUBaseModule(nn.Module):
def __init__(self, activation_fn):
super().__init__()
self.activation_fn = activation_fn

def forward(self, x):
# dim=-1 breaks in jit for pt<1.10
x1, x2 = x.chunk(2, dim=(x.ndim-1))
return x1 * self.activation_fn(x2)


class LiGLU(_GLUBaseModule):
def __init__(self):
super().__init__(nn.Identity())


class GEGLU(_GLUBaseModule):
def __init__(self):
super().__init__(F.gelu)


class ReGLU(_GLUBaseModule):
def __init__(self):
super().__init__(F.relu)


class SwiGLU(_GLUBaseModule):
def __init__(self):
super().__init__(F.silu)


liglu = torch.jit.script(LiGLU())
geglu = torch.jit.script(GEGLU())
reglu = torch.jit.script(ReGLU())
swiglu = torch.jit.script(SwiGLU())
43 changes: 43 additions & 0 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import random
import unittest

import torch
from torch.nn import functional as F

from megatron.model.activations import liglu, geglu, reglu, swiglu

from .utils import set_seed


class TestActivations(unittest.TestCase):
def setUp(self):
"""setup an input of reasonable size"""
set_seed()
self.batch_size = random.randint(2, 64)
self.seq_len = random.randint(256, 1025)
self.num_channels = random.randint(1, 384) * 2
self.x = torch.randn(self.batch_size, self.seq_len, self.num_channels)
self.x1, self.x2 = self.x.chunk(2, dim=-1)

def test_shapes(self):
# glu should halve the last dimension
output_shape = [self.batch_size, self.seq_len, self.num_channels // 2]
for activation_fn in [liglu, geglu, reglu, swiglu]:
output = activation_fn(self.x)
self.assertEqual(list(output.shape), output_shape)

def test_liglu(self):
expected = self.x1 * self.x2
torch.testing.assert_equal(liglu(self.x), expected)

def test_geglu(self):
expected = self.x1 * F.gelu(self.x2)
torch.testing.assert_equal(geglu(self.x), expected)

def test_reglu(self):
expected = self.x1 * F.relu(self.x2)
torch.testing.assert_equal(reglu(self.x), expected)

def test_swiglu(self):
expected = self.x1 * F.silu(self.x2)
torch.testing.assert_equal(swiglu(self.x), expected)
10 changes: 10 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import random

import numpy as np
import torch


def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

0 comments on commit effb2fb

Please sign in to comment.