Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose GLU activations as arguments #69

Merged
merged 15 commits into from
Aug 22, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ def _add_network_size_args(parser):
default=PositionEmbeddingType.absolute,
help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.'
)
group.add_argument('--glu-activation', type=str,
choices=["liglu", "geglu", "reglu", "swiglu"],
jaketae marked this conversation as resolved.
Show resolved Hide resolved
help='GLU activations to use.'
)

return parser

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ class _GLUBaseModule(nn.Module):
def __init__(self, activation_fn):
super().__init__()
self.activation_fn = activation_fn

def forward(self, x):
# dim=-1 breaks in jit for pt<1.10
x1, x2 = x.chunk(2, dim=(x.ndim-1))
x1, x2 = x.chunk(2, dim=(x.ndim - 1))
return x1 * self.activation_fn(x2)


Expand Down Expand Up @@ -38,3 +38,11 @@ def __init__(self):
geglu = torch.jit.script(GEGLU())
reglu = torch.jit.script(ReGLU())
swiglu = torch.jit.script(SwiGLU())


GLU_ACTIVATIONS = {
"geglu": geglu,
"liglu": liglu,
"reglu": reglu,
"swiglu": swiglu,
}
5 changes: 4 additions & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import deepspeed

from .glu_activations import GLU_ACTIVATIONS
from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb

# flags required to enable jit fusion kernels
Expand Down Expand Up @@ -76,7 +77,9 @@ def __init__(self, init_method, output_layer_init_method):

self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
if args.openai_gelu:
if args.glu_activation:
self.activation_func = GLU_ACTIVATIONS[args.glu_activation]
elif args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
Expand Down
26 changes: 26 additions & 0 deletions megatron/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import random
from distutils.util import strtobool
from io import StringIO
from packaging import version
from pathlib import Path
from typing import Iterator, Union
from unittest import mock
Expand Down Expand Up @@ -212,6 +213,31 @@ def torch_assert_equal(actual, expected):
torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0)


def is_torch_bf16_available():
# from https://github.com/huggingface/transformers/blob/26eb566e43148c80d0ea098c76c3d128c0281c16/src/transformers/file_utils.py#L301
if is_torch_available():
import torch
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if not version.parse(torch.__version__) >= version.parse("1.09"):
return False
return True
else:
return False


def require_torch_bf16(test_case):
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.9."""
if not is_torch_bf16_available():
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.9")(test_case)
else:
return test_case


def get_tests_dir(append_path=None):
"""
Args:
Expand Down
18 changes: 13 additions & 5 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch.nn import functional as F

from megatron.model.activations import liglu, geglu, reglu, swiglu
from megatron.model.glu_activations import GLU_ACTIVATIONS, geglu, liglu, reglu, swiglu
from megatron.testing_utils import set_seed, torch_assert_equal


Expand All @@ -17,13 +17,13 @@ def setUp(self):
self.num_channels = random.randint(1, 384) * 2
self.x = torch.randn(self.batch_size, self.seq_len, self.num_channels)
self.x1, self.x2 = self.x.chunk(2, dim=-1)
# glu should halve the last dimension
self.output_shape = [self.batch_size, self.seq_len, self.num_channels // 2]

def test_shapes(self):
# glu should halve the last dimension
output_shape = [self.batch_size, self.seq_len, self.num_channels // 2]
for activation_fn in [liglu, geglu, reglu, swiglu]:
for activation_fn in GLU_ACTIVATIONS.values():
output = activation_fn(self.x)
self.assertEqual(list(output.shape), output_shape)
self.assertEqual(list(output.shape), self.output_shape)

def test_liglu(self):
expected = self.x1 * self.x2
Expand All @@ -40,3 +40,11 @@ def test_reglu(self):
def test_swiglu(self):
expected = self.x1 * F.silu(self.x2)
torch_assert_equal(swiglu(self.x), expected)

# from megatron.testing_utils import require_torch_bf16
# @require_torch_bf16
# def test_bf16_jit(self):
# x_bf16 = self.x.to(torch.bfloat16)
# for activation_fn in GLU_ACTIVATIONS.values():
# output = activation_fn(x_bf16)
# self.assertEqual(list(output.shape), self.output_shape)
1 change: 1 addition & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_training_all(self):
--eval-interval 10
--eval-iters 5
--checkpoint-activations
--glu-activation geglu
--exit-interval {exit_interval}

--merge-file {data_dir}/gpt2-tiny-merges.txt
Expand Down