diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 92184f43c9eb0..666d87eb92595 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,6 +25,8 @@ #include +#include "core/scalar_type.hpp" + template inline std::string str(T x) { return std::to_string(x); @@ -131,11 +133,26 @@ __device__ inline int lop3(int a, int b, int c) { return res; } -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -156,6 +173,28 @@ __device__ inline FragB dequant(int q) { return frag_b; } +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -296,7 +335,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; int same_group_id[stages]; auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); @@ -840,10 +902,19 @@ __device__ inline void MarlinMoESingle( // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } - FragB frag_b0 = dequant(b_quant); + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -855,8 +926,6 @@ __device__ inline void MarlinMoESingle( } } - FragB frag_b1 = dequant(b_quant_shift); - // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -881,13 +950,13 @@ __device__ inline void MarlinMoESingle( // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; + constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -1035,8 +1104,10 @@ __device__ inline void MarlinMoESingle( auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { res = __hmul2(res, s[0]); } @@ -1088,9 +1159,9 @@ __device__ inline void MarlinMoESingle( // Start global fetch and register load pipelines. auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - __syncthreads(); + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + // __syncthreads(); #pragma unroll for (int i = 0; i < stages - 1; i++) { @@ -1166,28 +1237,70 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } } } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1227,7 +1340,8 @@ __device__ inline void MarlinMoESingle( } } -template 4) { + if (max_block > cfg_max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * max_block - pad) / 64; - par = min((16 * max_block - pad) / 64, max_par); - prob_m = 64 * par; - m_block_ctr += 4 * (par - 1); - max_block = 4; + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; } if (max_block == 1) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1342,7 +1457,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template , \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par); \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks); \ } typedef struct { @@ -1423,6 +1543,11 @@ typedef struct { int num_threads; } thread_config_t; +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + thread_config_t small_batch_thread_configs[] = { // Ordered by priority @@ -1443,8 +1568,77 @@ thread_config_t large_batch_thread_configs[] = { {128, 64, 128}, // Reduce N 4X, increase K 2X }; -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = ceildiv(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * STAGES; + } +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = ceildiv(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * STAGES; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1472,64 +1666,88 @@ bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, return false; } + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + return true; } -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } - return thread_config_t{-1, -1, -1}; + return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, int prob_m, int prob_n, int prob_k, void* workspace, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool replicate_input, + bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1537,26 +1755,42 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); } + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int num_bits = q_type.size_bits(); + // Set thread config - thread_config_t th_config; + exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; } else { // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); } - TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + - " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + - str(prob_n) + "]"); - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; @@ -1590,11 +1824,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, } } - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - int tot_m = prob_m; const int* topk_ids_ptr = (const int*)topk_ids; @@ -1611,10 +1840,13 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order = false; } + int pack_factor = 32 / q_type.size_bits(); + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const int4* A_ptr = (const int4*)A; int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + const int4* B_ptr = + (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; @@ -1636,19 +1868,22 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, A_ptr = a_tmp_ptr; } - int max_m_blocks = ceildiv(tot_m, 16); - for (int m_block = 0; m_block < max_m_blocks; m_block += 16) { - // Define kernel configurations - + int tot_m_blocks = ceildiv(tot_m, 16); + for (int m_block = 0; m_block < tot_m_blocks; + m_block += 4 * exec_cfg.max_m_blocks) { // make it max possible value - int thread_m_blocks = 4; + int thread_m_blocks = exec_cfg.max_m_blocks; if (false) { } - CALL_IF_MOE(16, 4, 256) - CALL_IF_MOE(8, 8, 256) - CALL_IF_MOE(8, 4, 128) - CALL_IF_MOE(4, 8, 128) + CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1670,9 +1905,15 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { + TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + + int pack_factor = 32 / b_q_type->size_bits(); + int max_par = 4; int dev = a.get_device(); @@ -1733,8 +1974,8 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - has_act_order, is_k_full, num_groups, group_size, num_experts, topk, - moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, + topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; } diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 43d264e0770d6..adee8399a4d6f 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -2,11 +2,14 @@ #include +#include "core/scalar_type.hpp" + torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 8a0e625b43fa1..cd65a8ee92b94 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " - "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " - "bool replicate_input, bool apply_weights) -> Tensor"); + "g_idx, Tensor! perm, Tensor! workspace, " + "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2250cf1598b8b..8072cf09e5b65 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe( m: int, n: int, @@ -148,6 +149,7 @@ def test_fused_marlin_moe( topk: int, group_size: int, act_order: bool, + num_bits: int, ): torch.manual_seed(7) @@ -161,13 +163,12 @@ def test_fused_marlin_moe( if group_size in (k, n): return - quant_type = scalar_types.uint4b8 + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - for i in range(w2.shape[0]): - w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) w_ref1_l = [] qweight1_l = [] @@ -240,6 +241,7 @@ def test_fused_marlin_moe( topk_ids, w1_scale=scales1, w2_scale=scales2, + num_bits=num_bits, ) assert compute_max_diff(marlin_output, triton_output) < 4e-2 @@ -254,7 +256,8 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) -def test_marlin_moe_mmm( +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_single_marlin_moe_multiply( m: int, n: int, k: int, @@ -262,6 +265,7 @@ def test_marlin_moe_mmm( topk: int, group_size: int, act_order: bool, + num_bits: int, ): if topk > e: return @@ -273,7 +277,8 @@ def test_marlin_moe_mmm( if group_size == k: return - quant_type = scalar_types.uint4b8 + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 @@ -308,7 +313,8 @@ def test_marlin_moe_mmm( g_idx, sort_indices, topk, - renormalize=False) + renormalize=False, + num_bits=num_bits) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index fe76705746766..2f5c6c5a117f3 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -1,3 +1,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main -gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main \ No newline at end of file +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh old mode 100644 new mode 100755 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ed08878f14875..74b3b69606c67 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -559,7 +559,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 200a6148978aa..866b18d725a8c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -7,18 +7,21 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) +from vllm.scalar_type import scalar_types def single_marlin_moe( - hidden_states: torch.Tensor, - w: torch.Tensor, - scales: torch.Tensor, - gating_output: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + num_bits: int = 8, +) -> torch.Tensor: """ This function computes the multiplication of hidden_states with expert weights used in Marlin MoE, using weights w and top-k gating mechanism. @@ -36,6 +39,7 @@ def single_marlin_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -48,10 +52,11 @@ def single_marlin_moe( assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous" assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] M, K = hidden_states.shape E = w.shape[0] - N = w.shape[2] // 2 + N = w.shape[2] // (num_bits // 2) topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) @@ -76,10 +81,13 @@ def single_marlin_moe( device="cuda", requires_grad=False) + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True, - False) + g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, + block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -98,6 +106,7 @@ def fused_marlin_moe( override_config: Optional[Dict[str, Any]] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + num_bits: int = 8, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -122,6 +131,7 @@ def fused_marlin_moe( w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -131,13 +141,14 @@ def fused_marlin_moe( 0], "Number of tokens mismatch" assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] M, K = hidden_states.shape E = w1.shape[0] @@ -165,6 +176,9 @@ def fused_marlin_moe( device="cuda", requires_grad=False) + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), device=hidden_states.device, @@ -181,6 +195,7 @@ def fused_marlin_moe( g_idx1, perm1, workspace, + scalar_type, M, 2 * N, K, @@ -204,6 +219,7 @@ def fused_marlin_moe( g_idx2, perm2, workspace, + scalar_type, M, K, N, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a0cb4337f9dee..3e01112eaa14d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids.to(torch.int32) + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def get_config_dtype_str(dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 49c29c2775cb6..7dee2fca81153 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -6,6 +6,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat) from vllm.model_executor.utils import set_weight_attrs @@ -38,10 +40,11 @@ def __init__( if not (self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits == 4): + and self.num_bits in WNA16_SUPPORTED_BITS): raise ValueError("For Fused MoE layers, only ", f"{CompressionFormat.pack_quantized.value} ", - "is supported for 4 bits") + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, @@ -292,4 +295,5 @@ def apply( topk_ids, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, + num_bits=self.num_bits, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3617a32f80fc1..cc699f5b4554f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -611,4 +611,5 @@ def apply( topk_ids, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, + num_bits=self.quant_config.quant_type.size_bits, ).to(orig_dtype) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 0052489d99dc4..2bfe6ea09bd62 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,13 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors"] - # for gptq_marlin, only run fused MoE for int4 - if model_config.quantization == "gptq_marlin": - hf_quant_config = getattr(model_config.hf_config, - "quantization_config", None) - if hf_quant_config and hf_quant_config.get("bits") == 4: - mixtral_supported.append("gptq_marlin") + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported