From 376df3e2a4f55c880b44bfb004071acbeaec6da5 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Sun, 29 Sep 2024 03:19:40 +0200 Subject: [PATCH] [Bugfix] Fix Marlin MoE act order when is_k_full == False (#8741) Co-authored-by: Tyler Michael Smith Signed-off-by: Amit Garg --- csrc/core/exception.hpp | 3 ++ csrc/moe/marlin_moe_ops.cu | 12 +++---- tests/kernels/test_moe.py | 32 +++++++++++++------ .../layers/fused_moe/fused_marlin_moe.py | 8 +++-- 4 files changed, 37 insertions(+), 18 deletions(-) create mode 100644 csrc/core/exception.hpp diff --git a/csrc/core/exception.hpp b/csrc/core/exception.hpp new file mode 100644 index 0000000000000..f3b2ffaef6cce --- /dev/null +++ b/csrc/core/exception.hpp @@ -0,0 +1,3 @@ +#pragma once + +#define VLLM_IMPLIES(p, q) (!(p) || (q)) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index dfe0437414013..c97b5dbd2a54e 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,6 +25,7 @@ #include +#include "core/exception.hpp" #include "core/scalar_type.hpp" #include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h" @@ -189,7 +190,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int load_groups = tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 2; + return load_groups * tb_n * 4; } else { int tb_scales = tb_groups * tb_n * 2; @@ -433,11 +434,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C, int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; - const int4* s_ptr = - (const int4*)s + - (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * - prob_n / 8) * - expert_idx; + const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx; int* locks = (int*)workspace; @@ -521,6 +518,9 @@ torch::Tensor marlin_gemm_moe( " is not size_n = ", size_n); num_groups = b_scales.size(1); + TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), + "if is_k_full is false, has_act_order must be true"); + if (has_act_order) { if (is_k_full) { TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index c6ddcc8ce79f5..cbbb5c9b79c42 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -145,6 +145,7 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("is_k_full", [True, False]) def test_fused_marlin_moe( m: int, n: int, @@ -154,6 +155,7 @@ def test_fused_marlin_moe( group_size: int, act_order: bool, num_bits: int, + is_k_full: bool, ): seed_everything(7) @@ -166,6 +168,9 @@ def test_fused_marlin_moe( return if group_size in (k, n): return + else: + if not is_k_full: + return quant_type = (scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128) @@ -246,6 +251,7 @@ def test_fused_marlin_moe( w1_scale=scales1, w2_scale=scales2, num_bits=num_bits, + is_k_full=is_k_full, ) assert compute_max_diff(marlin_output, triton_output) < 4e-2 @@ -290,6 +296,7 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("is_k_full", [True, False]) def test_single_marlin_moe_multiply( m: int, n: int, @@ -299,6 +306,7 @@ def test_single_marlin_moe_multiply( group_size: int, act_order: bool, num_bits: int, + is_k_full: bool, ): if topk > e: return @@ -309,6 +317,9 @@ def test_single_marlin_moe_multiply( return if group_size == k: return + else: + if not is_k_full: + return quant_type = (scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128) @@ -339,15 +350,18 @@ def test_single_marlin_moe_multiply( sort_indices = stack_and_dev(sort_indices_l) score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = single_marlin_moe(a, - qweight, - scales, - score, - g_idx, - sort_indices, - topk, - renormalize=False, - num_bits=num_bits) + marlin_output = single_marlin_moe( + a, + qweight, + scales, + score, + g_idx, + sort_indices, + topk, + renormalize=False, + num_bits=num_bits, + is_k_full=is_k_full, + ) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 866b18d725a8c..8177e846127ee 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -21,6 +21,7 @@ def single_marlin_moe( renormalize: bool, override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, + is_k_full: bool = True, ) -> torch.Tensor: """ This function computes the multiplication of hidden_states with expert @@ -86,7 +87,7 @@ def single_marlin_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, + g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -107,6 +108,7 @@ def fused_marlin_moe( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, num_bits: int = 8, + is_k_full: bool = True, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -199,7 +201,7 @@ def fused_marlin_moe( M, 2 * N, K, - True, + is_k_full, E, topk, block_size_m, @@ -223,7 +225,7 @@ def fused_marlin_moe( M, K, N, - True, + is_k_full, E, topk, block_size_m,