From ab7963eb2b0eb3d1f3fe2fca92eb97923edf358f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 13:40:09 -0700 Subject: [PATCH 01/52] add initial single block kernel --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 72 +++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 csrc/quantization/fp8/fp8_cuda_kernels.cu diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu new file mode 100644 index 0000000000000..cff5084042657 --- /dev/null +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -0,0 +1,72 @@ +#include +#include +#include + +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +namespace vllm { + +template +__global__ void scaled_fp8_quant_kernel( + c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { + __shared__ float cache[1024]; + int i = blockDim.x * blockIdx.x + threadIdx.x; + int cacheIndex = threadIdx.x; + + scalar_t tmp = 0.0; + while (i < num_elems) { + float x = static_cast(input[i]); + tmp = max(tmp, fabs(x)); + i += blockDim.x * gridDim.x; + } + + cache[cacheIndex] = tmp; + + __syncthreads(); + + // perform parallel reduction + int ib = blockDim.x / 2; + while (ib != 0) { + if (cacheIndex < ib && cache[cacheIndex + ib] > cache[cacheIndex]) { + cache[cacheIndex] = cache[cacheIndex + ib]; + } + __syncthreads(); + ib /= 2; + } + // now cache[0] contains the maximum, rescale the numbers + i = blockDim.x * blockIdx.x + threadIdx.x; + while (i < num_elems) { + out[i] = static_cast(input[i] / cache[0]); + i += blockDim.x * gridDim.x; + } +} + +} // namespace vllm + +void scaled_fp8_quant( + torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scales) // [d] +{ + int64_t num_elems = input.numel(); + dim3 grid(1); + dim3 block(1024); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES_FP8( + input.scalar_type(), + "scaled_fp8_quant_kernel", + [&] { + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + scales.data_ptr(), + num_elems); + }); +} \ No newline at end of file From 45225aafc7903eceaa442ada44e9d5729ff94fb8 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 13:43:33 -0700 Subject: [PATCH 02/52] update --- CMakeLists.txt | 1 + csrc/ops.h | 5 +++++ csrc/pybind.cpp | 1 + csrc/quantization/fp8/fp8_cuda_kernels.cu | 2 +- 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1845151181284..6ea9c5a829601 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/fp8/fp8_cuda_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") diff --git a/csrc/ops.h b/csrc/ops.h index 41ecc1e89371b..3446318571528 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -131,6 +131,11 @@ void gptq_shuffle( torch::Tensor q_perm, int bit); +void scaled_fp8_quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scales); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index de02afc162113..8c083c9e60fe5 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); + ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Scale tensor and quantize to FP8"); ops.def( "moe_align_block_size", &moe_align_block_size, diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index cff5084042657..d489fe5e7fc4a 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -59,7 +59,7 @@ void scaled_fp8_quant( dim3 block(1024); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES_FP8( + VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel", [&] { From 69b52cc2a48f8489b839046c6dbfa70f39206765 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 14:07:48 -0700 Subject: [PATCH 03/52] use blocks --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 36 +++++++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index d489fe5e7fc4a..c191f08738bb7 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -9,9 +9,16 @@ namespace vllm { +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : + __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + + return old; +} + template -__global__ void scaled_fp8_quant_kernel( - c10::Float8_e4m3fn* __restrict__ out, +__global__ void segmented_max_reduction( const scalar_t* __restrict__ input, const float* __restrict__ scale, int64_t num_elems) { @@ -39,10 +46,22 @@ __global__ void scaled_fp8_quant_kernel( __syncthreads(); ib /= 2; } - // now cache[0] contains the maximum, rescale the numbers - i = blockDim.x * blockIdx.x + threadIdx.x; + // now cache[0] contains the maximum for this thread block, + // atomically write the max to the target location + if (cacheIndex == 0) { + atomicMaxFloat(scale, cache[0]) + } +} + +template +__global__ void scaled_fp8_quant_kernel( + c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { + int i = blockDim.x * blockIdx.x + threadIdx.x; while (i < num_elems) { - out[i] = static_cast(input[i] / cache[0]); + out[i] = static_cast(input[i] / *scale); i += blockDim.x * gridDim.x; } } @@ -54,8 +73,9 @@ void scaled_fp8_quant( torch::Tensor& input, // [..., d] torch::Tensor& scales) // [d] { + int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); - dim3 grid(1); + dim3 grid(num_tokens); dim3 block(1024); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -63,6 +83,10 @@ void scaled_fp8_quant( input.scalar_type(), "scaled_fp8_quant_kernel", [&] { + vllm::segmented_max_reduction<<>>( + input.data_ptr(), + scales.data_ptr(), + num_elems); vllm::scaled_fp8_quant_kernel<<>>( out.data_ptr(), input.data_ptr(), From dd6f680c4ccd835d2c20d0a30513ec8a173f9f6d Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 14:14:39 -0700 Subject: [PATCH 04/52] fix --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index c191f08738bb7..502eba44c2254 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -49,7 +49,7 @@ __global__ void segmented_max_reduction( // now cache[0] contains the maximum for this thread block, // atomically write the max to the target location if (cacheIndex == 0) { - atomicMaxFloat(scale, cache[0]) + atomicMaxFloat(scale, cache[0]); } } From cb89c0fd71536cd4581c0237307dedaec3202c19 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 14:22:41 -0700 Subject: [PATCH 05/52] update --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 502eba44c2254..eaaac496251e2 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -19,8 +19,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { template __global__ void segmented_max_reduction( + float* __restrict__ scale, const scalar_t* __restrict__ input, - const float* __restrict__ scale, int64_t num_elems) { __shared__ float cache[1024]; int i = blockDim.x * blockIdx.x + threadIdx.x; @@ -84,8 +84,8 @@ void scaled_fp8_quant( "scaled_fp8_quant_kernel", [&] { vllm::segmented_max_reduction<<>>( - input.data_ptr(), scales.data_ptr(), + input.data_ptr(), num_elems); vllm::scaled_fp8_quant_kernel<<>>( out.data_ptr(), From 43517030d959fd8eb984aadbe368e2f4dd54249a Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 15:03:02 -0700 Subject: [PATCH 06/52] port fp8 code --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 2 +- .../layers/fused_moe/fused_moe.py | 39 +++++++++++++++---- vllm/model_executor/models/mixtral.py | 31 ++++++++++++++- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index eaaac496251e2..3b71facd4724c 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -71,7 +71,7 @@ __global__ void scaled_fp8_quant_kernel( void scaled_fp8_quant( torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] - torch::Tensor& scales) // [d] + torch::Tensor& scale) // [d] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 377b6588dbf47..345d23e28df31 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -21,6 +21,8 @@ def fused_moe_kernel( a_ptr, b_ptr, c_ptr, + w_scale_ptr, + a_scale_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, @@ -111,6 +113,9 @@ def fused_moe_kernel( b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + w_scale = tl.load(w_scale_ptr + off_experts) + a_scale = tl.load(a_scale_ptr) + # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block @@ -129,7 +134,7 @@ def fused_moe_kernel( mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. - accumulator += tl.dot(a, b) + accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=True) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -140,7 +145,7 @@ def fused_moe_kernel( other=0) accumulator = accumulator * moe_weight[:, None] - accumulator = accumulator.to(compute_type) + accumulator = (accumulator * w_scale * a_scale).to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -207,12 +212,13 @@ def moe_align_block_size( def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + w_scale: torch.Tensor, a_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, - config: Dict[str, Any]) -> None: + config: Dict[str, Any], compute_type: tl.dtype) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -223,6 +229,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A, B, C, + w_scale, + a_scale, topk_weights, sorted_token_ids, expert_ids, @@ -240,7 +248,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, C.stride(2), MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, - compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, + compute_type=compute_type, **config, ) @@ -283,6 +291,8 @@ def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, @@ -381,30 +391,45 @@ def fused_moe( 'GROUP_SIZE_M': 1 } + intermediate_cache0 = torch.empty(hidden_states.shape, + device=hidden_states.device, + dtype=torch.float8_e4m3fn) intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype) + intermediate_cache2_scaled = torch.empty((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=torch.float8_e4m3fn) intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) + a_scale = torch.empty(1, device=hidden_states.device, dtype=torch.float32) + a2_scale = torch.empty(1, device=hidden_states.device, dtype=torch.float32) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) - invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, + ops.scaled_fp8_quant(intermediate_cache0, hidden_states, a_scale) + + invoke_fused_moe_kernel(intermediate_cache0, w1, intermediate_cache1, + w1_scale, a_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, False, - topk_ids.shape[1], config) + topk_ids.shape[1], config, compute_type=tl.float16) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + ops.scaled_fp8_quant(intermediate_cache2_scaled, intermediate_cache2, a2_scale) + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, + w2_scale, a2_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, True, 1, - config) + config, compute_type=tl.float16) if inplace: return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 4d1755f2bbe63..608eec02b2217 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -89,13 +89,19 @@ def __init__( 2 * self.intermediate_size, self.hidden_size, device="cuda", - dtype=self.params_dtype)) + dtype=dtype=torch.float8_e4m3fn)) self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, device="cuda", - dtype=self.params_dtype)) + dtype=dtype=torch.float8_e4m3fn)) + + # Scaling factors for fp8 weights + self.ws_scale = nn.Parameter( + torch.ones(self.num_total_experts, device="cuda", dtype=torch.float32)) + self.w2s_scale = nn.Parameter( + torch.ones(self.num_total_experts, device="cuda", dtype=torch.float32)) set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, @@ -104,6 +110,13 @@ def __init__( "weight_loader": self.weight_loader, }) + set_weight_attrs(self.ws_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s_scale, { + "weight_loader": self.weight_loader, + }) + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): tp_rank = get_tensor_model_parallel_rank() @@ -117,6 +130,10 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] + # For loading weight scales + if "scales" in weight_name: + param_data[expert_id] = loaded_weight + print("loaded scale", weight_name, loaded_weight.shape) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape @@ -126,6 +143,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = fused_moe(hidden_states, self.ws, self.w2s, + self.ws_scale, + self.s2s_scale, router_logits, self.top_k, renormalize=True, @@ -403,11 +422,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] expert_params_mapping = [ + # These are the weights for the experts # (param_name, weight_name, expert_id) ("ws" if weight_name in ["w1", "w3"] else "w2s", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id) + ("ws_scale" if weight_name in ["w1", "w3"] else "w2s_scale", + f"scales.{expert_id}.{weight_name}", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) From c303674ac534196c6ee47dbafab540929284c1cc Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 15:04:19 -0700 Subject: [PATCH 07/52] config --- ...168,device_name=NVIDIA_H100_80GB_HBM3.json | 102 +++++++++--------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json index e341a67917d51..2ad07bf79a25c 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json @@ -1,81 +1,81 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 4 }, "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "8": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4 }, "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 4 }, "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 4 }, "32": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 4 }, "48": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, "64": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 4 }, "96": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, @@ -83,63 +83,63 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 4 }, "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 4 }, "512": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, - "num_warps": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4 }, "2048": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4 }, "4096": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4 } From 267f856af806889920e9c531bd07cec76e10a3e9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 15:16:45 -0700 Subject: [PATCH 08/52] update --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 3b71facd4724c..fbf9aa24dcc61 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -84,13 +84,13 @@ void scaled_fp8_quant( "scaled_fp8_quant_kernel", [&] { vllm::segmented_max_reduction<<>>( - scales.data_ptr(), + scale.data_ptr(), input.data_ptr(), num_elems); vllm::scaled_fp8_quant_kernel<<>>( out.data_ptr(), input.data_ptr(), - scales.data_ptr(), + scale.data_ptr(), num_elems); }); } \ No newline at end of file From d85fb1aaf979edcf4541f9694ba069e729f0f1c6 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 15:29:09 -0700 Subject: [PATCH 09/52] custom ops --- vllm/_custom_ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a0837a20875fe..4eae031dac5d1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -152,6 +152,9 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, size_n, size_k) +# fp8 +def scaled_fp8_quant(out: torch.Tensor, input: torch.Tensor, scale: torch.Tensor): + return vllm_ops.scaled_fp8_quant(out, input, scale) # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, From 96e3f8b56f9b99823a73341c337163f58efdb579 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 15:36:51 -0700 Subject: [PATCH 10/52] update --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/models/mixtral.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 345d23e28df31..45fb0ba43d3bf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -425,7 +425,7 @@ def fused_moe( ops.scaled_fp8_quant(intermediate_cache2_scaled, intermediate_cache2, a2_scale) - invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, + invoke_fused_moe_kernel(intermediate_cache2_scaled, w2, intermediate_cache3, w2_scale, a2_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, True, 1, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 608eec02b2217..196546ac1e098 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -89,13 +89,13 @@ def __init__( 2 * self.intermediate_size, self.hidden_size, device="cuda", - dtype=dtype=torch.float8_e4m3fn)) + dtype=torch.float8_e4m3fn)) self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, device="cuda", - dtype=dtype=torch.float8_e4m3fn)) + dtype=torch.float8_e4m3fn)) # Scaling factors for fp8 weights self.ws_scale = nn.Parameter( @@ -144,7 +144,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self.ws, self.w2s, self.ws_scale, - self.s2s_scale, + self.w2s_scale, router_logits, self.top_k, renormalize=True, From 0690411f898eab38509d58c3dfed088bb7ed3efd Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 18:46:48 -0700 Subject: [PATCH 11/52] update --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index fbf9aa24dcc61..9106bf5f94c01 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -49,7 +49,7 @@ __global__ void segmented_max_reduction( // now cache[0] contains the maximum for this thread block, // atomically write the max to the target location if (cacheIndex == 0) { - atomicMaxFloat(scale, cache[0]); + atomicMaxFloat(scale, cache[0] / 448.0); } } From 130899b834d018104b577da8d9bf0f739024f74a Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 20:13:40 -0700 Subject: [PATCH 12/52] fix initialization --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 45fb0ba43d3bf..0f98666442d37 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -407,8 +407,8 @@ def fused_moe( device=hidden_states.device, dtype=hidden_states.dtype) - a_scale = torch.empty(1, device=hidden_states.device, dtype=torch.float32) - a2_scale = torch.empty(1, device=hidden_states.device, dtype=torch.float32) + a_scale = torch.zeros(1, device=hidden_states.device, dtype=torch.float32) + a2_scale = torch.zeros(1, device=hidden_states.device, dtype=torch.float32) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) From 0a10737380bc988982c07dc809e24993be8402ca Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 21:09:44 -0700 Subject: [PATCH 13/52] add fp8_silu_and_mul_kernel --- csrc/ops.h | 7 +++- csrc/pybind.cpp | 3 +- csrc/quantization/fp8/fp8_cuda_kernels.cu | 48 +++++++++++++++++++++-- vllm/_custom_ops.py | 3 ++ 4 files changed, 56 insertions(+), 5 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 3446318571528..831ef7934cc76 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -134,7 +134,12 @@ void gptq_shuffle( void scaled_fp8_quant( torch::Tensor& out, torch::Tensor& input, - torch::Tensor& scales); + torch::Tensor& scale); + +void fp8_silu_and_mul_kernel( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); void moe_align_block_size( torch::Tensor topk_ids, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 8c083c9e60fe5..5048aef96ce22 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -71,7 +71,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Scale tensor and quantize to FP8"); + ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); + ops.def("fp8_silu_and_mul_kernel", &fp8_silu_and_mul_kernel, "Compute FP8 silu_and_mul and scaling factor"); ops.def( "moe_align_block_size", &moe_align_block_size, diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 9106bf5f94c01..ff5b22fe4fe54 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -49,7 +49,7 @@ __global__ void segmented_max_reduction( // now cache[0] contains the maximum for this thread block, // atomically write the max to the target location if (cacheIndex == 0) { - atomicMaxFloat(scale, cache[0] / 448.0); + atomicMaxFloat(scale, cache[0] / std::numeric_limits::max()); } } @@ -66,12 +66,27 @@ __global__ void scaled_fp8_quant_kernel( } } +template +__global__ void fp8_silu_and_mul_kernel( + c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const float x = (float) input[token_idx * 2 * d + idx]; + const float y = (float) input[token_idx * 2 * d + d + idx]; + float r = silu_kernel(x) * y; + out[token_idx * d + idx] = static_cast(r / *scale); + } +} + } // namespace vllm void scaled_fp8_quant( torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [d] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -93,4 +108,31 @@ void scaled_fp8_quant( scale.data_ptr(), num_elems); }); -} \ No newline at end of file +} + +void fp8_silu_and_mul_kernel( + torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& scale) // [1] +{ + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + dim3 block(1024); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + out.scalar_type(), + "scaled_silu_and_mul_kernel", + [&] { + vllm::segmented_max_reduction<<>>( + scale.data_ptr(), + input.data_ptr(), + input.numel()); + vllm::scaled_silu_and_mul_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + scale.data_ptr(), + d); + }); +} diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4eae031dac5d1..9480be2fbbde5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -156,6 +156,9 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def scaled_fp8_quant(out: torch.Tensor, input: torch.Tensor, scale: torch.Tensor): return vllm_ops.scaled_fp8_quant(out, input, scale) +def fp8_silu_and_mul_kernel(out: torch.Tensor, input: torch.Tensor, scale: torch.Tensor): + return vllm_ops.fp8_silu_and_mul_kernel(out, input, scale) + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, From ab9fec433350fa53ad3287dcf44d3a335127fc3f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 21:17:16 -0700 Subject: [PATCH 14/52] update --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index ff5b22fe4fe54..bee075a965694 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -66,6 +66,12 @@ __global__ void scaled_fp8_quant_kernel( } } +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + return (T) (((float) x) / (1.0f + expf((float) -x))); +} + template __global__ void fp8_silu_and_mul_kernel( c10::Float8_e4m3fn* __restrict__ out, From 10a5697ca712f9304e2691abbd9d15c2cf0b9c2e Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 21:23:55 -0700 Subject: [PATCH 15/52] fix --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index bee075a965694..9e4354d70833c 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -129,13 +129,13 @@ void fp8_silu_and_mul_kernel( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( out.scalar_type(), - "scaled_silu_and_mul_kernel", + "fp8_silu_and_mul_kernel_kernel", [&] { vllm::segmented_max_reduction<<>>( scale.data_ptr(), input.data_ptr(), input.numel()); - vllm::scaled_silu_and_mul_kernel<<>>( + vllm::fp8_silu_and_mul_kernel<<>>( out.data_ptr(), input.data_ptr(), scale.data_ptr(), From c89d2a83d175105b7e34d056cd47a6d87e465cf9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Apr 2024 21:33:32 -0700 Subject: [PATCH 16/52] fix --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 9e4354d70833c..903aa8d924096 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -128,7 +128,7 @@ void fp8_silu_and_mul_kernel( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - out.scalar_type(), + input.scalar_type(), "fp8_silu_and_mul_kernel_kernel", [&] { vllm::segmented_max_reduction<<>>( From 9435467b4a8f5bf6dafd844befeb398dbc856ed4 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 18 Apr 2024 12:31:06 -0700 Subject: [PATCH 17/52] convert in kernel --- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0f98666442d37..a40e0457a9c60 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -133,8 +133,9 @@ def fused_moe_kernel( b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + scaled_a = (a / a_scale).to(tl.float8e4nv) # We accumulate along the K dimension. - accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=True) + accumulator = tl.dot(scaled_a, b, acc=accumulator, allow_tf32=True) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk From 609f493802bde2c40c1f73969aa0d51e2c2c98b7 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 19 Apr 2024 12:54:09 -0700 Subject: [PATCH 18/52] cleanup --- csrc/ops.h | 5 -- csrc/pybind.cpp | 1 - csrc/quantization/fp8/fp8_cuda_kernels.cu | 69 ++++--------------- vllm/_custom_ops.py | 3 - .../layers/fused_moe/fused_moe.py | 3 +- 5 files changed, 15 insertions(+), 66 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 831ef7934cc76..5bd272f7c3ea3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -136,11 +136,6 @@ void scaled_fp8_quant( torch::Tensor& input, torch::Tensor& scale); -void fp8_silu_and_mul_kernel( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& scale); - void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 5048aef96ce22..9c0979e0d701e 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -72,7 +72,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); - ops.def("fp8_silu_and_mul_kernel", &fp8_silu_and_mul_kernel, "Compute FP8 silu_and_mul and scaling factor"); ops.def( "moe_align_block_size", &moe_align_block_size, diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 903aa8d924096..7518e319fd702 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -17,6 +17,12 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { return old; } +// Compute the absolute maximum m of the input tensor and store +// m / float8_e4m3::max() in *scale. Each thread block performs a +// reduction tree and the memory in scale is atomically updated. +// So to get the right answer, *scale needs to be initialized to +// a value <= 0.0 and we need to wait for all thread blocks to +// finishe before consuming *scale. template __global__ void segmented_max_reduction( float* __restrict__ scale, @@ -24,31 +30,31 @@ __global__ void segmented_max_reduction( int64_t num_elems) { __shared__ float cache[1024]; int i = blockDim.x * blockIdx.x + threadIdx.x; - int cacheIndex = threadIdx.x; + // First store maximum for all values processes by + // the current thread in cache[threadIdx.x] scalar_t tmp = 0.0; while (i < num_elems) { float x = static_cast(input[i]); tmp = max(tmp, fabs(x)); i += blockDim.x * gridDim.x; } - - cache[cacheIndex] = tmp; + cache[threadIdx.x] = tmp; __syncthreads(); - // perform parallel reduction + // Now perform parallel reduction within the thread block int ib = blockDim.x / 2; while (ib != 0) { - if (cacheIndex < ib && cache[cacheIndex + ib] > cache[cacheIndex]) { - cache[cacheIndex] = cache[cacheIndex + ib]; + if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { + cache[threadIdx.x] = cache[threadIdx.x + ib]; } __syncthreads(); ib /= 2; } - // now cache[0] contains the maximum for this thread block, + // Finally, since cache[0] contains the maximum for this thread block, // atomically write the max to the target location - if (cacheIndex == 0) { + if (threadIdx.x == 0) { atomicMaxFloat(scale, cache[0] / std::numeric_limits::max()); } } @@ -66,27 +72,6 @@ __global__ void scaled_fp8_quant_kernel( } } -template -__device__ __forceinline__ T silu_kernel(const T& x) { - // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); -} - -template -__global__ void fp8_silu_and_mul_kernel( - c10::Float8_e4m3fn* __restrict__ out, - const scalar_t* __restrict__ input, - const float* __restrict__ scale, - const int d) { - const int64_t token_idx = blockIdx.x; - for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const float x = (float) input[token_idx * 2 * d + idx]; - const float y = (float) input[token_idx * 2 * d + d + idx]; - float r = silu_kernel(x) * y; - out[token_idx * d + idx] = static_cast(r / *scale); - } -} - } // namespace vllm void scaled_fp8_quant( @@ -116,29 +101,3 @@ void scaled_fp8_quant( }); } -void fp8_silu_and_mul_kernel( - torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., 2 * d] - torch::Tensor& scale) // [1] -{ - int d = input.size(-1) / 2; - int64_t num_tokens = input.numel() / input.size(-1); - dim3 grid(num_tokens); - dim3 block(1024); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "fp8_silu_and_mul_kernel_kernel", - [&] { - vllm::segmented_max_reduction<<>>( - scale.data_ptr(), - input.data_ptr(), - input.numel()); - vllm::fp8_silu_and_mul_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - scale.data_ptr(), - d); - }); -} diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9480be2fbbde5..4eae031dac5d1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -156,9 +156,6 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def scaled_fp8_quant(out: torch.Tensor, input: torch.Tensor, scale: torch.Tensor): return vllm_ops.scaled_fp8_quant(out, input, scale) -def fp8_silu_and_mul_kernel(out: torch.Tensor, input: torch.Tensor, scale: torch.Tensor): - return vllm_ops.fp8_silu_and_mul_kernel(out, input, scale) - # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a40e0457a9c60..0f98666442d37 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -133,9 +133,8 @@ def fused_moe_kernel( b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - scaled_a = (a / a_scale).to(tl.float8e4nv) # We accumulate along the K dimension. - accumulator = tl.dot(scaled_a, b, acc=accumulator, allow_tf32=True) + accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=True) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk From d790697c7c9b2eb0af53e7f4e9a309983b533c4b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 19 Apr 2024 14:56:11 -0700 Subject: [PATCH 19/52] conversion --- vllm/model_executor/models/mixtral.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 196546ac1e098..fd057dc154459 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -89,13 +89,13 @@ def __init__( 2 * self.intermediate_size, self.hidden_size, device="cuda", - dtype=torch.float8_e4m3fn)) + dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, device="cuda", - dtype=torch.float8_e4m3fn)) + dtype=self.params_dtype)) # Scaling factors for fp8 weights self.ws_scale = nn.Parameter( @@ -121,6 +121,12 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): tp_rank = get_tensor_model_parallel_rank() param_data = param.data + # First we check if the parameters in the checkpoint have a different + # dtype than the native dtype of this model -- this is for example + # the case if we want to use FP8 for the MoE layer and FP16 for the + # rest of the model. If this happens, we convert the dtype. + if param_data.dtype != loaded_weight.dtype: + param = param.to(loaded_weight.dtype) shard_size = self.intermediate_size shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) if weight_name.endswith("w1.weight"): From 400a7e133ac7a24ad5a3830839db5c7eb71b2423 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 19 Apr 2024 17:38:57 -0700 Subject: [PATCH 20/52] update --- vllm/model_executor/models/mixtral.py | 36 ++++++++------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index fd057dc154459..77c9db25a03f5 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -66,6 +66,7 @@ def __init__( intermediate_size: int, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, + use_fp8: bool = True, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -73,6 +74,7 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size + self.use_fp8 = use_fp8 if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -110,23 +112,10 @@ def __init__( "weight_loader": self.weight_loader, }) - set_weight_attrs(self.ws_scale, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s_scale, { - "weight_loader": self.weight_loader, - }) - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): tp_rank = get_tensor_model_parallel_rank() param_data = param.data - # First we check if the parameters in the checkpoint have a different - # dtype than the native dtype of this model -- this is for example - # the case if we want to use FP8 for the MoE layer and FP16 for the - # rest of the model. If this happens, we convert the dtype. - if param_data.dtype != loaded_weight.dtype: - param = param.to(loaded_weight.dtype) shard_size = self.intermediate_size shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) if weight_name.endswith("w1.weight"): @@ -136,10 +125,15 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] - # For loading weight scales - if "scales" in weight_name: - param_data[expert_id] = loaded_weight - print("loaded scale", weight_name, loaded_weight.shape) + + def process_weights_after_loading(self): + if self.use_fp8: + qws, ws_scale = per_tensor_quantize(self.ws.data) + self.ws = nn.Parameter(qws, requires_grad=False) + self.ws_scale.data.copy_(ws_scale) + qw2s, w2s_scale = per_tensor_quantize(self.w2s.data) + self.w2s = nn.Parameter(qw2s, requires_grad=False) + self.w2s_scale.data.copy_(w2s_scale) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape @@ -428,19 +422,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] expert_params_mapping = [ - # These are the weights for the experts # (param_name, weight_name, expert_id) ("ws" if weight_name in ["w1", "w3"] else "w2s", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] - ] + [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ("ws_scale" if weight_name in ["w1", "w3"] else "w2s_scale", - f"scales.{expert_id}.{weight_name}", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) From 4b2c8f424d36ecdaae7708d9c20ca4f68aa8e383 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 19 Apr 2024 17:46:07 -0700 Subject: [PATCH 21/52] update --- vllm/model_executor/model_loader/loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 3b1d125ef8a67..dddddfac35021 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -228,6 +228,10 @@ def load_model(self, *, model_config: ModelConfig, model, "fall_back_to_pt_during_load", True)), ) + + for _, module in model.named_modules(): + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() return model.eval() From cc2a488c1bc5a005e8b42dfba4f35c8f9fa1d8ec Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 19 Apr 2024 17:51:10 -0700 Subject: [PATCH 22/52] update --- vllm/model_executor/models/mixtral.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 77c9db25a03f5..1f7154ed99944 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -49,6 +49,31 @@ from vllm.sequence import SamplerOutput +# Temporary until https://github.com/vllm-project/vllm/pull/4118 is merged + +def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]: + """Quantize a tensor using per-tensor static scaling factor. + Args: + tensor: The input tensor. + """ + finfo = torch.finfo(torch.float8_e4m3fn) + # Calculate the scale as dtype max divided by absmax. + # Since .abs() creates a new tensor, we use aminmax to get + # the min and max first and then calculate the absmax. + min_val, max_val = tensor.aminmax() + amax = min_val.abs().max(max_val.abs()) + scale = finfo.max / amax.clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(torch.float8_e4m3fn) + scale = scale.float().reciprocal() + return qweight, scale + + class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. From dc6add9c2b6e8b4f09ad860e4e840559c2f7f213 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 19 Apr 2024 18:07:09 -0700 Subject: [PATCH 23/52] update --- vllm/model_executor/models/mixtral.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 1f7154ed99944..b7ff843ebf048 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -153,12 +153,13 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def process_weights_after_loading(self): if self.use_fp8: - qws, ws_scale = per_tensor_quantize(self.ws.data) - self.ws = nn.Parameter(qws, requires_grad=False) - self.ws_scale.data.copy_(ws_scale) - qw2s, w2s_scale = per_tensor_quantize(self.w2s.data) - self.w2s = nn.Parameter(qw2s, requires_grad=False) - self.w2s_scale.data.copy_(w2s_scale) + ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) + w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) + for expert in range(self.num_total_experts): + ws[expert,:,:], self.ws_scale[expert] = per_tensor_quantize(self.ws.data[expert,:,:]) + w2s[expert,:,:], self.w2s_scale[expert] = per_tensor_quantize(self.w2s.data[expert,:,:]) + self.ws = nn.Parameter(ws, requires_grad=False) + self.w2s = nn.Parameter(w2s, requires_grad=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape From 0af9edc0026847b330b345aab77788056a5229a8 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 19 Apr 2024 18:12:11 -0700 Subject: [PATCH 24/52] update --- vllm/model_executor/models/mixtral.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index b7ff843ebf048..d2818da9f10d7 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -126,9 +126,11 @@ def __init__( # Scaling factors for fp8 weights self.ws_scale = nn.Parameter( - torch.ones(self.num_total_experts, device="cuda", dtype=torch.float32)) + torch.ones(self.num_total_experts,device="cuda", dtype=torch.float32), + requires_grad=False) self.w2s_scale = nn.Parameter( - torch.ones(self.num_total_experts, device="cuda", dtype=torch.float32)) + torch.ones(self.num_total_experts, device="cuda", dtype=torch.float32), + requires_grad=False) set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, From f2a934d745713c97fcc9a37236cfcf36e06c6062 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 19 Apr 2024 18:31:22 -0700 Subject: [PATCH 25/52] update --- .../layers/fused_moe/fused_moe.py | 25 +++++++++++++------ vllm/model_executor/models/mixtral.py | 3 ++- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0f98666442d37..6faed5c8af822 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -51,6 +51,7 @@ def fused_moe_kernel( MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, + use_fp8: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -113,8 +114,9 @@ def fused_moe_kernel( b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - w_scale = tl.load(w_scale_ptr + off_experts) - a_scale = tl.load(a_scale_ptr) + if use_fp8: + w_scale = tl.load(w_scale_ptr + off_experts) + a_scale = tl.load(a_scale_ptr) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -134,7 +136,10 @@ def fused_moe_kernel( mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. - accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=True) + if use_fp8: + accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=True) + else: + accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -145,7 +150,10 @@ def fused_moe_kernel( other=0) accumulator = accumulator * moe_weight[:, None] - accumulator = (accumulator * w_scale * a_scale).to(compute_type) + if use_fp8: + accumulator = (accumulator * w_scale * a_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -218,7 +226,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, - config: Dict[str, Any], compute_type: tl.dtype) -> None: + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8: bool) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -249,6 +258,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, + use_fp8=use_fp8, **config, ) @@ -298,6 +308,7 @@ def fused_moe( renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -419,7 +430,7 @@ def fused_moe( w1_scale, a_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, False, - topk_ids.shape[1], config, compute_type=tl.float16) + topk_ids.shape[1], config, compute_type=tl.float16, use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -429,7 +440,7 @@ def fused_moe( w2_scale, a2_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, True, 1, - config, compute_type=tl.float16) + config, compute_type=tl.float16, use_fp8=use_fp8) if inplace: return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d2818da9f10d7..0d5cf677c4ec4 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -176,7 +176,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, self.top_k, renormalize=True, - inplace=True) + inplace=True, + use_fp8=self.use_fp8) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( From ce663ec11cfa37db503af70f56d57b964c46ac75 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 19 Apr 2024 19:18:25 -0700 Subject: [PATCH 26/52] update --- .../layers/fused_moe/fused_moe.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6faed5c8af822..0c5888250ce0f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -402,42 +402,42 @@ def fused_moe( 'GROUP_SIZE_M': 1 } - intermediate_cache0 = torch.empty(hidden_states.shape, - device=hidden_states.device, - dtype=torch.float8_e4m3fn) intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache2_scaled = torch.empty((M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=torch.float8_e4m3fn) intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) - a_scale = torch.zeros(1, device=hidden_states.device, dtype=torch.float32) + a1_scale = torch.zeros(1, device=hidden_states.device, dtype=torch.float32) a2_scale = torch.zeros(1, device=hidden_states.device, dtype=torch.float32) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) - ops.scaled_fp8_quant(intermediate_cache0, hidden_states, a_scale) + if use_fp8: + a1 = torch.empty_like(hidden_states, dtype=torch.float8_e4m3fn) + ops.scaled_fp8_quant(a1, hidden_states, a1_scale) + else: + a1 = hidden_states - invoke_fused_moe_kernel(intermediate_cache0, w1, intermediate_cache1, - w1_scale, a_scale, + invoke_fused_moe_kernel(a1, w1, intermediate_cache1, w1_scale, a1_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, False, topk_ids.shape[1], config, compute_type=tl.float16, use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - ops.scaled_fp8_quant(intermediate_cache2_scaled, intermediate_cache2, a2_scale) + if use_fp8: + a2 = torch.empty_like(intermediate_cache2, dtype=torch.float8_e4m3fn) + ops.scaled_fp8_quant(a2, intermediate_cache2, a2_scale) + else: + a2 = intermediate_cache2 - invoke_fused_moe_kernel(intermediate_cache2_scaled, w2, intermediate_cache3, - w2_scale, a2_scale, + invoke_fused_moe_kernel(a2, w2, intermediate_cache3, w2_scale, a2_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, True, 1, config, compute_type=tl.float16, use_fp8=use_fp8) From 77bdc3e8df376381c2427649988b6ccd06baa1ff Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 20 Apr 2024 14:07:39 -0700 Subject: [PATCH 27/52] update --- vllm/model_executor/models/mixtral.py | 30 ++++----------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0d5cf677c4ec4..d85e00fe0baae 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -39,6 +39,8 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, per_tensor_quantize) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -49,31 +51,6 @@ from vllm.sequence import SamplerOutput -# Temporary until https://github.com/vllm-project/vllm/pull/4118 is merged - -def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]: - """Quantize a tensor using per-tensor static scaling factor. - Args: - tensor: The input tensor. - """ - finfo = torch.finfo(torch.float8_e4m3fn) - # Calculate the scale as dtype max divided by absmax. - # Since .abs() creates a new tensor, we use aminmax to get - # the min and max first and then calculate the absmax. - min_val, max_val = tensor.aminmax() - amax = min_val.abs().max(max_val.abs()) - scale = finfo.max / amax.clamp(min=1e-12) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(torch.float8_e4m3fn) - scale = scale.float().reciprocal() - return qweight, scale - - class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -286,7 +263,8 @@ def __init__( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) + intermediate_size=config.intermediate_size, + use_fp8=isinstance(linear_method, Fp8LinearMethod)) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, From bb123dd2b7707813a3c7e6ff9db319c951258d3b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 20 Apr 2024 14:36:04 -0700 Subject: [PATCH 28/52] Use MoE for fp8 quant --- vllm/model_executor/model_loader/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index a0a3b2784614d..f7e0f56c1a46e 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,6 +24,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. if (model_config.quantization is not None + and model_config.quantization != "fp8" and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] From d212d2dc836ff3bb0b271f0f1a0a1832c4ba26ef Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 20 Apr 2024 14:54:14 -0700 Subject: [PATCH 29/52] fix --- vllm/model_executor/models/mixtral.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d85e00fe0baae..0a7ba82bef5c1 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -196,6 +196,12 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window + if isinstance(linear_method, Fp8LinearMethod): + # If we are using FP8, we currently do not want to + # use quantize the attention layers until we improve + # the performance and make sure the accuracy is good. + linear_method = None + self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, From 88c02ea35415c7965655dab305d6716362be8234 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 20 Apr 2024 15:30:21 -0700 Subject: [PATCH 30/52] clean up --- vllm/_custom_ops.py | 9 +++-- .../layers/fused_moe/fused_moe.py | 36 +++++++------------ vllm/model_executor/models/mixtral.py | 2 +- 3 files changed, 20 insertions(+), 27 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4eae031dac5d1..681cf144c2788 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import torch @@ -153,8 +153,11 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) # fp8 -def scaled_fp8_quant(out: torch.Tensor, input: torch.Tensor, scale: torch.Tensor): - return vllm_ops.scaled_fp8_quant(out, input, scale) +def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + vllm_ops.scaled_fp8_quant(output, input, scale) + return output, scale # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0c5888250ce0f..4eba23beca016 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -21,8 +21,8 @@ def fused_moe_kernel( a_ptr, b_ptr, c_ptr, - w_scale_ptr, a_scale_ptr, + b_scale_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, @@ -115,8 +115,8 @@ def fused_moe_kernel( offs_bn[None, :] * stride_bn) if use_fp8: - w_scale = tl.load(w_scale_ptr + off_experts) a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -151,7 +151,7 @@ def fused_moe_kernel( accumulator = accumulator * moe_weight[:, None] if use_fp8: - accumulator = (accumulator * w_scale * a_scale).to(compute_type) + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -220,7 +220,7 @@ def moe_align_block_size( def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - w_scale: torch.Tensor, a_scale: torch.Tensor, + B_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, @@ -231,6 +231,11 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + if use_fp8: + A, A_scale = ops.scaled_fp8_quant(A) + else: + A_scale = None + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) @@ -238,8 +243,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A, B, C, - w_scale, - a_scale, + B_scale, + A_scale, topk_weights, sorted_token_ids, expert_ids, @@ -412,32 +417,17 @@ def fused_moe( device=hidden_states.device, dtype=hidden_states.dtype) - a1_scale = torch.zeros(1, device=hidden_states.device, dtype=torch.float32) - a2_scale = torch.zeros(1, device=hidden_states.device, dtype=torch.float32) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) - if use_fp8: - a1 = torch.empty_like(hidden_states, dtype=torch.float8_e4m3fn) - ops.scaled_fp8_quant(a1, hidden_states, a1_scale) - else: - a1 = hidden_states - - invoke_fused_moe_kernel(a1, w1, intermediate_cache1, w1_scale, a1_scale, + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, w1_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, False, topk_ids.shape[1], config, compute_type=tl.float16, use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - if use_fp8: - a2 = torch.empty_like(intermediate_cache2, dtype=torch.float8_e4m3fn) - ops.scaled_fp8_quant(a2, intermediate_cache2, a2_scale) - else: - a2 = intermediate_cache2 - - invoke_fused_moe_kernel(a2, w2, intermediate_cache3, w2_scale, a2_scale, + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, w2_scale, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, True, 1, config, compute_type=tl.float16, use_fp8=use_fp8) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0a7ba82bef5c1..6728cf43a2332 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -198,7 +198,7 @@ def __init__(self, if isinstance(linear_method, Fp8LinearMethod): # If we are using FP8, we currently do not want to - # use quantize the attention layers until we improve + # quantize the attention layers until we improve # the performance and make sure the accuracy is good. linear_method = None From 11e1f01e7ce2a97477a837de924e83d54b5d2ba0 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 20 Apr 2024 15:34:47 -0700 Subject: [PATCH 31/52] fix --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4eba23beca016..9f7d69abe2283 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -243,8 +243,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A, B, C, - B_scale, A_scale, + B_scale, topk_weights, sorted_token_ids, expert_ids, From 5fa1dcfd351b3afd7dd9755a91a30f774b35c426 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 21 Apr 2024 12:40:33 -0700 Subject: [PATCH 32/52] update --- ...vice_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json} | 0 vllm/model_executor/layers/fused_moe/fused_moe.py | 11 ++++++----- 2 files changed, 6 insertions(+), 5 deletions(-) rename vllm/model_executor/layers/fused_moe/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json => E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json} (100%) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json similarity index 100% rename from vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json rename to vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9f7d69abe2283..70bd1c5dcb140 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -268,13 +268,14 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ) -def get_config_file_name(E: int, N: int) -> str: +def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") - return f"E={E},N={N},device_name={device_name}.json" + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" @functools.lru_cache -def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: +def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -286,7 +287,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N) + json_file_name = get_config_file_name(E, N, dtype) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) @@ -384,7 +385,7 @@ def fused_moe( config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2]) + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) if configs: # If an optimal configuration map has been found, look up the From a0e40030b46ce03edb8a4e29f536af20f78b03f4 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 21 Apr 2024 12:45:55 -0700 Subject: [PATCH 33/52] update --- ...168,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..e341a67917d51 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} From c0bfdba6b1a7c33fd666011277df5b41b6d7ae53 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 21 Apr 2024 13:38:10 -0700 Subject: [PATCH 34/52] format --- vllm/_custom_ops.py | 2 + .../layers/fused_moe/fused_moe.py | 46 ++++++++++++++----- vllm/model_executor/models/mixtral.py | 24 ++++++---- 3 files changed, 50 insertions(+), 22 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 681cf144c2788..e4b16ed918d1a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -152,6 +152,7 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, size_n, size_k) + # fp8 def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: scale = torch.zeros(1, device=input.device, dtype=torch.float32) @@ -159,6 +160,7 @@ def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: vllm_ops.scaled_fp8_quant(output, input, scale) return output, scale + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 70bd1c5dcb140..c33b56b915d6b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -220,8 +220,8 @@ def moe_align_block_size( def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - B_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, + B_scale: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, @@ -275,7 +275,8 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs(E: int, N: int, + dtype: Optional[str]) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -385,7 +386,8 @@ def fused_moe( config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) + configs = get_moe_configs(E, w2.shape[2], + "float8" if use_fp8 else None) if configs: # If an optimal configuration map has been found, look up the @@ -421,17 +423,37 @@ def fused_moe( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) - invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, w1_scale, - topk_weights, topk_ids, sorted_token_ids, - expert_ids, num_tokens_post_padded, False, - topk_ids.shape[1], config, compute_type=tl.float16, use_fp8=use_fp8) + invoke_fused_moe_kernel(hidden_states, + w1, + intermediate_cache1, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=tl.float16, + use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, w2_scale, - topk_weights, topk_ids, sorted_token_ids, - expert_ids, num_tokens_post_padded, True, 1, - config, compute_type=tl.float16, use_fp8=use_fp8) + invoke_fused_moe_kernel(intermediate_cache2, + w2, + intermediate_cache3, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=tl.float16, + use_fp8=use_fp8) if inplace: return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 6728cf43a2332..94f7dfb8d68d8 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -39,8 +39,8 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.fp8 import ( - Fp8LinearMethod, per_tensor_quantize) +from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod, + per_tensor_quantize) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -102,12 +102,14 @@ def __init__( dtype=self.params_dtype)) # Scaling factors for fp8 weights - self.ws_scale = nn.Parameter( - torch.ones(self.num_total_experts,device="cuda", dtype=torch.float32), - requires_grad=False) - self.w2s_scale = nn.Parameter( - torch.ones(self.num_total_experts, device="cuda", dtype=torch.float32), - requires_grad=False) + self.ws_scale = nn.Parameter(torch.ones(self.num_total_experts, + device="cuda", + dtype=torch.float32), + requires_grad=False) + self.w2s_scale = nn.Parameter(torch.ones(self.num_total_experts, + device="cuda", + dtype=torch.float32), + requires_grad=False) set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, @@ -135,8 +137,10 @@ def process_weights_after_loading(self): ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) for expert in range(self.num_total_experts): - ws[expert,:,:], self.ws_scale[expert] = per_tensor_quantize(self.ws.data[expert,:,:]) - w2s[expert,:,:], self.w2s_scale[expert] = per_tensor_quantize(self.w2s.data[expert,:,:]) + ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize( + self.ws.data[expert, :, :]) + w2s[expert, :, :], self.w2s_scale[ + expert] = per_tensor_quantize(self.w2s.data[expert, :, :]) self.ws = nn.Parameter(ws, requires_grad=False) self.w2s = nn.Parameter(w2s, requires_grad=False) From 7c4ee357c9e25bf5647245701f7ecbe7daccb0e7 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 21 Apr 2024 13:40:06 -0700 Subject: [PATCH 35/52] spelling --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 7518e319fd702..055f69341456c 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -22,7 +22,7 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { // reduction tree and the memory in scale is atomically updated. // So to get the right answer, *scale needs to be initialized to // a value <= 0.0 and we need to wait for all thread blocks to -// finishe before consuming *scale. +// finish before consuming *scale. template __global__ void segmented_max_reduction( float* __restrict__ scale, From d4ea8b71da8bdc2a6f8b2087ca1b50971b24dbd2 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Apr 2024 11:06:10 -0700 Subject: [PATCH 36/52] Update vllm/model_executor/layers/fused_moe/fused_moe.py Co-authored-by: Cody Yu --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c33b56b915d6b..41d9e109ea5c5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -270,7 +270,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") - dtype_selector = "" if not dtype else f",dtype={dtype}" + dtype_selector = "" if not dtype else f",{dtype=}" return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" From 57235c507bbd1585af8b76afd6796213b5702827 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Apr 2024 11:06:27 -0700 Subject: [PATCH 37/52] Update vllm/model_executor/models/mixtral.py Co-authored-by: Cody Yu --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 94f7dfb8d68d8..e48632a4e55e5 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -101,7 +101,7 @@ def __init__( device="cuda", dtype=self.params_dtype)) - # Scaling factors for fp8 weights + # Scaling factors for fp8 weights. If fp8 is not used, these parameters are ... self.ws_scale = nn.Parameter(torch.ones(self.num_total_experts, device="cuda", dtype=torch.float32), From 188314d1be09928007f0c63cba0616821a800d8b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Apr 2024 11:35:07 -0700 Subject: [PATCH 38/52] update --- vllm/model_executor/models/mixtral.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e48632a4e55e5..a91d9d3b550bf 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -101,7 +101,8 @@ def __init__( device="cuda", dtype=self.params_dtype)) - # Scaling factors for fp8 weights. If fp8 is not used, these parameters are ... + # Scaling factors for fp8 weights. If fp8 is not used, these parameters + # are 1.0 so no rescaling will happen. self.ws_scale = nn.Parameter(torch.ones(self.num_total_experts, device="cuda", dtype=torch.float32), From d20e5e9c90e0581c0cff7edb77e453520a2af844 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Apr 2024 11:54:10 -0700 Subject: [PATCH 39/52] add fixme --- vllm/model_executor/models/mixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a91d9d3b550bf..f5be7a4ae7eb9 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -202,8 +202,8 @@ def __init__(self, self.sliding_window = sliding_window if isinstance(linear_method, Fp8LinearMethod): - # If we are using FP8, we currently do not want to - # quantize the attention layers until we improve + # FIXME(pcmoritz): If we are using FP8, we currently do + # not want to quantize the attention layers until we improve # the performance and make sure the accuracy is good. linear_method = None From aedd33d7e1d2910d8e5ad83b4199cc6a0a4f9029 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Apr 2024 12:02:28 -0700 Subject: [PATCH 40/52] update --- vllm/model_executor/models/mixtral.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index f5be7a4ae7eb9..11237109b3cab 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -68,7 +68,7 @@ def __init__( intermediate_size: int, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, - use_fp8: bool = True, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -76,7 +76,9 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size - self.use_fp8 = use_fp8 + # FIXME(pcmoritz): Make this more general to support different + # quantization schemes + self.use_fp8 = isinstance(linear_method, Fp8LinearMethod) if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -275,7 +277,7 @@ def __init__( top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - use_fp8=isinstance(linear_method, Fp8LinearMethod)) + linear_method=linear_method) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, From 4aa77c91b2bf4aef30bf48d692c86704684ee1e0 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Apr 2024 12:11:16 -0700 Subject: [PATCH 41/52] keep fused_moe interface --- vllm/model_executor/layers/fused_moe/fused_moe.py | 10 ++++++++-- vllm/model_executor/models/mixtral.py | 6 +++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 41d9e109ea5c5..06a58331c64b0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -308,14 +308,14 @@ def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -333,6 +333,12 @@ def fused_moe( Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. Returns: - torch.Tensor: The output tensor after applying the MoE layer. diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 11237109b3cab..a8d5841ca7b10 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -155,13 +155,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = fused_moe(hidden_states, self.ws, self.w2s, - self.ws_scale, - self.w2s_scale, router_logits, self.top_k, renormalize=True, inplace=True, - use_fp8=self.use_fp8) + use_fp8=self.use_fp8, + w1_scale=self.ws_scale, + w2_scale=self.w2s_scale) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( From 69ad2dcbff1c88584c849bb6a9c844690e674ac9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Apr 2024 12:12:18 -0700 Subject: [PATCH 42/52] typo --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 06a58331c64b0..f9a387a7c450c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -315,7 +315,7 @@ def fused_moe( override_config: Optional[Dict[str, Any]] = None, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, - w2_scale: optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of From bae81d31e4da86ac803691d75fcb823d2dff3daf Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Apr 2024 12:37:17 -0700 Subject: [PATCH 43/52] fixloading config file --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f9a387a7c450c..45b73e25a781f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -270,7 +270,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") - dtype_selector = "" if not dtype else f",{dtype=}" + dtype_selector = "" if not dtype else f",dtype={dtype}" return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" From b733cea0574a0b010f75a59dd4d347738b42e90c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Apr 2024 19:06:49 -0700 Subject: [PATCH 44/52] update --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 45b73e25a781f..cc6e6cba66608 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -137,7 +137,7 @@ def fused_moe_kernel( other=0.0) # We accumulate along the K dimension. if use_fp8: - accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=True) + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. From d53b1fc47859f061fb06ccc46e82d2325157abab Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Apr 2024 19:35:35 -0700 Subject: [PATCH 45/52] update --- vllm/model_executor/models/mixtral.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a8d5841ca7b10..83d5eea87e1f2 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -49,6 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput +from vllm.utils import print_warning_once class MixtralMoE(nn.Module): @@ -204,9 +205,9 @@ def __init__(self, self.sliding_window = sliding_window if isinstance(linear_method, Fp8LinearMethod): - # FIXME(pcmoritz): If we are using FP8, we currently do - # not want to quantize the attention layers until we improve - # the performance and make sure the accuracy is good. + print_warning_once( + "For Mixtral FP8 quantization, we currently do not quantize " + "the attention layers until their FP8 performance is improved.") linear_method = None self.qkv_proj = QKVParallelLinear( From 5ef2ee91e271b7057444b0f069d5effb86a2275f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 23 Apr 2024 12:11:56 -0700 Subject: [PATCH 46/52] update --- vllm/model_executor/models/mixtral.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 83d5eea87e1f2..4e4693a9bf737 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -96,13 +96,13 @@ def __init__( 2 * self.intermediate_size, self.hidden_size, device="cuda", - dtype=self.params_dtype)) + dtype=self.params_dtype)) if self.use_fp8 else None self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, device="cuda", - dtype=self.params_dtype)) + dtype=self.params_dtype)) if self.use_fp8 else None # Scaling factors for fp8 weights. If fp8 is not used, these parameters # are 1.0 so no rescaling will happen. @@ -207,7 +207,8 @@ def __init__(self, if isinstance(linear_method, Fp8LinearMethod): print_warning_once( "For Mixtral FP8 quantization, we currently do not quantize " - "the attention layers until their FP8 performance is improved.") + "the attention layers until their FP8 performance is improved." + ) linear_method = None self.qkv_proj = QKVParallelLinear( From 88073007af92a04e1d06ffa68f4d04da40530f3b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 23 Apr 2024 12:18:09 -0700 Subject: [PATCH 47/52] fix --- vllm/model_executor/layers/fused_moe/fused_moe.py | 8 +++++--- vllm/model_executor/models/mixtral.py | 9 ++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index cc6e6cba66608..ac7c30e2a9727 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -231,10 +231,12 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 - if use_fp8: - A, A_scale = ops.scaled_fp8_quant(A) - else: + if not use_fp8: A_scale = None + assert B_scale is None + else: + A, A_scale = ops.scaled_fp8_quant(A) + assert B_scale is not None grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 4e4693a9bf737..d4f5a030005f3 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -96,7 +96,7 @@ def __init__( 2 * self.intermediate_size, self.hidden_size, device="cuda", - dtype=self.params_dtype)) if self.use_fp8 else None + dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, @@ -104,16 +104,15 @@ def __init__( device="cuda", dtype=self.params_dtype)) if self.use_fp8 else None - # Scaling factors for fp8 weights. If fp8 is not used, these parameters - # are 1.0 so no rescaling will happen. + # Scaling factors for fp8 weights. self.ws_scale = nn.Parameter(torch.ones(self.num_total_experts, device="cuda", dtype=torch.float32), - requires_grad=False) + requires_grad=False) if self.use_fp8 else None self.w2s_scale = nn.Parameter(torch.ones(self.num_total_experts, device="cuda", dtype=torch.float32), - requires_grad=False) + requires_grad=False) if self.use_fp8 else None set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, From a15a7b5476c75887ae6d66e45acf39116d78ceaa Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 23 Apr 2024 12:20:27 -0700 Subject: [PATCH 48/52] update --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d4f5a030005f3..11f4b6d1b2547 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -102,7 +102,7 @@ def __init__( self.hidden_size, self.intermediate_size, device="cuda", - dtype=self.params_dtype)) if self.use_fp8 else None + dtype=self.params_dtype)) # Scaling factors for fp8 weights. self.ws_scale = nn.Parameter(torch.ones(self.num_total_experts, From 8fd40c1dca4c4db9afc9646192327d40e4a2aa78 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 23 Apr 2024 12:25:48 -0700 Subject: [PATCH 49/52] format --- vllm/model_executor/models/mixtral.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 11f4b6d1b2547..a33b795d7088e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -104,15 +104,15 @@ def __init__( device="cuda", dtype=self.params_dtype)) - # Scaling factors for fp8 weights. - self.ws_scale = nn.Parameter(torch.ones(self.num_total_experts, - device="cuda", - dtype=torch.float32), - requires_grad=False) if self.use_fp8 else None - self.w2s_scale = nn.Parameter(torch.ones(self.num_total_experts, - device="cuda", - dtype=torch.float32), - requires_grad=False) if self.use_fp8 else None + # Scaling factors for FP8 weights + self.ws_scale = nn.Parameter( + torch.ones( + self.num_total_experts, device="cuda", dtype=torch.float32), + requires_grad=False) if self.use_fp8 else None + self.w2s_scale = nn.Parameter( + torch.ones( + self.num_total_experts, device="cuda", dtype=torch.float32), + requires_grad=False) if self.use_fp8 else None set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, From 0f93811ce2a910fc2d74fe7b23697c19e0a0e1ed Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 23 Apr 2024 12:29:05 -0700 Subject: [PATCH 50/52] align --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 055f69341456c..c3337cede1282 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -77,7 +77,7 @@ __global__ void scaled_fp8_quant_kernel( void scaled_fp8_quant( torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); From fbbfc61bce2e22b52ed24a56a3ba26fb1e7f5dd4 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 23 Apr 2024 12:53:43 -0700 Subject: [PATCH 51/52] rerun ci From 725270e879202b0c3b1369a17b40cef97ea49834 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 23 Apr 2024 15:36:42 -0700 Subject: [PATCH 52/52] rerun ci