Skip to content

Commit

Permalink
[Kernel] add kernel for FATReLU (vllm-project#9610)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: qishuai <[email protected]>
  • Loading branch information
jeejeelee authored and FerdinandZhong committed Oct 29, 2024
1 parent e6c92d9 commit 9f4c844
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 8 deletions.
42 changes: 42 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,48 @@ void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]

namespace vllm {

template <typename T>
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
const float f = (float)x;
return (T)(f > threshold ? f : 0.0f);
}

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x, param) * y;
}
}

} // namespace vllm

#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \
PARAM); \
});

void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
}
namespace vllm {

// Element-wise activation kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
Expand Down
3 changes: 3 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
double threshold);

void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
Expand Down
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);

// FATReLU implementation.
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);

// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
Expand Down
23 changes: 16 additions & 7 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import random
from typing import Type

import pytest
import torch

from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, QuickGELU,
SiluAndMul)
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, NewGELU,
QuickGELU, SiluAndMul)
from vllm.utils import seed_everything

from .allclose_default import get_default_atol, get_default_rtol
Expand All @@ -20,7 +21,8 @@
]


@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
@pytest.mark.parametrize("activation",
["silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
Expand All @@ -47,16 +49,23 @@ def test_act_and_mul(
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
fn = torch.ops._C.gelu_tanh_and_mul
elif activation == "fatrelu":
threshold = random.uniform(0, 1)
layer = FatreluAndMul(threshold)
fn = torch.ops._C.fatrelu_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
# The SiLU, GELU and FatReLU implementations are equivalent to the native
# PyTorch implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
opcheck(fn, (out, x))
if activation == "fatrelu":
opcheck(fn, (out, x, threshold))
else:
opcheck(fn, (out, x))


@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
Expand Down
6 changes: 6 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)


def fatrelu_and_mul(out: torch.Tensor,
x: torch.Tensor,
threshold: float = 0.0) -> None:
torch.ops._C.fatrelu_and_mul(out, x, threshold)


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_fast(out, x)

Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return x1 * x2

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.fatrelu_and_mul(out, x, self.threshold)
return out


@CustomOp.register("silu_and_mul")
Expand Down

0 comments on commit 9f4c844

Please sign in to comment.