diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 5ba9ab178d5a4..22b10f0571d1c 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -2,19 +2,16 @@ #include #include +#include + #include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { -template -__device__ __forceinline__ T silu(const T& x) { - // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); -} - -template -__global__ void silu_and_mul_kernel( +// Activation and gating kernel template. +template +__global__ void act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const int d) { @@ -22,32 +19,58 @@ __global__ void silu_and_mul_kernel( for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); - out[token_idx * d + idx] = silu(x) * y; + out[token_idx * d + idx] = ACT_FN(x) * y; } } +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + return (T) (((float) x) / (1.0f + expf((float) -x))); +} + +template +__device__ __forceinline__ T gelu_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'none' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38 + const float f = (float) x; + constexpr float ALPHA = M_SQRT1_2; + return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); +} + } // namespace vllm +// Launch activation and gating kernel. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), \ + "act_and_mul_kernel", \ + [&] { \ + vllm::act_and_mul_kernel><<>>( \ + out.data_ptr(), \ + input.data_ptr(), \ + d); \ + }); + void silu_and_mul( torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - int64_t num_tokens = input.numel() / input.size(-1); - int d = input.size(-1) / 2; - - dim3 grid(num_tokens); - dim3 block(std::min(d, 1024)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "silu_and_mul_kernel", - [&] { - vllm::silu_and_mul_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - d); - }); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +} + +void gelu_and_mul( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); } namespace vllm { diff --git a/csrc/ops.h b/csrc/ops.h index 2bcd0c2efc5c6..dbdd2c2c57945 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -57,6 +57,10 @@ void silu_and_mul( torch::Tensor& out, torch::Tensor& input); +void gelu_and_mul( + torch::Tensor& out, + torch::Tensor& input); + void gelu_new( torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index b36d259697167..24c22020131e8 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + ops.def( + "gelu_and_mul", + &gelu_and_mul, + "Activation function used in GeGLU."); ops.def( "gelu_new", &gelu_new, diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 8e216c293f070..e0dec144eba11 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -1,7 +1,10 @@ +from typing import Type + import pytest import torch -from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul +from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, + NewGELU, SiluAndMul) from allclose_default import get_default_atol, get_default_rtol DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -13,13 +16,15 @@ ] +@pytest.mark.parametrize("activation", [SiluAndMul, GeluAndMul]) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_silu_and_mul( +def test_act_and_mul( + activation: Type[torch.nn.Module], num_tokens: int, d: int, dtype: torch.dtype, @@ -31,48 +36,23 @@ def test_silu_and_mul( torch.cuda.manual_seed(seed) torch.set_default_device(device) x = torch.randn(num_tokens, 2 * d, dtype=dtype) - layer = SiluAndMul() + layer = activation() out = layer(x) ref_out = layer._forward(x) - assert torch.allclose(out, - ref_out, - atol=get_default_atol(out), - rtol=get_default_rtol(out)) + # The SiLU and GELU implementations are equivalent to the native PyTorch + # implementations, so we can do exact comparison. + assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0) +@pytest.mark.parametrize("activation", [FastGELU, NewGELU]) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_gelu_new( - num_tokens: int, - d: int, - dtype: torch.dtype, - seed: int, - device: str, -) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.set_default_device(device) - x = torch.randn(num_tokens, d, dtype=dtype) - layer = NewGELU() - out = layer(x) - ref_out = layer._forward(x) - assert torch.allclose(out, - ref_out, - atol=get_default_atol(out), - rtol=get_default_rtol(out)) - - -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("d", D) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_gelu_fast( +def test_activation( + activation: Type[torch.nn.Module], num_tokens: int, d: int, dtype: torch.dtype, @@ -84,7 +64,7 @@ def test_gelu_fast( torch.cuda.manual_seed(seed) torch.set_default_device(device) x = torch.randn(num_tokens, d, dtype=dtype) - layer = FastGELU() + layer = activation() out = layer(x) ref_out = layer._forward(x) assert torch.allclose(out, diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 95902ae38e256..5a3a7b2dbaee7 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -37,6 +37,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out +class GeluAndMul(nn.Module): + """An activation function for GeGLU. + + The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) + return: (batch_size, seq_len, d) or (num_tokens, d) + """ + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.gelu(x[..., :d]) * x[..., d:] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + ops.gelu_and_mul(out, x) + return out + + class NewGELU(nn.Module): def _forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 03bd149c001d3..d8b515993d8ff 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -21,10 +21,11 @@ from transformers import GemmaConfig from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -50,27 +51,21 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_proj = ColumnParallelLinear(hidden_size, - intermediate_size, - bias=False, - linear_method=linear_method) - self.up_proj = ColumnParallelLinear(hidden_size, - intermediate_size, - bias=False, - linear_method=linear_method) + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, linear_method=linear_method) - self.act_fn = nn.GELU() + self.act_fn = GeluAndMul() def forward(self, x): - gate, _ = self.gate_proj(x) - gate = self.act_fn(gate) - up, _ = self.up_proj(x) - fuse = gate * up - outputs, _ = self.down_proj(fuse) - return outputs + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x class GemmaAttention(nn.Module): @@ -294,6 +289,8 @@ def load_weights(self, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params = set()