Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] FP8 support for MoE kernel / Mixtral #4244

Merged
merged 53 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
ab7963e
add initial single block kernel
pcmoritz Apr 17, 2024
45225aa
update
pcmoritz Apr 17, 2024
69b52cc
use blocks
pcmoritz Apr 17, 2024
dd6f680
fix
pcmoritz Apr 17, 2024
cb89c0f
update
pcmoritz Apr 17, 2024
4351703
port fp8 code
pcmoritz Apr 17, 2024
c303674
config
pcmoritz Apr 17, 2024
267f856
update
pcmoritz Apr 17, 2024
d85fb1a
custom ops
pcmoritz Apr 17, 2024
96e3f8b
update
pcmoritz Apr 17, 2024
0690411
update
pcmoritz Apr 18, 2024
130899b
fix initialization
pcmoritz Apr 18, 2024
0a10737
add fp8_silu_and_mul_kernel
pcmoritz Apr 18, 2024
ab9fec4
update
pcmoritz Apr 18, 2024
10a5697
fix
pcmoritz Apr 18, 2024
c89d2a8
fix
pcmoritz Apr 18, 2024
9435467
convert in kernel
pcmoritz Apr 18, 2024
609f493
cleanup
pcmoritz Apr 19, 2024
d790697
conversion
pcmoritz Apr 19, 2024
400a7e1
update
pcmoritz Apr 20, 2024
4b2c8f4
update
pcmoritz Apr 20, 2024
cc2a488
update
pcmoritz Apr 20, 2024
dc6add9
update
pcmoritz Apr 20, 2024
0af9edc
update
pcmoritz Apr 20, 2024
f2a934d
update
pcmoritz Apr 20, 2024
ce663ec
update
pcmoritz Apr 20, 2024
4047a93
Merge branch 'main' into mixtral-fp8-final
pcmoritz Apr 20, 2024
77bdc3e
update
pcmoritz Apr 20, 2024
bb123dd
Use MoE for fp8 quant
pcmoritz Apr 20, 2024
d212d2d
fix
pcmoritz Apr 20, 2024
88c02ea
clean up
pcmoritz Apr 20, 2024
11e1f01
fix
pcmoritz Apr 20, 2024
5fa1dcf
update
pcmoritz Apr 21, 2024
a0e4003
update
pcmoritz Apr 21, 2024
c0bfdba
format
pcmoritz Apr 21, 2024
7c4ee35
spelling
pcmoritz Apr 21, 2024
d4ea8b7
Update vllm/model_executor/layers/fused_moe/fused_moe.py
pcmoritz Apr 22, 2024
57235c5
Update vllm/model_executor/models/mixtral.py
pcmoritz Apr 22, 2024
188314d
update
pcmoritz Apr 22, 2024
d20e5e9
add fixme
pcmoritz Apr 22, 2024
aedd33d
update
pcmoritz Apr 22, 2024
4aa77c9
keep fused_moe interface
pcmoritz Apr 22, 2024
69ad2dc
typo
pcmoritz Apr 22, 2024
bae81d3
fixloading config file
pcmoritz Apr 22, 2024
b733cea
update
pcmoritz Apr 23, 2024
d53b1fc
update
pcmoritz Apr 23, 2024
5ef2ee9
update
pcmoritz Apr 23, 2024
8807300
fix
pcmoritz Apr 23, 2024
a15a7b5
update
pcmoritz Apr 23, 2024
8fd40c1
format
pcmoritz Apr 23, 2024
0f93811
align
pcmoritz Apr 23, 2024
fbbfc61
rerun ci
pcmoritz Apr 23, 2024
725270e
rerun ci
pcmoritz Apr 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);

void scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);

void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
Expand Down
1 change: 1 addition & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
Expand Down
103 changes: 103 additions & 0 deletions csrc/quantization/fp8/fp8_cuda_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>

#include <cmath>

#include "cuda_compat.h"
#include "dispatch_utils.h"

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));

return old;
}

// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template<typename scalar_t>
__global__ void segmented_max_reduction(
float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int i = blockDim.x * blockIdx.x + threadIdx.x;

// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;

__syncthreads();

// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
}
}

template<typename scalar_t>
__global__ void scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
while (i < num_elems) {
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
i += blockDim.x * gridDim.x;
}
}

} // namespace vllm

void scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(),
input.data_ptr<scalar_t>(),
num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}

10 changes: 9 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, Optional, Tuple

import torch

Expand Down Expand Up @@ -153,6 +153,14 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k)


# fp8
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
vllm_ops.scaled_fp8_quant(output, input, scale)
return output, scale


# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
}
}
Loading
Loading