Skip to content

Commit

Permalink
Optimize GeGLU layer in Gemma (vllm-project#2975)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Feb 22, 2024
1 parent 93dc5a2 commit fd5dcc5
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 77 deletions.
73 changes: 48 additions & 25 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,75 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>

#include <cmath>

#include "cuda_compat.h"
#include "dispatch_utils.h"

namespace vllm {

template<typename T>
__device__ __forceinline__ T silu(const T& x) {
// x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x)));
}

template<typename scalar_t>
__global__ void silu_and_mul_kernel(
// Activation and gating kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
out[token_idx * d + idx] = ACT_FN(x) * y;
}
}

template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x)));
}

template<typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
const float f = (float) x;
constexpr float ALPHA = M_SQRT1_2;
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}

} // namespace vllm

// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"act_and_mul_kernel", \
[&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});

void silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;

dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"silu_and_mul_kernel",
[&] {
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
d);
});
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

void gelu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
}

namespace vllm {
Expand Down
4 changes: 4 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);

void gelu_and_mul(
torch::Tensor& out,
torch::Tensor& input);

void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
Expand Down
4 changes: 4 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU.");
ops.def(
"gelu_new",
&gelu_new,
Expand Down
50 changes: 15 additions & 35 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Type

import pytest
import torch

from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, SiluAndMul)
from allclose_default import get_default_atol, get_default_rtol

DTYPES = [torch.half, torch.bfloat16, torch.float]
Expand All @@ -13,13 +16,15 @@
]


@pytest.mark.parametrize("activation", [SiluAndMul, GeluAndMul])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul(
def test_act_and_mul(
activation: Type[torch.nn.Module],
num_tokens: int,
d: int,
dtype: torch.dtype,
Expand All @@ -31,48 +36,23 @@ def test_silu_and_mul(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
layer = SiluAndMul()
layer = activation()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)


@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gelu_new(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = NewGELU()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_gelu_fast(
def test_activation(
activation: Type[torch.nn.Module],
num_tokens: int,
d: int,
dtype: torch.dtype,
Expand All @@ -84,7 +64,7 @@ def test_gelu_fast(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = FastGELU()
layer = activation()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out,
Expand Down
23 changes: 23 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return out


class GeluAndMul(nn.Module):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""

def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d]) * x[..., d:]

def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.gelu_and_mul(out, x)
return out


class NewGELU(nn.Module):

def _forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
31 changes: 14 additions & 17 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
from transformers import GemmaConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
Expand All @@ -50,27 +51,21 @@ def __init__(
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.gate_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=False,
linear_method=linear_method)
self.up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=False,
linear_method=linear_method)
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
self.act_fn = nn.GELU()
self.act_fn = GeluAndMul()

def forward(self, x):
gate, _ = self.gate_proj(x)
gate = self.act_fn(gate)
up, _ = self.up_proj(x)
fuse = gate * up
outputs, _ = self.down_proj(fuse)
return outputs
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x


class GemmaAttention(nn.Module):
Expand Down Expand Up @@ -294,6 +289,8 @@ def load_weights(self,
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
Expand Down

0 comments on commit fd5dcc5

Please sign in to comment.