forked from opendatahub-io/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Kernel] FP8 support for MoE kernel / Mixtral (vllm-project#4244)
This PR is the first step towards fixing vllm-project#3208 It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
- Loading branch information
Showing
10 changed files
with
385 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <torch/extension.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#include <cmath> | ||
|
||
#include "cuda_compat.h" | ||
#include "dispatch_utils.h" | ||
|
||
namespace vllm { | ||
|
||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { | ||
float old; | ||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : | ||
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); | ||
|
||
return old; | ||
} | ||
|
||
// Compute the absolute maximum m of the input tensor and store | ||
// m / float8_e4m3::max() in *scale. Each thread block performs a | ||
// reduction tree and the memory in scale is atomically updated. | ||
// So to get the right answer, *scale needs to be initialized to | ||
// a value <= 0.0 and we need to wait for all thread blocks to | ||
// finish before consuming *scale. | ||
template<typename scalar_t> | ||
__global__ void segmented_max_reduction( | ||
float* __restrict__ scale, | ||
const scalar_t* __restrict__ input, | ||
int64_t num_elems) { | ||
__shared__ float cache[1024]; | ||
int i = blockDim.x * blockIdx.x + threadIdx.x; | ||
|
||
// First store maximum for all values processes by | ||
// the current thread in cache[threadIdx.x] | ||
scalar_t tmp = 0.0; | ||
while (i < num_elems) { | ||
float x = static_cast<float>(input[i]); | ||
tmp = max(tmp, fabs(x)); | ||
i += blockDim.x * gridDim.x; | ||
} | ||
cache[threadIdx.x] = tmp; | ||
|
||
__syncthreads(); | ||
|
||
// Now perform parallel reduction within the thread block | ||
int ib = blockDim.x / 2; | ||
while (ib != 0) { | ||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { | ||
cache[threadIdx.x] = cache[threadIdx.x + ib]; | ||
} | ||
__syncthreads(); | ||
ib /= 2; | ||
} | ||
// Finally, since cache[0] contains the maximum for this thread block, | ||
// atomically write the max to the target location | ||
if (threadIdx.x == 0) { | ||
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max()); | ||
} | ||
} | ||
|
||
template<typename scalar_t> | ||
__global__ void scaled_fp8_quant_kernel( | ||
c10::Float8_e4m3fn* __restrict__ out, | ||
const scalar_t* __restrict__ input, | ||
const float* __restrict__ scale, | ||
int64_t num_elems) { | ||
int i = blockDim.x * blockIdx.x + threadIdx.x; | ||
while (i < num_elems) { | ||
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale); | ||
i += blockDim.x * gridDim.x; | ||
} | ||
} | ||
|
||
} // namespace vllm | ||
|
||
void scaled_fp8_quant( | ||
torch::Tensor& out, // [..., d] | ||
torch::Tensor& input, // [..., d] | ||
torch::Tensor& scale) // [1] | ||
{ | ||
int64_t num_tokens = input.numel() / input.size(-1); | ||
int64_t num_elems = input.numel(); | ||
dim3 grid(num_tokens); | ||
dim3 block(1024); | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
VLLM_DISPATCH_FLOATING_TYPES( | ||
input.scalar_type(), | ||
"scaled_fp8_quant_kernel", | ||
[&] { | ||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>( | ||
scale.data_ptr<float>(), | ||
input.data_ptr<scalar_t>(), | ||
num_elems); | ||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||
out.data_ptr<c10::Float8_e4m3fn>(), | ||
input.data_ptr<scalar_t>(), | ||
scale.data_ptr<float>(), | ||
num_elems); | ||
}); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
146 changes: 146 additions & 0 deletions
146
...r/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
{ | ||
"1": { | ||
"BLOCK_SIZE_M": 16, | ||
"BLOCK_SIZE_N": 32, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 1, | ||
"num_warps": 4, | ||
"num_stages": 4 | ||
}, | ||
"2": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 64, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 1, | ||
"num_warps": 4, | ||
"num_stages": 4 | ||
}, | ||
"4": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 64, | ||
"BLOCK_SIZE_K": 64, | ||
"GROUP_SIZE_M": 64, | ||
"num_warps": 4, | ||
"num_stages": 4 | ||
}, | ||
"8": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 128, | ||
"BLOCK_SIZE_K": 256, | ||
"GROUP_SIZE_M": 64, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"16": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 1, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"24": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 1, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"32": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 128, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 16, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"48": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 128, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 32, | ||
"num_warps": 4, | ||
"num_stages": 4 | ||
}, | ||
"64": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 128, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 16, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"96": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 32, | ||
"BLOCK_SIZE_K": 256, | ||
"GROUP_SIZE_M": 32, | ||
"num_warps": 4, | ||
"num_stages": 4 | ||
}, | ||
"128": { | ||
"BLOCK_SIZE_M": 64, | ||
"BLOCK_SIZE_N": 128, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 16, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"256": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 1, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"512": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 16, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"1024": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 64, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"1536": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 32, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"2048": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 64, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"3072": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 32, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
}, | ||
"4096": { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 256, | ||
"BLOCK_SIZE_K": 128, | ||
"GROUP_SIZE_M": 64, | ||
"num_warps": 8, | ||
"num_stages": 4 | ||
} | ||
} |
Oops, something went wrong.