diff --git a/CMakeLists.txt b/CMakeLists.txt index b2d0cf3e568b7..4a99985d9abc4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/fp8/fp8_cuda_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") diff --git a/csrc/ops.h b/csrc/ops.h index a379c910d9cf3..ff7a3de1a0a8c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -146,6 +146,11 @@ void gptq_shuffle( torch::Tensor q_perm, int bit); +void scaled_fp8_quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 42e92e5382e8e..a5b16c5abc3ed 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -73,6 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); + ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); ops.def( "moe_align_block_size", &moe_align_block_size, diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu new file mode 100644 index 0000000000000..c3337cede1282 --- /dev/null +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -0,0 +1,103 @@ +#include +#include +#include + +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +namespace vllm { + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : + __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + + return old; +} + +// Compute the absolute maximum m of the input tensor and store +// m / float8_e4m3::max() in *scale. Each thread block performs a +// reduction tree and the memory in scale is atomically updated. +// So to get the right answer, *scale needs to be initialized to +// a value <= 0.0 and we need to wait for all thread blocks to +// finish before consuming *scale. +template +__global__ void segmented_max_reduction( + float* __restrict__ scale, + const scalar_t* __restrict__ input, + int64_t num_elems) { + __shared__ float cache[1024]; + int i = blockDim.x * blockIdx.x + threadIdx.x; + + // First store maximum for all values processes by + // the current thread in cache[threadIdx.x] + scalar_t tmp = 0.0; + while (i < num_elems) { + float x = static_cast(input[i]); + tmp = max(tmp, fabs(x)); + i += blockDim.x * gridDim.x; + } + cache[threadIdx.x] = tmp; + + __syncthreads(); + + // Now perform parallel reduction within the thread block + int ib = blockDim.x / 2; + while (ib != 0) { + if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { + cache[threadIdx.x] = cache[threadIdx.x + ib]; + } + __syncthreads(); + ib /= 2; + } + // Finally, since cache[0] contains the maximum for this thread block, + // atomically write the max to the target location + if (threadIdx.x == 0) { + atomicMaxFloat(scale, cache[0] / std::numeric_limits::max()); + } +} + +template +__global__ void scaled_fp8_quant_kernel( + c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + while (i < num_elems) { + out[i] = static_cast(input[i] / *scale); + i += blockDim.x * gridDim.x; + } +} + +} // namespace vllm + +void scaled_fp8_quant( + torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] +{ + int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_elems = input.numel(); + dim3 grid(num_tokens); + dim3 block(1024); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "scaled_fp8_quant_kernel", + [&] { + vllm::segmented_max_reduction<<>>( + scale.data_ptr(), + input.data_ptr(), + num_elems); + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + scale.data_ptr(), + num_elems); + }); +} + diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a0837a20875fe..e4b16ed918d1a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import torch @@ -153,6 +153,14 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) +# fp8 +def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + vllm_ops.scaled_fp8_quant(output, input, scale) + return output, scale + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..2ad07bf79a25c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 377b6588dbf47..ac7c30e2a9727 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -21,6 +21,8 @@ def fused_moe_kernel( a_ptr, b_ptr, c_ptr, + a_scale_ptr, + b_scale_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, @@ -49,6 +51,7 @@ def fused_moe_kernel( MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, + use_fp8: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -111,6 +114,10 @@ def fused_moe_kernel( b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + if use_fp8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block @@ -129,7 +136,10 @@ def fused_moe_kernel( mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. - accumulator += tl.dot(a, b) + if use_fp8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -140,7 +150,10 @@ def fused_moe_kernel( other=0) accumulator = accumulator * moe_weight[:, None] - accumulator = accumulator.to(compute_type) + if use_fp8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -207,15 +220,24 @@ def moe_align_block_size( def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, + B_scale: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, - config: Dict[str, Any]) -> None: + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8: bool) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + if not use_fp8: + A_scale = None + assert B_scale is None + else: + A, A_scale = ops.scaled_fp8_quant(A) + assert B_scale is not None + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) @@ -223,6 +245,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A, B, C, + A_scale, + B_scale, topk_weights, sorted_token_ids, expert_ids, @@ -240,18 +264,21 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, C.stride(2), MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, - compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, + compute_type=compute_type, + use_fp8=use_fp8, **config, ) -def get_config_file_name(E: int, N: int) -> str: +def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") - return f"E={E},N={N},device_name={device_name}.json" + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" @functools.lru_cache -def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: +def get_moe_configs(E: int, N: int, + dtype: Optional[str]) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -263,7 +290,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N) + json_file_name = get_config_file_name(E, N, dtype) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) @@ -288,6 +315,9 @@ def fused_moe( renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -305,6 +335,12 @@ def fused_moe( Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -358,7 +394,8 @@ def fused_moe( config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2]) + configs = get_moe_configs(E, w2.shape[2], + "float8" if use_fp8 else None) if configs: # If an optimal configuration map has been found, look up the @@ -394,17 +431,37 @@ def fused_moe( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) - invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, - topk_weights, topk_ids, sorted_token_ids, - expert_ids, num_tokens_post_padded, False, - topk_ids.shape[1], config) + invoke_fused_moe_kernel(hidden_states, + w1, + intermediate_cache1, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=tl.float16, + use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, - topk_weights, topk_ids, sorted_token_ids, - expert_ids, num_tokens_post_padded, True, 1, - config) + invoke_fused_moe_kernel(intermediate_cache2, + w2, + intermediate_cache3, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=tl.float16, + use_fp8=use_fp8) if inplace: return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 64cd186506bdb..f75c35a69d4a9 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -232,6 +232,8 @@ def load_model(self, *, model_config: ModelConfig, linear_method = getattr(module, "linear_method", None) if linear_method is not None: linear_method.process_weights_after_loading(module) + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() return model.eval() diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index a0a3b2784614d..f7e0f56c1a46e 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,6 +24,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. if (model_config.quantization is not None + and model_config.quantization != "fp8" and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 4d1755f2bbe63..a33b795d7088e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -39,6 +39,8 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod, + per_tensor_quantize) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -47,6 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput +from vllm.utils import print_warning_once class MixtralMoE(nn.Module): @@ -66,6 +69,7 @@ def __init__( intermediate_size: int, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -73,6 +77,9 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size + # FIXME(pcmoritz): Make this more general to support different + # quantization schemes + self.use_fp8 = isinstance(linear_method, Fp8LinearMethod) if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -97,6 +104,16 @@ def __init__( device="cuda", dtype=self.params_dtype)) + # Scaling factors for FP8 weights + self.ws_scale = nn.Parameter( + torch.ones( + self.num_total_experts, device="cuda", dtype=torch.float32), + requires_grad=False) if self.use_fp8 else None + self.w2s_scale = nn.Parameter( + torch.ones( + self.num_total_experts, device="cuda", dtype=torch.float32), + requires_grad=False) if self.use_fp8 else None + set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, }) @@ -118,6 +135,18 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] + def process_weights_after_loading(self): + if self.use_fp8: + ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) + w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) + for expert in range(self.num_total_experts): + ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize( + self.ws.data[expert, :, :]) + w2s[expert, :, :], self.w2s_scale[ + expert] = per_tensor_quantize(self.w2s.data[expert, :, :]) + self.ws = nn.Parameter(ws, requires_grad=False) + self.w2s = nn.Parameter(w2s, requires_grad=False) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) @@ -129,7 +158,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, self.top_k, renormalize=True, - inplace=True) + inplace=True, + use_fp8=self.use_fp8, + w1_scale=self.ws_scale, + w2_scale=self.w2s_scale) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -171,6 +203,13 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window + if isinstance(linear_method, Fp8LinearMethod): + print_warning_once( + "For Mixtral FP8 quantization, we currently do not quantize " + "the attention layers until their FP8 performance is improved." + ) + linear_method = None + self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, @@ -238,7 +277,8 @@ def __init__( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) + intermediate_size=config.intermediate_size, + linear_method=linear_method) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size,