Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] add kernel for FATReLU #9610

Merged
merged 5 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,48 @@ void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]

namespace vllm {

template <typename T>
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
const float f = (float)x;
return (T)(f > threshold ? f : 0.0f);
}

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x, param) * y;
}
}

} // namespace vllm

#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \
PARAM); \
});

void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
}
namespace vllm {

// Element-wise activation kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
Expand Down
3 changes: 3 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
double threshold);

void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
Expand Down
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);

// FATReLU implementation.
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);

// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
Expand Down
23 changes: 16 additions & 7 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import random
from typing import Type

import pytest
import torch

from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, QuickGELU,
SiluAndMul)
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, NewGELU,
QuickGELU, SiluAndMul)
from vllm.utils import seed_everything

from .allclose_default import get_default_atol, get_default_rtol
Expand All @@ -20,7 +21,8 @@
]


@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
@pytest.mark.parametrize("activation",
["silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
Expand All @@ -47,16 +49,23 @@ def test_act_and_mul(
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
fn = torch.ops._C.gelu_tanh_and_mul
elif activation == "fatrelu":
threshold = random.uniform(0, 1)
layer = FatreluAndMul(threshold)
fn = torch.ops._C.fatrelu_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
# The SiLU, GELU and FatReLU implementations are equivalent to the native
# PyTorch implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
opcheck(fn, (out, x))
if activation == "fatrelu":
opcheck(fn, (out, x, threshold))
else:
opcheck(fn, (out, x))


@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
Expand Down
6 changes: 6 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)


def fatrelu_and_mul(out: torch.Tensor,
x: torch.Tensor,
threshold: float = 0.0) -> None:
torch.ops._C.fatrelu_and_mul(out, x, threshold)


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_fast(out, x)

Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return x1 * x2

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.fatrelu_and_mul(out, x, self.threshold)
return out


@CustomOp.register("silu_and_mul")
Expand Down