From 2d970c53822a0d5840fa96bc290b6f37998ce6cf Mon Sep 17 00:00:00 2001 From: chutianxiang Date: Sun, 4 Feb 2024 21:45:56 +0800 Subject: [PATCH 01/20] Add kernel --- csrc/ops.h | 13 + csrc/pybind.cpp | 1 + csrc/quantization/gptq/q_gemm.cu | 457 +++++++++++++++++++++++++++++++ 3 files changed, 471 insertions(+) diff --git a/csrc/ops.h b/csrc/ops.h index 2bcd0c2efc5c6..fc591fb2e93d3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -100,6 +100,19 @@ void gptq_shuffle( torch::Tensor q_weight, torch::Tensor q_perm); +torch::Tensor group_gptq_gemm( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + torch::Tensor topk_weights, + torch::Tensor sorted_token_ids_ptr, + torch::Tensor expert_ids_ptr, + bool mul_weights, + bool use_exllama +); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 8a8235691ab8e..d712dbfdafe66 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -54,6 +54,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); + ops.def("group_gptq_gemm", &group_gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def( diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index a5d2345f1e7fd..18701af2975d9 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -822,6 +822,421 @@ void shuffle_exllama_weight shuffle_kernel<<>>(q_weight, height, width); } + +template +__global__ void group_gemm_half_q_half_gptq_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm, + const half* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int num_valid_tokens, + const int top_k +) +{ + int expert_id = expert_ids_ptr[blockIdx.y]; + b_q_weight = b_q_weight + size_k * size_n / 8 * expert_id; + b_gptq_qzeros = b_gptq_qzeros + groups * size_n / 8 * expert_id; + b_gptq_scales = b_gptq_scales + groups * size_n * expert_id; + + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + int token_a[m_count]; + + int valid_count = m_count; + for (int m = 0; m < m_count; ++m) { + int token_id = sorted_token_ids_ptr[offset_m + m]; + if (token_id >= num_valid_tokens) { + valid_count = m; + break; + } + token_a[m] = token_id; + } + + if (offset_k + t < end_k) + { + for (int m = 0; m < valid_count; ++m) + { + const half* a_ptr = a_.item_ptr(token_a[m] / top_k, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + #pragma unroll + for (int m = 0; m < m_count; m++) + { + if (m >= valid_count) break; + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < valid_count; m++) + { + half2 *out = (half2*) c_.item_ptr(token_a[m], n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + if (topk_weights) { + half2 topk_weight = __half2half2(topk_weights[token_a[m]]); + result01 = __hmul2(result01, topk_weight); + result23 = __hmul2(result23, topk_weight); + } + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + +void group_gemm_half_q_half_cuda +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_q_perm, + half* c, + const half* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int num_valid_tokens, + const int top_k, + int size_m, + int size_n, + int size_k, + int pad_size_m, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(pad_size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + group_gemm_half_q_half_gptq_kernel<<>> + ( + a, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + c, + size_m, + size_n, + size_k, + groups, + b_q_perm, + topk_weights, + sorted_token_ids_ptr, + expert_ids_ptr, + num_valid_tokens, + top_k + ); +} + +__global__ void group_gemm_half_q_half_alt_kernel( + const half2* __restrict__ vec, + const uint32_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int height, + int width, + int groups, + const half* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int num_valid_tokens, + const int top_k +) +{ + int expert_id = expert_ids_ptr[blockIdx.y]; + mat = mat + height * width * expert_id; + scales = scales + groups * width * expert_id; + zeros = zeros + groups * width / 8 * expert_id; + g_idx = g_idx + height * 8 * expert_id; + + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + int token_a[BLOCK_M_SIZE_MAX]; + for (int m = 0; m < b_end; ++m) { + int token_id = sorted_token_ids_ptr[b + m]; + if (token_id >= num_valid_tokens) { + b_end = m; + break; + } + token_a[m] = token_id; + } + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[token_a[m] / top_k * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; + } + } + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0xF), __int2half_rn(val >> 4) + ); + } + + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), + __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)) + ); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + if (topk_weights) { + res[m] = __hmul(res[m], topk_weights[token_a[m]]); + } + } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[token_a[m] * width + w], res[m]); + } +} + + +void group_gemm_half_q_half_alt +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + const half* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int num_valid_tokens, + const int top_k, + int size_m, + int size_n, + int size_k, + int pad_size_m, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(pad_size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + group_gemm_half_q_half_alt_kernel<<>> + ( + (const half2*) a, + b_q_weight, + c, + b_gptq_scales, + b_gptq_qzeros, + b_g_idx, + size_m, + size_k / 8, + size_n, + groups, + topk_weights, + sorted_token_ids_ptr, + expert_ids_ptr, + num_valid_tokens, + top_k + ); +} + +void group_gemm_half_q_half_cuda +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + const half* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int num_valid_tokens, + const int top_k, + int size_m, + int size_n, + int size_k, + int pad_size_m, + int groups, + bool use_exllama +) { + if (use_exllama) { + group_gemm_half_q_half_cuda( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_valid_tokens, + top_k, size_m, size_n, size_k, pad_size_m, groups + ); + } else { + group_gemm_half_q_half_alt( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_valid_tokens, + top_k, size_m, size_n, size_k, pad_size_m, groups + ); + } +} + } // namespace gptq } // namespace vllm @@ -873,3 +1288,45 @@ void gptq_shuffle q_weight.size(1) ); } + +torch::Tensor group_gptq_gemm +( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + torch::Tensor topk_weights, + torch::Tensor sorted_token_ids_ptr, + torch::Tensor expert_ids_ptr, + bool mul_weights, + bool use_exllama +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::empty({a.size(0), topk_weights.size(1), b_q_weight.size(2)}, options); + + vllm::gptq::group_gemm_half_q_half_cuda + ( + (const half*) a.data_ptr(), + (const uint32_t*) b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*) b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), + (half*) c.data_ptr(), + mul_weights ? (const half*) topk_weights.data_ptr() : NULL, + (const int*) sorted_token_ids_ptr.data_ptr(), + (const int*) expert_ids_ptr.data_ptr(), + topk_weights.numel(), // num tokens + topk_weights.size(1) / a.size(1), // top_k + a.size(0) * a.size(1), // m + c.size(1), // n + a.size(2), // k + b_gptq_qzeros.size(1), // group number + sorted_token_ids_ptr.size(0), + use_exllama + ); + return c; +} \ No newline at end of file From 281354acd6e2d9ec0ad67e88d5f67956e36c503b Mon Sep 17 00:00:00 2001 From: chutianxiang Date: Mon, 5 Feb 2024 10:59:46 +0800 Subject: [PATCH 02/20] Add group gemm kernel for gptq --- csrc/ops.h | 1 + csrc/quantization/gptq/q_gemm.cu | 44 ++++++++++++++++++++++---------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index fc591fb2e93d3..ec9ddc3c95206 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -109,6 +109,7 @@ torch::Tensor group_gptq_gemm( torch::Tensor topk_weights, torch::Tensor sorted_token_ids_ptr, torch::Tensor expert_ids_ptr, + torch::Tensor num_tokens_post_padded, bool mul_weights, bool use_exllama ); diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 18701af2975d9..95f3cbb62cc34 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -839,14 +839,20 @@ __global__ void group_gemm_half_q_half_gptq_kernel const half* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens, const int top_k ) { + int num_tokens = *num_tokens_post_padded; + int offset_m = blockIdx.y * m_count; + if (offset_m >= num_tokens) return + int expert_id = expert_ids_ptr[blockIdx.y]; b_q_weight = b_q_weight + size_k * size_n / 8 * expert_id; b_gptq_qzeros = b_gptq_qzeros + groups * size_n / 8 * expert_id; b_gptq_scales = b_gptq_scales + groups * size_n * expert_id; + b_q_perm = b_q_perm + size_k * expert_id; MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); @@ -857,7 +863,6 @@ __global__ void group_gemm_half_q_half_gptq_kernel // Block int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; int offset_k = blockIdx.z * BLOCK_KN_SIZE; int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); @@ -954,10 +959,8 @@ __global__ void group_gemm_half_q_half_gptq_kernel dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - #pragma unroll - for (int m = 0; m < m_count; m++) + for (int m = 0; m < valid_count; m++) { - if (m >= valid_count) break; block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); @@ -997,6 +1000,7 @@ void group_gemm_half_q_half_cuda const half* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens, const int top_k, int size_m, @@ -1030,6 +1034,7 @@ void group_gemm_half_q_half_cuda topk_weights, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_post_padded, num_valid_tokens, top_k ); @@ -1049,10 +1054,15 @@ __global__ void group_gemm_half_q_half_alt_kernel( const half* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens, const int top_k ) { + int num_tokens = *num_tokens_post_padded; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + if (b >= num_tokens) return; + int expert_id = expert_ids_ptr[blockIdx.y]; mat = mat + height * width * expert_id; scales = scales + groups * width * expert_id; @@ -1062,8 +1072,7 @@ __global__ void group_gemm_half_q_half_alt_kernel( int zero_width = width / 8; int vec_height = height * 4; const int blockwidth2 = BLOCK_KN_SIZE / 2; - int b = blockIdx.y * BLOCK_M_SIZE_MAX; - int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int b_end = BLOCK_M_SIZE_MAX; int h = BLOCK_KN_SIZE * blockIdx.z / 8; int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; @@ -1140,14 +1149,14 @@ __global__ void group_gemm_half_q_half_alt_kernel( #else res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); #endif - if (topk_weights) { - res[m] = __hmul(res[m], topk_weights[token_a[m]]); - } } i += width; k += 4; } for (int m = 0; m < b_end; m++) { + if (topk_weights) { + res[m] = __hmul(res[m], topk_weights[token_a[m]]); + } atomicAdd(&mul[token_a[m] * width + w], res[m]); } } @@ -1164,6 +1173,7 @@ void group_gemm_half_q_half_alt const half* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens, const int top_k, int size_m, @@ -1197,6 +1207,7 @@ void group_gemm_half_q_half_alt topk_weights, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_post_padded, num_valid_tokens, top_k ); @@ -1213,6 +1224,7 @@ void group_gemm_half_q_half_cuda const half* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens, const int top_k, int size_m, @@ -1225,13 +1237,15 @@ void group_gemm_half_q_half_cuda if (use_exllama) { group_gemm_half_q_half_cuda( a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, - topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_valid_tokens, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_post_padded, num_valid_tokens, top_k, size_m, size_n, size_k, pad_size_m, groups ); } else { group_gemm_half_q_half_alt( a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, - topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_valid_tokens, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_post_padded, num_valid_tokens, top_k, size_m, size_n, size_k, pad_size_m, groups ); } @@ -1299,6 +1313,7 @@ torch::Tensor group_gptq_gemm torch::Tensor topk_weights, torch::Tensor sorted_token_ids_ptr, torch::Tensor expert_ids_ptr, + torch::Tensor num_tokens_post_padded, bool mul_weights, bool use_exllama ) @@ -1306,7 +1321,7 @@ torch::Tensor group_gptq_gemm const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - at::Tensor c = torch::empty({a.size(0), topk_weights.size(1), b_q_weight.size(2)}, options); + at::Tensor c = torch::zeros({a.size(0), topk_weights.size(1), b_q_weight.size(2)}, options); vllm::gptq::group_gemm_half_q_half_cuda ( @@ -1319,13 +1334,14 @@ torch::Tensor group_gptq_gemm mul_weights ? (const half*) topk_weights.data_ptr() : NULL, (const int*) sorted_token_ids_ptr.data_ptr(), (const int*) expert_ids_ptr.data_ptr(), + (const int*) num_tokens_post_padded.data_ptr(), topk_weights.numel(), // num tokens topk_weights.size(1) / a.size(1), // top_k a.size(0) * a.size(1), // m - c.size(1), // n + c.size(2), // n a.size(2), // k - b_gptq_qzeros.size(1), // group number sorted_token_ids_ptr.size(0), + b_gptq_qzeros.size(1), // group number use_exllama ); return c; From 2a1c106119b89a2afca6179751af13118f73763a Mon Sep 17 00:00:00 2001 From: chutianxiang Date: Mon, 5 Feb 2024 21:49:18 +0800 Subject: [PATCH 03/20] Add dequant kernel --- csrc/ops.h | 8 + csrc/pybind.cpp | 3 +- csrc/quantization/gptq/q_gemm.cu | 121 ++++- vllm/model_executor/layers/fused_moe.py | 12 +- vllm/model_executor/layers/linear.py | 73 +++- .../layers/quantization/gptq.py | 56 +++ vllm/model_executor/model_loader.py | 5 - vllm/model_executor/models/__init__.py | 1 - vllm/model_executor/models/mixtral.py | 179 +++++--- vllm/model_executor/models/mixtral_quant.py | 412 ------------------ 10 files changed, 366 insertions(+), 504 deletions(-) delete mode 100644 vllm/model_executor/models/mixtral_quant.py diff --git a/csrc/ops.h b/csrc/ops.h index ec9ddc3c95206..676ade865a73e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -114,6 +114,14 @@ torch::Tensor group_gptq_gemm( bool use_exllama ); +torch::Tensor dequant_gptq( + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama +); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index d712dbfdafe66..79d3b07fde4af 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -54,8 +54,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); - ops.def("group_gptq_gemm", &group_gptq_gemm, "Quantized GEMM for GPTQ"); + ops.def("group_gptq_gemm", &group_gptq_gemm, "Grouped Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); + ops.def("dequant_gptq", &dequant_gptq, "Dequantize gptq weight to half"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def( "moe_align_block_size", diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 95f3cbb62cc34..08aa6a7f43682 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -316,6 +316,14 @@ __global__ void reconstruct_exllama_kernel half* __restrict__ b ) { + if (blockIdx.z > 0){ + b_q_weight = b_q_weight + blockIdx.z * size_k * size_n / 8; + b_gptq_scales = b_gptq_scales + blockIdx.z * groups * size_n; + b_gptq_qzeros = b_gptq_qzeros + blockIdx.z * groups * size_n / 8; + b_q_perm = b_q_perm + blockIdx.z * size_k; + b = b + blockIdx.z * size_k * size_n; + } + MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); @@ -426,7 +434,8 @@ void reconstruct_exllama half* out, int height, int width, - int groups + int groups, + int num_experts ) { dim3 blockDim, gridDim; @@ -434,6 +443,7 @@ void reconstruct_exllama blockDim.y = 1; gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + gridDim.z = num_experts; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); reconstruct_exllama_kernel<<>> @@ -597,6 +607,13 @@ __global__ void reconstruct_gptq_kernel half* __restrict__ out ) { + if (blockIdx.z > 0){ + w = w + blockIdx.z * height * width / 8; + w_scales = w_scales + blockIdx.z * group * width; + w_zeros = w_zeros + blockIdx.z * group * width / 8; + g_idx = g_idx + blockIdx.z * height; + out = out + blockIdx.z * height * width; + } // Start of block int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; @@ -634,7 +651,8 @@ void reconstruct_gptq half* out, int height, int width, - int groups + int groups, + int num_experts ) { dim3 blockDim, gridDim; @@ -642,6 +660,7 @@ void reconstruct_gptq blockDim.y = 1; gridDim.y = DIVIDE(height, 8); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + gridDim.z = num_experts; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); reconstruct_gptq_kernel<<>> ( @@ -678,12 +697,12 @@ void gemm_half_q_half_cuda // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) { reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, - size_k, size_n, groups); + size_k, size_n, groups, 1); } else { reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups); + temp_dq, size_k, size_n, groups, 1); } const half alpha = __float2half(1.0f); @@ -836,7 +855,7 @@ __global__ void group_gemm_half_q_half_gptq_kernel const int size_k, const int groups, const int* __restrict__ b_q_perm, - const half* __restrict__ topk_weights, + const float* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, const int* __restrict__ num_tokens_post_padded, @@ -846,7 +865,7 @@ __global__ void group_gemm_half_q_half_gptq_kernel { int num_tokens = *num_tokens_post_padded; int offset_m = blockIdx.y * m_count; - if (offset_m >= num_tokens) return + if (offset_m >= num_tokens) return; int expert_id = expert_ids_ptr[blockIdx.y]; b_q_weight = b_q_weight + size_k * size_n / 8 * expert_id; @@ -976,14 +995,15 @@ __global__ void group_gemm_half_q_half_gptq_kernel for (int m = 0; m < valid_count; m++) { + if (topk_weights) { + #pragma unroll + for (int j = 0; j < 4; ++j) { + block_c[m][j] = block_c[m][j] * topk_weights[token_a[m]]; + } + } half2 *out = (half2*) c_.item_ptr(token_a[m], n); half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); - if (topk_weights) { - half2 topk_weight = __half2half2(topk_weights[token_a[m]]); - result01 = __hmul2(result01, topk_weight); - result23 = __hmul2(result23, topk_weight); - } atomicAdd(out , result01); atomicAdd(out + 1, result23); } @@ -997,7 +1017,7 @@ void group_gemm_half_q_half_cuda const half* b_gptq_scales, const int* b_q_perm, half* c, - const half* __restrict__ topk_weights, + const float* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, const int* __restrict__ num_tokens_post_padded, @@ -1051,7 +1071,7 @@ __global__ void group_gemm_half_q_half_alt_kernel( int height, int width, int groups, - const half* __restrict__ topk_weights, + const float* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, const int* __restrict__ num_tokens_post_padded, @@ -1155,7 +1175,7 @@ __global__ void group_gemm_half_q_half_alt_kernel( } for (int m = 0; m < b_end; m++) { if (topk_weights) { - res[m] = __hmul(res[m], topk_weights[token_a[m]]); + res[m] = __float2half(__half2float(res[m]) * topk_weights[token_a[m]]); } atomicAdd(&mul[token_a[m] * width + w], res[m]); } @@ -1170,7 +1190,7 @@ void group_gemm_half_q_half_alt const half* b_gptq_scales, const int* b_g_idx, half* c, - const half* __restrict__ topk_weights, + const float* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, const int* __restrict__ num_tokens_post_padded, @@ -1221,7 +1241,7 @@ void group_gemm_half_q_half_cuda const half* b_gptq_scales, const int* b_g_idx, half* c, - const half* __restrict__ topk_weights, + const float* __restrict__ topk_weights, const int* __restrict__ sorted_token_ids_ptr, const int* __restrict__ expert_ids_ptr, const int* __restrict__ num_tokens_post_padded, @@ -1251,6 +1271,31 @@ void group_gemm_half_q_half_cuda } } +void dequant_gptq_cuda +( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* temp_dq, + int size_k, + int size_n, + int groups, + int num_experts, + bool use_exllama +) +{ + if (use_exllama) { + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, + size_k, size_n, groups, num_experts); + } + else + { + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, num_experts); + } +} + } // namespace gptq } // namespace vllm @@ -1331,7 +1376,7 @@ torch::Tensor group_gptq_gemm (const half*) b_gptq_scales.data_ptr(), b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), (half*) c.data_ptr(), - mul_weights ? (const half*) topk_weights.data_ptr() : NULL, + mul_weights ? (const float*) topk_weights.data_ptr() : NULL, (const int*) sorted_token_ids_ptr.data_ptr(), (const int*) expert_ids_ptr.data_ptr(), (const int*) num_tokens_post_padded.data_ptr(), @@ -1345,4 +1390,46 @@ torch::Tensor group_gptq_gemm use_exllama ); return c; +} + +torch::Tensor dequant_gptq +( + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_gptq_scales)); + auto options = torch::TensorOptions().dtype(b_gptq_scales.dtype()).device(b_gptq_scales.device()); + + at::Tensor temp_dq; + int num_experts; + int size_k; + int size_n; + int groups; + // moe + if (b_q_weight.dim() == 3) { + temp_dq = torch::empty({b_q_weight.size(0), b_q_weight.size(1) * 8, b_q_weight.size(2)}, options); + num_experts = b_q_weight.size(0); + size_k = b_q_weight.size(1) * 8; + size_n = b_q_weight.size(2); + groups = b_gptq_scales.size(1); + } else + { + temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options); + num_experts = 1; + size_k = b_q_weight.size(0) * 8; + size_n = b_q_weight.size(1); + groups = b_gptq_scales.size(0); + } + vllm::gptq::dequant_gptq_cuda( + (const uint32_t*) b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*) b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), + (half*) temp_dq.data_ptr(), + size_k, size_n, groups, + num_experts, use_exllama); + return temp_dq; } \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py index eed2e83bed7f8..627b1f4d143f1 100644 --- a/vllm/model_executor/layers/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe.py @@ -142,7 +142,7 @@ def moe_align_block_size( - expert_ids: A tensor indicating the assigned expert index for each block. - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. - This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. + This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. Padding ensures that during block matrix multiplication, the dimensions align correctly. Example: @@ -151,7 +151,7 @@ def moe_align_block_size( - As block_size is 4, we pad 1 token for each expert. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - Then append padding tokens [12, 12, 12, 12] for each block. - - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ @@ -218,7 +218,7 @@ def fused_moe(hidden_states: torch.Tensor, inplace=False): """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. - + Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. @@ -226,15 +226,15 @@ def fused_moe(hidden_states: torch.Tensor, - topk_weights (torch.Tensor): The weights for the top-k selected experts. - topk_ids (torch.Tensor): The indices of the top-k selected experts. - inplace (bool): If True, perform the operation in-place. Defaults to False. - + Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" + # assert w1.is_contiguous(), "Expert weights1 must be contiguous" + # assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 55d38b763b2b5..dc815cadebfc7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter +from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.communication_op import ( @@ -36,6 +37,24 @@ def apply_weights(self, """Apply the weights to the input tensor.""" raise NotImplementedError + def create_moe_weights(self, num_experts: int, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + """Creating moe weights""" + linear_weights = self.create_weights(input_size_per_partition, + output_size_per_partition, + input_size, output_size, + params_dtype) + for name, param in tuple(linear_weights.items()): + if isinstance(param, Parameter): + repeat_size = (num_experts,) + (1,) * param.dim() + new_param = Parameter(param.unsqueeze(0).repeat(*repeat_size), + requires_grad=False) + set_weight_attrs(new_param, param.__dict__) + linear_weights[name] = new_param + return linear_weights + class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization. @@ -70,6 +89,14 @@ def apply_weights(self, return F.linear(x, weight) return F.linear(x, weight, bias) + def apply_moe_weights(self, + w1: Dict[str, torch.Tensor], + w2: Dict[str, torch.Tensor], + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + return fused_moe(x, w1["weight"], w2["weight"], topk_weights, topk_ids) + class ReplicatedLinear(torch.nn.Module): """Replicated linear layer. @@ -153,6 +180,7 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, linear_method: Optional[LinearMethodBase] = None, + num_experts: int = 1, ): super().__init__() @@ -170,9 +198,16 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) + self.num_experts = num_experts + if num_experts > 1: + self.linear_weights = self.linear_method.create_moe_weights( + num_experts, self.input_size, self.output_size_per_partition, + self.input_size, self.output_size, self.params_dtype + ) + else: + self.linear_weights = self.linear_method.create_weights( + self.input_size, self.output_size_per_partition, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -188,10 +223,13 @@ def __init__( else: self.register_parameter("bias", None) - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, + expert_id: int = 0): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) param_data = param.data + if self.num_experts > 1: + param_data = param_data[expert_id] if output_dim is not None: shard_size = param_data.shape[output_dim] start_idx = tp_rank * shard_size @@ -245,18 +283,22 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, linear_method: Optional[LinearMethodBase] = None, + num_experts: int = 1, ): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method) + skip_bias_add, params_dtype, linear_method, num_experts) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None): + loaded_shard_id: Optional[int] = None, + expert_id: int = 0): param_data = param.data + if self.num_experts > 1: + param_data = param_data[expert_id] output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: # Loaded weight is already packed. @@ -473,6 +515,7 @@ def __init__( params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, linear_method: Optional[LinearMethodBase] = None, + num_experts: int = 1, ): super().__init__() # Keep input parameters @@ -491,9 +534,16 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) + self.num_experts = num_experts + if num_experts > 1: + self.linear_weights = self.linear_method.create_moe_weights( + num_experts, self.input_size_per_partition, self.output_size, + self.input_size, self.output_size, self.params_dtype + ) + else: + self.linear_weights = self.linear_method.create_weights( + self.input_size_per_partition, self.output_size, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -513,10 +563,13 @@ def __init__( else: self.register_parameter("bias", None) - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, + expert_id: int = 0): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data + if self.num_experts > 1: + param_data = param_data[expert_id] if input_dim is not None: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 7218760fbe55d..818c55fc3694b 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -6,6 +6,8 @@ from torch.nn.parameter import Parameter from vllm._C import ops +from vllm.model_executor.layers.fused_moe import ( + moe_align_block_size, fused_moe) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( @@ -209,3 +211,57 @@ def apply_weights(self, if bias is not None: output = output + bias return output.reshape(out_shape) + + def apply_moe_weights(self, + w1: Dict[str, torch.Tensor], + w2: Dict[str, torch.Tensor], + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + # shuffle weights for exllama + for w in [w1, w2]: + if w["exllama_state"] == ExllamaState.UNINITIALIZED: + if self.quant_config.desc_act: + w["g_idx"] = torch.argsort(w["g_idx"], dim=-1).to( + torch.int) + else: + w["g_idx"] = torch.empty((1, 1), device="meta") + w["exllama_state"] = ExllamaState.READY + # todo: implement single pass shuffle + for i in range(w["qweight"].shape[0]): + ops.gptq_shuffle( + w["qweight"][i], + w["g_idx"][i] if w["g_idx"].device != torch.device( + "meta") else w["g_idx"], + ) + + if x.shape[0] >= 100: + dequant_w1 = ops.dequant_gptq( + w1["qweight"], w1["qzeros"], w1["scales"], w1["g_idx"], + w1["exllama_state"] == ExllamaState.READY + ).permute(0, 2, 1) + dequant_w2 = ops.dequant_gptq( + w2["qweight"], w2["qzeros"], w2["scales"], w2["g_idx"], + w2["exllama_state"] == ExllamaState.READY + ).permute(0, 2, 1) + return fused_moe(x, dequant_w1, dequant_w2, topk_weights, topk_ids) + + (sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size( + topk_ids, 8, w1["qweight"].shape[0]) + + x = x.view(x.shape[0], 1, *x.shape[1:]) + gate_up = ops.group_gptq_gemm( + x, w1["qweight"], w1["qzeros"], w1["scales"], w1["g_idx"], + topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, + False, w1["exllama_state"] == ExllamaState.READY) + + out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), + dtype=x.dtype, device=x.device) + ops.silu_and_mul(out, gate_up) + + out = ops.group_gptq_gemm( + out, w2["qweight"], w2["qzeros"], w2["scales"], w2["g_idx"], + topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, + True, w2["exllama_state"] == ExllamaState.READY) + + return torch.sum(out, dim=1) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 4b1e13d9e9e0a..6f6666d512dec 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -22,11 +22,6 @@ def _set_default_torch_dtype(dtype: torch.dtype): def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - if (model_config.quantization is not None - and "MixtralForCausalLM" in architectures): - architectures = ["QuantMixtralForCausalLM"] for arch in architectures: model_cls = ModelRegistry.load_model_cls(arch) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index fb519b3c0cf92..91d8f860baccc 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -31,7 +31,6 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("mistral", "MistralForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a8e470395b904..edeafd31f99c8 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -23,6 +23,7 @@ """Inference-only Mixtral model.""" from typing import List, Optional, Tuple +import numpy as np import torch import torch.nn.functional as F @@ -33,10 +34,12 @@ from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + LinearMethodBase, QKVParallelLinear, ReplicatedLinear, - RowParallelLinear) + RowParallelLinear, + UnquantizedLinearMethod) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -54,6 +57,45 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] +class MixtralMLP(nn.Module): + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + linear_method=linear_method) + self.w2 = ReplicatedLinear(self.ffn_dim, + self.hidden_dim, + bias=False, + linear_method=linear_method) + self.w3 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + linear_method=linear_method) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states + + class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -71,13 +113,18 @@ def __init__( intermediate_size: int, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() + self.rank = get_tensor_model_parallel_rank() self.tp_size = tp_size or get_tensor_model_parallel_world_size() self.num_total_experts = num_experts self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size + self.linear_method = linear_method + if self.linear_method is None: + self.linear_method = UnquantizedLinearMethod() if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -89,39 +136,38 @@ def __init__( params_dtype=self.params_dtype, linear_method=None) - self.ws = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device="cuda", - dtype=self.params_dtype)) - self.w2s = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype)) - - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] + if not hasattr(self.linear_method, "apply_moe_weights"): + if self.tp_size > self.num_total_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_total_experts}.") + # Split experts equally between ranks + self.expert_indicies = np.array_split(range( + self.num_total_experts), self.tp_size)[self.rank].tolist() + if not self.expert_indicies: + raise ValueError( + f"Rank {self.rank} has no experts assigned to it.") + + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + hidden_size, + intermediate_size, + linear_method=linear_method) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) + else: + self.ws = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + num_experts=num_experts) + self.w2s = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method, + num_experts=num_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_size = hidden_states.shape @@ -135,12 +181,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, - routing_weights, - selected_experts, - inplace=True) + if not hasattr(self.linear_method, "apply_moe_weights"): + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + else: + final_hidden_states = self.linear_method.apply_moe_weights( + self.ws.linear_weights, + self.w2s.linear_weights, + hidden_states, + routing_weights, + selected_experts, + ) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -251,7 +313,8 @@ def __init__( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) + intermediate_size=config.intermediate_size, + linear_method=linear_method) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -372,12 +435,17 @@ def load_weights(self, ] expert_params_mapping = [ - # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) + # (param_name, weight_name, shard_id, expert_id) + ( + "ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}", + shard_id, + expert_id + ) for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + for weight_name, shard_id in [("w1", 0), ("w3", 1), ("w2", None)] + ] if self.linear_method is None or hasattr( + self.linear_method, "apply_moe_weights") else [] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( @@ -401,16 +469,23 @@ def load_weights(self, weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id in expert_params_mapping: + for param_name, weight_name, shard_id, expert_id in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + if shard_id is None: + weight_loader(param, + loaded_weight, + expert_id=expert_id) + else: + weight_loader(param, + loaded_weight, + shard_id, + expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py deleted file mode 100644 index a8dadce24aa1d..0000000000000 --- a/vllm/model_executor/models/mixtral_quant.py +++ /dev/null @@ -1,412 +0,0 @@ -# coding=utf-8 -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Mixtral model.""" -from typing import List, Optional, Tuple - -import numpy as np - -import torch -import torch.nn.functional as F - -from torch import nn -from transformers import MixtralConfig - -from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - ReplicatedLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) -from vllm.sequence import SamplerOutput - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -class MixtralMLP(nn.Module): - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - linear_method=linear_method) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralMoE(nn.Module): - - def __init__( - self, - config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}.") - # Split experts equally between ranks - self.expert_indicies = np.array_split(range( - self.num_total_experts), self.tp_size)[self.rank].tolist() - if not self.expert_indicies: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - linear_method=linear_method) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) - self.gate = ReplicatedLinear(config.hidden_size, - self.num_total_experts, - bias=False, - linear_method=None) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits, _ = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - batch_size, sequence_length, hidden_dim) - - -class MixtralAttention(nn.Module): - - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.sliding_window = sliding_window - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - linear_method=linear_method, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - linear_method=linear_method, - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position, - base=int(self.rope_theta), - is_neox_style=True, - ) - self.attn = PagedAttention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class MixtralDecoderLayer(nn.Module): - - def __init__( - self, - config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 10000) - self.self_attn = MixtralAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - sliding_window=config.sliding_window, - linear_method=linear_method) - self.block_sparse_moe = MixtralMoE(config=config, - linear_method=linear_method) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - residual: Optional[torch.Tensor], - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - input_metadata=input_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.block_sparse_moe(hidden_states) - return hidden_states, residual - - -class MixtralModel(nn.Module): - - def __init__( - self, - config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, linear_method=linear_method) - for _ in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], input_metadata, - residual) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class MixtralForCausalLM(nn.Module): - - def __init__( - self, - config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.config = config - self.linear_method = linear_method - self.model = MixtralModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) - return hidden_states - - def sample( - self, - hidden_states: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) - return next_tokens - - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): - if "rotary_emb.inv_freq" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) From e9846d046941dc11da8d123bc986a05c74488ec4 Mon Sep 17 00:00:00 2001 From: chutianxiang Date: Tue, 6 Feb 2024 21:27:12 +0800 Subject: [PATCH 04/20] Add awq supprt --- csrc/ops.h | 12 + csrc/pybind.cpp | 1 + csrc/quantization/awq/gemm_kernels.cu | 399 ++++++++++++++++-- vllm/model_executor/layers/linear.py | 1 + .../model_executor/layers/quantization/awq.py | 41 ++ .../layers/quantization/gptq.py | 3 +- .../layers/quantization/squeezellm.py | 1 + vllm/model_executor/models/deepseek.py | 196 ++++++--- vllm/model_executor/models/mixtral.py | 11 +- 9 files changed, 570 insertions(+), 95 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 676ade865a73e..d18eabcbd9e3a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -80,6 +80,18 @@ torch::Tensor awq_dequantize( int split_k_iters, int thx, int thy); + +torch::Tensor awq_group_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + torch::Tensor _topk_weights, + torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, + torch::Tensor _num_tokens_post_padded, + bool mul_weights, + int split_k_iters); #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 79d3b07fde4af..256cad6435217 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -51,6 +51,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #ifndef USE_ROCM // Quantization ops ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + ops.def("awq_group_gemm", &awq_group_gemm, "Grouped Quantized GEMM for AWQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 376c8ebfb9b7a..4a40febe3748b 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -27,7 +27,7 @@ __pack_half2(const half x, const half y) { return (v1 << 16) | v0; } -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); @@ -36,7 +36,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i float C_warp[32]; __shared__ half A_shared[16 * (32 + 8)]; __shared__ half B_shared[32 * (128 + 8)]; - + __shared__ half scaling_factors_shared[128]; __shared__ half zeros_shared[128]; @@ -60,19 +60,19 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = A + half* A_ptr = A + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + (((int)threadIdx.x) % (32 / 8)) * 8; - + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * 2 + (((int)threadIdx.x) / (128 / 8)) * (OC / 8) + (((int)blockIdx_y) % j_factors1) * (128 / 8) + (((int)threadIdx.x) % (128 / 8)) * 1; // Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + (((int)threadIdx.x) % (32 / 8) ) * 8; @@ -80,16 +80,16 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8) + (((int)threadIdx.x) / (128 / 8)) * (128 + 8) + (((int)threadIdx.x) % (128 / 8)) * 8; - + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (128 / 8) + ((int)threadIdx.x) % (128 / 8); - + half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * (128) + + (((int)blockIdx_y) % j_factors1) * (128) + (((int)threadIdx.x) % (128 / 8)) * 8; - half* C_ptr = C + half* C_ptr = C + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + (((int)blockIdx_y) % j_factors1) * 128 + ((int)threadIdx.y) * 64 @@ -129,7 +129,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i // each warp: 32 x 4 // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); @@ -259,7 +259,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i } -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); @@ -268,7 +268,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in float C_warp[32]; __shared__ half A_shared[16 * (32 + 8)]; __shared__ half B_shared[32 * (64 + 8)]; - + __shared__ half scaling_factors_shared[64]; __shared__ half zeros_shared[64]; @@ -293,19 +293,19 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = A + half* A_ptr = A + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + (((int)threadIdx.x) % (32 / 8)) * 8; - + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * 4 + (((int)threadIdx.x) / (64 / 8)) * (OC / 8) + (((int)blockIdx_y) % j_factors1) * (64 / 8) + (((int)threadIdx.x) % (64 / 8)) * 1; // Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + (((int)threadIdx.x) % (32 / 8) ) * 8; @@ -313,16 +313,16 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8) + (((int)threadIdx.x) / (64 / 8)) * (64 + 8) + (((int)threadIdx.x) % (64 / 8)) * 8; - + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (64 / 8) + ((int)threadIdx.x) % (64 / 8); - + half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * (64) + + (((int)blockIdx_y) % j_factors1) * (64) + (((int)threadIdx.x) % (64 / 8)) * 8; - half* C_ptr = C + half* C_ptr = C + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + (((int)blockIdx_y) % j_factors1) * 64 + ((int)threadIdx.y) * 32 @@ -362,7 +362,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in // each warp: 32 x 4 // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); @@ -389,7 +389,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in } __syncthreads(); - for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { { unsigned int addr; @@ -405,9 +405,9 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in : "r"(addr) ); } - - for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) + + for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) { { unsigned int addr; @@ -424,8 +424,8 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in ); } } - - for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) + + for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { @@ -498,9 +498,17 @@ __global__ void __launch_bounds__(64) dequantize_weights( half* __restrict__ scaling_factors, int* __restrict__ zeros, half* __restrict__ C, - int G + int G, + int in_c, + int out_c ) { + if (blockIdx.z > 0) { + B = B + blockIdx.z * in_c * out_c / 8; + scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G; + zeros = zeros + blockIdx.z * in_c * out_c / G / 8; + C = C + blockIdx.z * in_c * out_c; + } int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; @@ -550,6 +558,251 @@ int j=0; } } +template +__global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( + int G, + int split_k_iters, + half* __restrict__ A, + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + const float* __restrict__ topk_weights, + const int* __restrict__ sorted_token_ids_ptr, + const int* __restrict__ expert_ids_ptr, + const int* __restrict__ num_tokens_post_padded, + const int num_valid_tokens, + const int top_k, + const int expert_num, + int pad_M, + int M, + int IC, + int OC, + half* __restrict__ C) +{ + // Only support matrix n = 64 or 128 + assert(N == 64 || N == 128); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + int num_tokens = *num_tokens_post_padded; + int j_factors1 = ((OC + N - 1) / N); + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1); + int block = blockIdx_y / j_factors1; + if (block * 16 >= num_tokens) return; + + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (N + 8)]; + + __shared__ half scaling_factors_shared[N]; + __shared__ half zeros_shared[N]; + + half A_shared_warp[8]; + half B_shared_warp[N / 4]; + for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; + } + } + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / N; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + + int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); + int token_id = sorted_token_ids_ptr[row]; + bool ld_A_flag = (token_id < num_valid_tokens); + half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8; + + int expert_id = expert_ids_ptr[block]; + B = B + OC * IC / 8 * expert_id; + scaling_factors = scaling_factors + OC * IC / G * expert_id; + zeros = zeros + OC * IC / G / 8 * expert_id; + + int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = C + + static_cast(blockIdx_z) * M * OC * expert_num // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) + { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + + int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { + + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + + + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); + } + + for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } +#else + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + +#endif + } + } + } + +// TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + int token_id = sorted_token_ids_ptr[row_offset]; + if (token_id < num_valid_tokens) + { + float value = C_warp[(ax1_0_1 * 8) + local_id]; + if (topk_weights) { + value = value * topk_weights[token_id]; + } + *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(value); + } + } + } +#endif +} + } // namespace awq } // namespace vllm @@ -561,10 +814,11 @@ torch::Tensor awq_dequantize( int thx, int thy) { - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); + int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1); + int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2); + int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); + int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) : _scaling_factors.size(1)); int x_thread = thx; int y_thread = thy; @@ -587,19 +841,24 @@ torch::Tensor awq_dequantize( const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + at::Tensor _de_kernel; + if (num_experts == 1) { + _de_kernel = torch::empty({in_c, out_c}, options); + } else { + _de_kernel = torch::empty({num_experts, in_c, out_c}, options); + } auto kernel = reinterpret_cast(_kernel.data_ptr()); auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks); + dim3 num_blocks(x_blocks, y_blocks, num_experts); dim3 threads_per_block(x_thread, y_thread); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); vllm::awq::dequantize_weights<<>>( - kernel, scaling_factors, zeros, de_kernel, G); + kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); return _de_kernel; } @@ -657,7 +916,7 @@ torch::Tensor awq_gemm( { int j_factors1 = num_out_channels / 64 / 1; dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - + // threadIdx.x: 32 // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); @@ -666,3 +925,69 @@ torch::Tensor awq_gemm( } return _out_feats.sum(0); } + +torch::Tensor awq_group_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + torch::Tensor _topk_weights, + torch::Tensor _sorted_token_ids_ptr, + torch::Tensor _expert_ids_ptr, + torch::Tensor _num_tokens_post_padded, + bool mul_weights, + int split_k_iters) +{ + int num_in_feats = _in_feats.size(0); + int pad_num_in_feats = _sorted_token_ids_ptr.size(0); + int num_in_channels = _in_feats.size(2); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + int num_experts = _topk_weights.size(1); + int top_k = num_experts / _in_feats.size(1); + int group_size = num_in_channels / _scaling_factors.size(1); + + at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8}, options); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto topk_weights = mul_weights ? reinterpret_cast(_topk_weights.data_ptr()) : nullptr; + auto sorted_token_ids_ptr = reinterpret_cast(_sorted_token_ids_ptr.data_ptr()); + auto expert_ids_ptr = reinterpret_cast(_expert_ids_ptr.data_ptr()); + auto num_tokens_post_padded = reinterpret_cast(_num_tokens_post_padded.data_ptr()); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) + { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::group_gemm_forward_4bit_cuda_m16nXk32<128><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, + _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + else if (num_out_channels % 64 == 0) + { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::group_gemm_forward_4bit_cuda_m16nXk32<64><<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, + _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); +} diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index dc815cadebfc7..a2b4b3ee96517 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -66,6 +66,7 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add + self.support_fused_moe = True def create_weights(self, input_size_per_partition: int, output_size_per_partition: int, input_size: int, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 681f95821eabb..7820fe4e8b4b8 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,6 +4,8 @@ from torch.nn.parameter import Parameter from vllm._C import ops +from vllm.model_executor.layers.fused_moe import ( + moe_align_block_size, fused_moe) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -76,6 +78,7 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config + self.support_fused_moe = True def create_weights(self, input_size_per_partition: int, output_size_per_partition: int, input_size: int, @@ -163,3 +166,41 @@ def apply_weights(self, if bias is not None: out = out + bias return out.reshape(out_shape) + + def apply_moe_weights(self, + w1: Dict[str, torch.Tensor], + w2: Dict[str, torch.Tensor], + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1024 + if FP16_MATMUL_HEURISTIC_CONDITION: + dequant_w1 = ops.awq_dequantize( + w1["qweight"], w1["scales"], w1["qzeros"], 0, 0, 0 + ).permute(0, 2, 1) + dequant_w2 = ops.awq_dequantize( + w2["qweight"], w2["scales"], w2["qzeros"], 0, 0, 0 + ).permute(0, 2, 1) + return fused_moe(x, dequant_w1, dequant_w2, topk_weights, topk_ids) + + (sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size( + topk_ids, 16, w1["qweight"].shape[0]) + + x = x.view(x.shape[0], 1, *x.shape[1:]) + pack_factor = self.quant_config.pack_factor + + gate_up = ops.awq_group_gemm( + x, w1["qweight"], w1["scales"], w1["qzeros"], + topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, + False, pack_factor) + + out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), + dtype=x.dtype, device=x.device) + ops.silu_and_mul(out, gate_up) + + out = ops.awq_group_gemm( + out, w2["qweight"], w2["scales"], w2["qzeros"], + topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, + True, pack_factor) + + return torch.sum(out, dim=1) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 818c55fc3694b..5e4cdebb346d0 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -88,6 +88,7 @@ class GPTQLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQConfig): self.quant_config = quant_config + self.support_fused_moe = True def create_weights( self, @@ -235,7 +236,7 @@ def apply_moe_weights(self, "meta") else w["g_idx"], ) - if x.shape[0] >= 100: + if x.shape[0] >= 128: dequant_w1 = ops.dequant_gptq( w1["qweight"], w1["qzeros"], w1["scales"], w1["g_idx"], w1["exllama_state"] == ExllamaState.READY diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 9244e88552756..dbbf72f804d13 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -66,6 +66,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config + self.support_fused_moe = False def create_weights(self, input_size_per_partition: int, output_size_per_partition: int, input_size: int, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index fc727b8e661b3..09619dfe5d1ff 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -23,6 +23,7 @@ """Inference-only Deepseek model.""" from typing import Any, Dict, List, Optional, Tuple +import numpy as np import torch from torch import nn import torch.nn.functional as F @@ -33,7 +34,8 @@ from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, +from vllm.model_executor.layers.linear import (UnquantizedLinearMethod, + LinearMethodBase, MergedColumnParallelLinear, ReplicatedLinear, QKVParallelLinear, @@ -86,6 +88,39 @@ def forward(self, x): return x +class DeepseekExpertMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_proj = ReplicatedLinear(hidden_size, + intermediate_size, + bias=False, + linear_method=linear_method) + self.up_proj = ReplicatedLinear(hidden_size, + intermediate_size, + bias=False, + linear_method=linear_method) + self.down_proj = ReplicatedLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + self.act_fn = nn.SiLU() + + def forward(self, hidden_states): + gate_out, _ = self.gate_proj(hidden_states) + gate_out = self.act_fn(gate_out) + up_out, _ = self.up_proj(hidden_states) + current_hidden_states = gate_out * up_out + current_hidden_states, _ = self.down_proj(current_hidden_states) + return current_hidden_states + + class DeepseekMoE(nn.Module): def __init__( @@ -99,20 +134,44 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.n_routed_experts = config.n_routed_experts self.top_k = config.num_experts_per_tok - if self.tp_size > self.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.n_routed_experts}.") - - self.experts = nn.ModuleList([ - DeepseekMLP(hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - reduce_results=False) - for idx in range(self.n_routed_experts) - ]) - self.pack_params() + self.linear_method = linear_method + if self.linear_method is None: + self.linear_method = UnquantizedLinearMethod() + + if not self.linear_method.support_fused_moe: + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}.") + # Split experts equally between ranks + self.expert_indicies = np.array_split(range( + self.n_routed_experts), self.tp_size)[self.rank].tolist() + if not self.expert_indicies: + raise ValueError( + f"Rank {self.rank} has no experts assigned to it.") + + self.experts = nn.ModuleList([ + DeepseekExpertMLP( + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + if idx in self.expert_indicies else None + for idx in range(self.n_routed_experts) + ]) + else: + self.w1 = MergedColumnParallelLinear( + config.hidden_size, [config.moe_intermediate_size] * 2, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) + self.w2 = RowParallelLinear( + config.moe_intermediate_size, + config.hidden_size, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, @@ -129,25 +188,6 @@ def __init__( reduce_results=False, ) - def pack_params(self): - w1 = [] - w2 = [] - for expert in self.experts: - w1.append(expert.gate_up_proj.weight) - w2.append(expert.down_proj.weight) - self.w1 = torch._utils._flatten_dense_tensors(w1) - w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) - for data, param in zip(w1s, w1): - param.data = data - self.w1 = self.w1.view(len(w1), *w1s[0].shape) - - self.w2 = torch._utils._flatten_dense_tensors(w2) - w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) - for data, param in zip(w2s, w2): - param.data = data - - self.w2 = self.w2.view(len(w2), *w2s[0].shape) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -164,12 +204,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.config.norm_topk_prob: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - final_hidden_states = fused_moe(hidden_states, - self.w1, - self.w2, - routing_weights, - selected_experts, - inplace=True) + if not self.linear_method.support_fused_moe: + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + else: + final_hidden_states = self.linear_method.apply_moe_weights( + self.w1.linear_weights, + self.w2.linear_weights, + hidden_states, + routing_weights, + selected_experts, + ) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output @@ -411,10 +467,25 @@ def load_weights(self, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), + ("mlp.gate_up_proj", "mlp.gate_proj", 0), + ("mlp.gate_up_proj", "mlp.up_proj", 1), + ("shared_experts.gate_up_proj", "shared_experts.gate_proj", 0), + ("shared_experts.gate_up_proj", "shared_experts.up_proj", 1), ] + expert_params_mapping = [ + # (param_name, weight_name, shard_id, expert_id) + ( + "w1" if weight_name in ["gate_proj", "up_proj"] else "w2", + f"experts.{expert_id}.{weight_name}", + shard_id, + expert_id + ) + for expert_id in range(self.config.n_routed_experts) + for weight_name, shard_id in [ + ("gate_proj", 0), ("up_proj", 1), ("down_proj", None)] + ] if self.linear_method is None or self.linear_method.support_fused_moe else [] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, @@ -440,14 +511,33 @@ def load_weights(self, weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for param_name, weight_name, shard_id, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + if shard_id is None: + weight_loader(param, + loaded_weight, + expert_id=expert_id) + else: + weight_loader(param, + loaded_weight, + shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index edeafd31f99c8..fcedaacf285fc 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -136,7 +136,7 @@ def __init__( params_dtype=self.params_dtype, linear_method=None) - if not hasattr(self.linear_method, "apply_moe_weights"): + if not self.linear_method.support_fused_moe: if self.tp_size > self.num_total_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -181,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - if not hasattr(self.linear_method, "apply_moe_weights"): + if not self.linear_method.support_fused_moe: final_hidden_states = None for expert_idx in self.expert_indicies: expert_layer = self.experts[expert_idx] @@ -444,8 +444,7 @@ def load_weights(self, ) for expert_id in range(self.config.num_local_experts) for weight_name, shard_id in [("w1", 0), ("w3", 1), ("w2", None)] - ] if self.linear_method is None or hasattr( - self.linear_method, "apply_moe_weights") else [] + ] if self.linear_method is None or self.linear_method.support_fused_moe else [] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( @@ -491,6 +490,10 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip experts that are not assigned to this worker. + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From a9d65a9c10787cc49175c0228912e447240cd5bd Mon Sep 17 00:00:00 2001 From: chutianxiang Date: Tue, 6 Feb 2024 23:39:26 +0800 Subject: [PATCH 05/20] Add test --- tests/kernels/test_moe.py | 137 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 227ddfc3661b3..e21ee7121cff8 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -9,7 +9,15 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from vllm._C import ops + from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.awq import ( + AWQConfig, AWQLinearMethod +) +from vllm.model_executor.layers.quantization.gptq import ( + ExllamaState, GPTQConfig, GPTQLinearMethod +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.models.mixtral import MixtralMoE @@ -102,3 +110,132 @@ def test_mixtral_moe(dtype: torch.dtype): vllm_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) + + +def torch_moe_gptq(a, w1, w1_gidx, w1_scale, w1_zero, w2, + w2_gidx, w2_scale, w2_zero, topk_weight, topk_ids): + (B, D) = a.shape + a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) + out = torch.zeros(B * topk_ids.shape[1], w2.shape[2], dtype=a.dtype, device=a.device) + topk_ids = topk_ids.view(-1) + topk_weight = topk_weight.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + dw1 = ops.dequant_gptq(w1[i], w1_zero[i], w1_scale[i], w1_gidx[i], False) + dw2 = ops.dequant_gptq(w2[i], w2_zero[i], w2_scale[i], w2_gidx[i], False) + r1 = SiluAndMul()(torch.matmul(a[mask], dw1)) + out[mask] = torch.matmul(r1,dw2) + return (out.view(B, -1, w2.shape[2]) * topk_weight.view(B, -1, 1)).sum(dim=1).half() + + +@pytest.mark.parametrize("m", [1, 16, 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", [8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("exstate", + [ExllamaState.UNINITIALIZED, ExllamaState.UNUSED]) +@pytest.mark.parametrize("groupsize", [-1, 128]) +def test_fused_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + exstate: ExllamaState, + groupsize: int +): + RANGE = 1000000000 + a = torch.randn((m, k), device='cuda', dtype=torch.half) / 10 + qw1 = torch.randint(-RANGE, RANGE, (e, (k // 32) * 4, n * 2), + dtype=torch.int, device='cuda') + qw2 = torch.randint(-RANGE, RANGE, (e, (n // 32) * 4, k), + dtype=torch.int, device='cuda') + + groupsize1 = groupsize if groupsize != -1 else k + groupsize2 = groupsize if groupsize != -1 else n + gidx1 = torch.tensor([i // groupsize1 for i in range(k)], dtype=torch.int32, + device='cuda').unsqueeze(0).expand(e, k).contiguous() + gidx2 = torch.tensor([i // groupsize2 for i in range(n)], dtype=torch.int32, + device='cuda').unsqueeze(0).expand(e, n).contiguous() + + scale1 = torch.randn((e, k // groupsize1, n * 2), dtype=torch.half, device='cuda') / 50 + scale2 = torch.randn((e, n // groupsize2, k), dtype=torch.half, device='cuda') / 50 + + zero1 = torch.randint(-RANGE, RANGE, (e, k // groupsize1, (n * 2// 32) * 4), + dtype=torch.int32, device='cuda') + zero2 = torch.randint(-RANGE, RANGE, (e, n // groupsize2, (k // 32) * 4), + dtype=torch.int32, device='cuda') + w1 = {"qweight": qw1, "g_idx": gidx1, "scales": scale1, "qzeros": zero1, + "exllama_state": exstate} + w2 = {"qweight": qw2, "g_idx": gidx2, "scales": scale2, "qzeros": zero2, + "exllama_state": exstate} + + score = torch.randn((m, e), device='cuda', dtype=torch.half) + score = torch.softmax(score, dim=-1).float() + topk_weight, topk_ids = torch.topk(score, topk) + + gptq_method = GPTQLinearMethod(GPTQConfig(4, groupsize, False)) + torch_output = torch_moe_gptq(a, qw1, gidx1, scale1, zero1, qw2, gidx2, scale2, zero2, + topk_weight, topk_ids) + cuda_output = gptq_method.apply_moe_weights(w1, w2, a, topk_weight, topk_ids) + # gptq kernels have large variance in output + assert torch.allclose(cuda_output, torch_output, atol=5e-2, rtol=0) + + +def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, topk_weight, topk_ids): + (B, D) = a.shape + a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) + out = torch.zeros(B * topk_ids.shape[1], w2.shape[2] * 8, dtype=a.dtype, device=a.device) + topk_ids = topk_ids.view(-1) + topk_weight = topk_weight.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + dw1 = ops.awq_dequantize(w1[i], w1_scale[i], w1_zero[i], 0, 0, 0) + dw2 = ops.awq_dequantize(w2[i], w2_scale[i], w2_zero[i], 0, 0, 0) + r1 = SiluAndMul()(torch.matmul(a[mask].half(), dw1)) + out[mask] = torch.matmul(r1,dw2).to(out.dtype) + return (out.view(B, -1, w2.shape[2] * 8) * topk_weight.view(B, -1, 1)).sum(dim=1).half() + + +@pytest.mark.parametrize("m", [1, 16, 128, 1024]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", [8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +def test_fused_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, +): + RANGE = 1000000000 + groupsize = 128 + a = torch.randn((m, k), device='cuda', dtype=torch.half) / 10 + qw1 = torch.randint(-RANGE, RANGE, (e, k, n * 2 // 8), + dtype=torch.int, device='cuda') + qw2 = torch.randint(-RANGE, RANGE, (e, n, k // 8), + dtype=torch.int, device='cuda') + + scale1 = torch.randn((e, k // groupsize, n * 2), dtype=torch.half, device='cuda') / 50 + scale2 = torch.randn((e, n // groupsize, k), dtype=torch.half, device='cuda') / 50 + + zero1 = torch.randint(-RANGE, RANGE, (e, k // groupsize, (n * 2// 32) * 4), + dtype=torch.int32, device='cuda') + zero2 = torch.randint(-RANGE, RANGE, (e, n // groupsize, (k // 32) * 4), + dtype=torch.int32, device='cuda') + w1 = {"qweight": qw1, "scales": scale1, "qzeros": zero1} + w2 = {"qweight": qw2, "scales": scale2, "qzeros": zero2} + + score = torch.randn((m, e), device='cuda', dtype=torch.half) + score = torch.softmax(score, dim=-1).float() + topk_weight, topk_ids = torch.topk(score, topk) + + awq_method = AWQLinearMethod(AWQConfig(4, groupsize, False)) + torch_output = torch_moe_awq(a, qw1, scale1, zero1, qw2, scale2, zero2, + topk_weight, topk_ids) + cuda_output = awq_method.apply_moe_weights(w1, w2, a, topk_weight, topk_ids) + assert torch.allclose(cuda_output, torch_output, atol=5e-2, rtol=0) \ No newline at end of file From 7dea006a08fcb2c05b9a744da2760fbb57f54624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Tue, 6 Feb 2024 23:56:23 +0800 Subject: [PATCH 06/20] format --- csrc/quantization/gptq/q_gemm.cu | 2 +- tests/kernels/test_moe.py | 161 +++++++++++------- vllm/model_executor/layers/linear.py | 30 ++-- .../model_executor/layers/quantization/awq.py | 41 ++--- .../layers/quantization/gptq.py | 36 ++-- vllm/model_executor/models/deepseek.py | 33 ++-- vllm/model_executor/models/mixtral.py | 41 ++--- 7 files changed, 190 insertions(+), 154 deletions(-) diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 08aa6a7f43682..a2e3426cbb0ae 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -1432,4 +1432,4 @@ torch::Tensor dequant_gptq size_k, size_n, groups, num_experts, use_exllama); return temp_dq; -} \ No newline at end of file +} diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index e21ee7121cff8..90814808dde46 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -12,12 +12,11 @@ from vllm._C import ops from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.quantization.awq import ( - AWQConfig, AWQLinearMethod -) -from vllm.model_executor.layers.quantization.gptq import ( - ExllamaState, GPTQConfig, GPTQLinearMethod -) +from vllm.model_executor.layers.quantization.awq import (AWQConfig, + AWQLinearMethod) +from vllm.model_executor.layers.quantization.gptq import (ExllamaState, + GPTQConfig, + GPTQLinearMethod) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.models.mixtral import MixtralMoE @@ -112,21 +111,27 @@ def test_mixtral_moe(dtype: torch.dtype): atol=mixtral_moe_tol[dtype]) -def torch_moe_gptq(a, w1, w1_gidx, w1_scale, w1_zero, w2, - w2_gidx, w2_scale, w2_zero, topk_weight, topk_ids): +def torch_moe_gptq(a, w1, w1_gidx, w1_scale, w1_zero, w2, w2_gidx, w2_scale, + w2_zero, topk_weight, topk_ids): (B, D) = a.shape a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) - out = torch.zeros(B * topk_ids.shape[1], w2.shape[2], dtype=a.dtype, device=a.device) + out = torch.zeros(B * topk_ids.shape[1], + w2.shape[2], + dtype=a.dtype, + device=a.device) topk_ids = topk_ids.view(-1) topk_weight = topk_weight.view(-1) for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - dw1 = ops.dequant_gptq(w1[i], w1_zero[i], w1_scale[i], w1_gidx[i], False) - dw2 = ops.dequant_gptq(w2[i], w2_zero[i], w2_scale[i], w2_gidx[i], False) + dw1 = ops.dequant_gptq(w1[i], w1_zero[i], w1_scale[i], w1_gidx[i], + False) + dw2 = ops.dequant_gptq(w2[i], w2_zero[i], w2_scale[i], w2_gidx[i], + False) r1 = SiluAndMul()(torch.matmul(a[mask], dw1)) - out[mask] = torch.matmul(r1,dw2) - return (out.view(B, -1, w2.shape[2]) * topk_weight.view(B, -1, 1)).sum(dim=1).half() + out[mask] = torch.matmul(r1, dw2) + return (out.view(B, -1, w2.shape[2]) * + topk_weight.view(B, -1, 1)).sum(dim=1).half() @pytest.mark.parametrize("m", [1, 16, 128]) @@ -137,57 +142,77 @@ def torch_moe_gptq(a, w1, w1_gidx, w1_scale, w1_zero, w2, @pytest.mark.parametrize("exstate", [ExllamaState.UNINITIALIZED, ExllamaState.UNUSED]) @pytest.mark.parametrize("groupsize", [-1, 128]) -def test_fused_moe( - m: int, - n: int, - k: int, - e: int, - topk: int, - exstate: ExllamaState, - groupsize: int -): +def test_fused_moe_gptq(m: int, n: int, k: int, e: int, topk: int, + exstate: ExllamaState, groupsize: int): RANGE = 1000000000 a = torch.randn((m, k), device='cuda', dtype=torch.half) / 10 - qw1 = torch.randint(-RANGE, RANGE, (e, (k // 32) * 4, n * 2), - dtype=torch.int, device='cuda') - qw2 = torch.randint(-RANGE, RANGE, (e, (n // 32) * 4, k), - dtype=torch.int, device='cuda') + qw1 = torch.randint(-RANGE, + RANGE, (e, (k // 32) * 4, n * 2), + dtype=torch.int, + device='cuda') + qw2 = torch.randint(-RANGE, + RANGE, (e, (n // 32) * 4, k), + dtype=torch.int, + device='cuda') groupsize1 = groupsize if groupsize != -1 else k groupsize2 = groupsize if groupsize != -1 else n - gidx1 = torch.tensor([i // groupsize1 for i in range(k)], dtype=torch.int32, + gidx1 = torch.tensor([i // groupsize1 for i in range(k)], + dtype=torch.int32, device='cuda').unsqueeze(0).expand(e, k).contiguous() - gidx2 = torch.tensor([i // groupsize2 for i in range(n)], dtype=torch.int32, + gidx2 = torch.tensor([i // groupsize2 for i in range(n)], + dtype=torch.int32, device='cuda').unsqueeze(0).expand(e, n).contiguous() - scale1 = torch.randn((e, k // groupsize1, n * 2), dtype=torch.half, device='cuda') / 50 - scale2 = torch.randn((e, n // groupsize2, k), dtype=torch.half, device='cuda') / 50 - - zero1 = torch.randint(-RANGE, RANGE, (e, k // groupsize1, (n * 2// 32) * 4), - dtype=torch.int32, device='cuda') - zero2 = torch.randint(-RANGE, RANGE, (e, n // groupsize2, (k // 32) * 4), - dtype=torch.int32, device='cuda') - w1 = {"qweight": qw1, "g_idx": gidx1, "scales": scale1, "qzeros": zero1, - "exllama_state": exstate} - w2 = {"qweight": qw2, "g_idx": gidx2, "scales": scale2, "qzeros": zero2, - "exllama_state": exstate} + scale1 = torch.randn( + (e, k // groupsize1, n * 2), dtype=torch.half, device='cuda') / 50 + scale2 = torch.randn( + (e, n // groupsize2, k), dtype=torch.half, device='cuda') / 50 + + zero1 = torch.randint(-RANGE, + RANGE, (e, k // groupsize1, (n * 2 // 32) * 4), + dtype=torch.int32, + device='cuda') + zero2 = torch.randint(-RANGE, + RANGE, (e, n // groupsize2, (k // 32) * 4), + dtype=torch.int32, + device='cuda') + w1 = { + "qweight": qw1, + "g_idx": gidx1, + "scales": scale1, + "qzeros": zero1, + "exllama_state": exstate + } + w2 = { + "qweight": qw2, + "g_idx": gidx2, + "scales": scale2, + "qzeros": zero2, + "exllama_state": exstate + } score = torch.randn((m, e), device='cuda', dtype=torch.half) score = torch.softmax(score, dim=-1).float() topk_weight, topk_ids = torch.topk(score, topk) gptq_method = GPTQLinearMethod(GPTQConfig(4, groupsize, False)) - torch_output = torch_moe_gptq(a, qw1, gidx1, scale1, zero1, qw2, gidx2, scale2, zero2, - topk_weight, topk_ids) - cuda_output = gptq_method.apply_moe_weights(w1, w2, a, topk_weight, topk_ids) + torch_output = torch_moe_gptq(a, qw1, gidx1, scale1, zero1, qw2, gidx2, + scale2, zero2, topk_weight, topk_ids) + cuda_output = gptq_method.apply_moe_weights(w1, w2, a, topk_weight, + topk_ids) # gptq kernels have large variance in output assert torch.allclose(cuda_output, torch_output, atol=5e-2, rtol=0) -def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, topk_weight, topk_ids): +def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, topk_weight, + topk_ids): (B, D) = a.shape a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) - out = torch.zeros(B * topk_ids.shape[1], w2.shape[2] * 8, dtype=a.dtype, device=a.device) + out = torch.zeros(B * topk_ids.shape[1], + w2.shape[2] * 8, + dtype=a.dtype, + device=a.device) topk_ids = topk_ids.view(-1) topk_weight = topk_weight.view(-1) for i in range(w1.shape[0]): @@ -196,8 +221,9 @@ def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, topk_weight, dw1 = ops.awq_dequantize(w1[i], w1_scale[i], w1_zero[i], 0, 0, 0) dw2 = ops.awq_dequantize(w2[i], w2_scale[i], w2_zero[i], 0, 0, 0) r1 = SiluAndMul()(torch.matmul(a[mask].half(), dw1)) - out[mask] = torch.matmul(r1,dw2).to(out.dtype) - return (out.view(B, -1, w2.shape[2] * 8) * topk_weight.view(B, -1, 1)).sum(dim=1).half() + out[mask] = torch.matmul(r1, dw2).to(out.dtype) + return (out.view(B, -1, w2.shape[2] * 8) * + topk_weight.view(B, -1, 1)).sum(dim=1).half() @pytest.mark.parametrize("m", [1, 16, 128, 1024]) @@ -205,7 +231,7 @@ def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, topk_weight, @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) -def test_fused_moe( +def test_fused_moe_awq( m: int, n: int, k: int, @@ -215,18 +241,28 @@ def test_fused_moe( RANGE = 1000000000 groupsize = 128 a = torch.randn((m, k), device='cuda', dtype=torch.half) / 10 - qw1 = torch.randint(-RANGE, RANGE, (e, k, n * 2 // 8), - dtype=torch.int, device='cuda') - qw2 = torch.randint(-RANGE, RANGE, (e, n, k // 8), - dtype=torch.int, device='cuda') - - scale1 = torch.randn((e, k // groupsize, n * 2), dtype=torch.half, device='cuda') / 50 - scale2 = torch.randn((e, n // groupsize, k), dtype=torch.half, device='cuda') / 50 - - zero1 = torch.randint(-RANGE, RANGE, (e, k // groupsize, (n * 2// 32) * 4), - dtype=torch.int32, device='cuda') - zero2 = torch.randint(-RANGE, RANGE, (e, n // groupsize, (k // 32) * 4), - dtype=torch.int32, device='cuda') + qw1 = torch.randint(-RANGE, + RANGE, (e, k, n * 2 // 8), + dtype=torch.int, + device='cuda') + qw2 = torch.randint(-RANGE, + RANGE, (e, n, k // 8), + dtype=torch.int, + device='cuda') + + scale1 = torch.randn( + (e, k // groupsize, n * 2), dtype=torch.half, device='cuda') / 50 + scale2 = torch.randn( + (e, n // groupsize, k), dtype=torch.half, device='cuda') / 50 + + zero1 = torch.randint(-RANGE, + RANGE, (e, k // groupsize, (n * 2 // 32) * 4), + dtype=torch.int32, + device='cuda') + zero2 = torch.randint(-RANGE, + RANGE, (e, n // groupsize, (k // 32) * 4), + dtype=torch.int32, + device='cuda') w1 = {"qweight": qw1, "scales": scale1, "qzeros": zero1} w2 = {"qweight": qw2, "scales": scale2, "qzeros": zero2} @@ -236,6 +272,7 @@ def test_fused_moe( awq_method = AWQLinearMethod(AWQConfig(4, groupsize, False)) torch_output = torch_moe_awq(a, qw1, scale1, zero1, qw2, scale2, zero2, - topk_weight, topk_ids) - cuda_output = awq_method.apply_moe_weights(w1, w2, a, topk_weight, topk_ids) - assert torch.allclose(cuda_output, torch_output, atol=5e-2, rtol=0) \ No newline at end of file + topk_weight, topk_ids) + cuda_output = awq_method.apply_moe_weights(w1, w2, a, topk_weight, + topk_ids) + assert torch.allclose(cuda_output, torch_output, atol=5e-2, rtol=0) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index a2b4b3ee96517..a963687e8f3e3 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -37,7 +37,8 @@ def apply_weights(self, """Apply the weights to the input tensor.""" raise NotImplementedError - def create_moe_weights(self, num_experts: int, input_size_per_partition: int, + def create_moe_weights(self, num_experts: int, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: @@ -48,7 +49,7 @@ def create_moe_weights(self, num_experts: int, input_size_per_partition: int, params_dtype) for name, param in tuple(linear_weights.items()): if isinstance(param, Parameter): - repeat_size = (num_experts,) + (1,) * param.dim() + repeat_size = (num_experts, ) + (1, ) * param.dim() new_param = Parameter(param.unsqueeze(0).repeat(*repeat_size), requires_grad=False) set_weight_attrs(new_param, param.__dict__) @@ -203,12 +204,11 @@ def __init__( if num_experts > 1: self.linear_weights = self.linear_method.create_moe_weights( num_experts, self.input_size, self.output_size_per_partition, - self.input_size, self.output_size, self.params_dtype - ) + self.input_size, self.output_size, self.params_dtype) else: self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) + self.input_size, self.output_size_per_partition, + self.input_size, self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -224,7 +224,9 @@ def __init__( else: self.register_parameter("bias", None) - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, expert_id: int = 0): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) @@ -290,7 +292,8 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method, num_experts) + skip_bias_add, params_dtype, linear_method, + num_experts) def weight_loader(self, param: Parameter, @@ -539,12 +542,11 @@ def __init__( if num_experts > 1: self.linear_weights = self.linear_method.create_moe_weights( num_experts, self.input_size_per_partition, self.output_size, - self.input_size, self.output_size, self.params_dtype - ) + self.input_size, self.output_size, self.params_dtype) else: self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) + self.input_size_per_partition, self.output_size, + self.input_size, self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) @@ -564,7 +566,9 @@ def __init__( else: self.register_parameter("bias", None) - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, expert_id: int = 0): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 7820fe4e8b4b8..6686a334cfba1 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,8 +4,8 @@ from torch.nn.parameter import Parameter from vllm._C import ops -from vllm.model_executor.layers.fused_moe import ( - moe_align_block_size, fused_moe) +from vllm.model_executor.layers.fused_moe import (moe_align_block_size, + fused_moe) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -175,32 +175,35 @@ def apply_moe_weights(self, topk_ids: torch.Tensor) -> torch.Tensor: FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1024 if FP16_MATMUL_HEURISTIC_CONDITION: - dequant_w1 = ops.awq_dequantize( - w1["qweight"], w1["scales"], w1["qzeros"], 0, 0, 0 - ).permute(0, 2, 1) - dequant_w2 = ops.awq_dequantize( - w2["qweight"], w2["scales"], w2["qzeros"], 0, 0, 0 - ).permute(0, 2, 1) + dequant_w1 = ops.awq_dequantize(w1["qweight"], w1["scales"], + w1["qzeros"], 0, 0, + 0).permute(0, 2, 1) + dequant_w2 = ops.awq_dequantize(w2["qweight"], w2["scales"], + w2["qzeros"], 0, 0, + 0).permute(0, 2, 1) return fused_moe(x, dequant_w1, dequant_w2, topk_weights, topk_ids) - (sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size( - topk_ids, 16, w1["qweight"].shape[0]) + (sorted_token_ids, expert_ids, + num_tokens_post_padded) = moe_align_block_size( + topk_ids, 16, w1["qweight"].shape[0]) x = x.view(x.shape[0], 1, *x.shape[1:]) pack_factor = self.quant_config.pack_factor - gate_up = ops.awq_group_gemm( - x, w1["qweight"], w1["scales"], w1["qzeros"], - topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, - False, pack_factor) + gate_up = ops.awq_group_gemm(x, w1["qweight"], w1["scales"], + w1["qzeros"], topk_weights, + sorted_token_ids, expert_ids, + num_tokens_post_padded, False, + pack_factor) out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), - dtype=x.dtype, device=x.device) + dtype=x.dtype, + device=x.device) ops.silu_and_mul(out, gate_up) - out = ops.awq_group_gemm( - out, w2["qweight"], w2["scales"], w2["qzeros"], - topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, - True, pack_factor) + out = ops.awq_group_gemm(out, w2["qweight"], w2["scales"], + w2["qzeros"], topk_weights, sorted_token_ids, + expert_ids, num_tokens_post_padded, True, + pack_factor) return torch.sum(out, dim=1) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 5e4cdebb346d0..cd4b3b4c3e1bb 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -6,8 +6,8 @@ from torch.nn.parameter import Parameter from vllm._C import ops -from vllm.model_executor.layers.fused_moe import ( - moe_align_block_size, fused_moe) +from vllm.model_executor.layers.fused_moe import (moe_align_block_size, + fused_moe) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( @@ -223,8 +223,8 @@ def apply_moe_weights(self, for w in [w1, w2]: if w["exllama_state"] == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: - w["g_idx"] = torch.argsort(w["g_idx"], dim=-1).to( - torch.int) + w["g_idx"] = torch.argsort(w["g_idx"], + dim=-1).to(torch.int) else: w["g_idx"] = torch.empty((1, 1), device="meta") w["exllama_state"] = ExllamaState.READY @@ -232,23 +232,23 @@ def apply_moe_weights(self, for i in range(w["qweight"].shape[0]): ops.gptq_shuffle( w["qweight"][i], - w["g_idx"][i] if w["g_idx"].device != torch.device( - "meta") else w["g_idx"], + w["g_idx"][i] + if w["g_idx"].device != torch.device("meta") else + w["g_idx"], ) if x.shape[0] >= 128: dequant_w1 = ops.dequant_gptq( w1["qweight"], w1["qzeros"], w1["scales"], w1["g_idx"], - w1["exllama_state"] == ExllamaState.READY - ).permute(0, 2, 1) + w1["exllama_state"] == ExllamaState.READY).permute(0, 2, 1) dequant_w2 = ops.dequant_gptq( w2["qweight"], w2["qzeros"], w2["scales"], w2["g_idx"], - w2["exllama_state"] == ExllamaState.READY - ).permute(0, 2, 1) + w2["exllama_state"] == ExllamaState.READY).permute(0, 2, 1) return fused_moe(x, dequant_w1, dequant_w2, topk_weights, topk_ids) - (sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size( - topk_ids, 8, w1["qweight"].shape[0]) + (sorted_token_ids, expert_ids, + num_tokens_post_padded) = moe_align_block_size( + topk_ids, 8, w1["qweight"].shape[0]) x = x.view(x.shape[0], 1, *x.shape[1:]) gate_up = ops.group_gptq_gemm( @@ -257,12 +257,14 @@ def apply_moe_weights(self, False, w1["exllama_state"] == ExllamaState.READY) out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), - dtype=x.dtype, device=x.device) + dtype=x.dtype, + device=x.device) ops.silu_and_mul(out, gate_up) - out = ops.group_gptq_gemm( - out, w2["qweight"], w2["qzeros"], w2["scales"], w2["g_idx"], - topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, - True, w2["exllama_state"] == ExllamaState.READY) + out = ops.group_gptq_gemm(out, w2["qweight"], w2["qzeros"], + w2["scales"], w2["g_idx"], topk_weights, + sorted_token_ids, expert_ids, + num_tokens_post_padded, True, + w2["exllama_state"] == ExllamaState.READY) return torch.sum(out, dim=1) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 09619dfe5d1ff..ee83483bce4be 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -32,14 +32,10 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (UnquantizedLinearMethod, - LinearMethodBase, - MergedColumnParallelLinear, - ReplicatedLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + UnquantizedLinearMethod, LinearMethodBase, MergedColumnParallelLinear, + ReplicatedLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -156,8 +152,7 @@ def __init__( intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, linear_method=linear_method, - ) - if idx in self.expert_indicies else None + ) if idx in self.expert_indicies else None for idx in range(self.n_routed_experts) ]) else: @@ -166,12 +161,11 @@ def __init__( bias=False, linear_method=linear_method, num_experts=self.n_routed_experts) - self.w2 = RowParallelLinear( - config.moe_intermediate_size, - config.hidden_size, - bias=False, - linear_method=linear_method, - num_experts=self.n_routed_experts) + self.w2 = RowParallelLinear(config.moe_intermediate_size, + config.hidden_size, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, @@ -209,8 +203,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for expert_idx in self.expert_indicies: expert_layer = self.experts[expert_idx] expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) + expert_weights = (routing_weights * expert_mask).sum( + dim=-1, keepdim=True) current_hidden_states = expert_layer(hidden_states).mul_( expert_weights) @@ -534,8 +528,9 @@ def load_weights(self, if name.endswith(".bias") and name not in params_dict: continue # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): + if (("mlp.experts." in name + or "mlp.shared_experts." in name) + and name not in params_dict): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index fcedaacf285fc..8bbaf5046d6bf 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -32,14 +32,10 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - LinearMethodBase, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, LinearMethodBase, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -49,7 +45,6 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -142,8 +137,9 @@ def __init__( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.num_total_experts}.") # Split experts equally between ranks - self.expert_indicies = np.array_split(range( - self.num_total_experts), self.tp_size)[self.rank].tolist() + self.expert_indicies = np.array_split( + range(self.num_total_experts), + self.tp_size)[self.rank].tolist() if not self.expert_indicies: raise ValueError( f"Rank {self.rank} has no experts assigned to it.") @@ -157,17 +153,16 @@ def __init__( for idx in range(self.num_total_experts) ]) else: - self.ws = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - linear_method=linear_method, - num_experts=num_experts) - self.w2s = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method, - num_experts=num_experts) + self.ws = MergedColumnParallelLinear(hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + num_experts=num_experts) + self.w2s = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method, + num_experts=num_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_size = hidden_states.shape @@ -186,8 +181,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for expert_idx in self.expert_indicies: expert_layer = self.experts[expert_idx] expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) + expert_weights = (routing_weights * expert_mask).sum( + dim=-1, keepdim=True) current_hidden_states = expert_layer(hidden_states).mul_( expert_weights) From 46d15fbbb67444bb66692f2cf3f29a861f230205 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Wed, 7 Feb 2024 13:51:56 +0800 Subject: [PATCH 07/20] format --- vllm/model_executor/layers/linear.py | 24 ++++++++----------- .../model_executor/layers/quantization/awq.py | 12 ++++------ .../layers/quantization/gptq.py | 12 ++++------ .../layers/quantization/squeezellm.py | 12 ++++------ 4 files changed, 25 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 0c16dbb6f5d3a..fff7ad6481e05 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -57,13 +57,11 @@ def create_moe_weights(self, num_experts: int, return linear_weights @abstractmethod - def apply_moe_weights(self, - w1: Dict[str, torch.Tensor], - w2: Dict[str, torch.Tensor], - x: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool) -> torch.Tensor: + def apply_moe_weights(self, w1: Dict[str, + torch.Tensor], w2: Dict[str, + torch.Tensor], + x: torch.Tensor, gating_output: torch.Tensor, + topk: int, renormalize: bool) -> torch.Tensor: """Apply the weights to the input tensor.""" raise NotImplementedError @@ -102,13 +100,11 @@ def apply_weights(self, return F.linear(x, weight) return F.linear(x, weight, bias) - def apply_moe_weights(self, - w1: Dict[str, torch.Tensor], - w2: Dict[str, torch.Tensor], - x: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool) -> torch.Tensor: + def apply_moe_weights(self, w1: Dict[str, + torch.Tensor], w2: Dict[str, + torch.Tensor], + x: torch.Tensor, gating_output: torch.Tensor, + topk: int, renormalize: bool) -> torch.Tensor: return fused_moe(x, w1["weight"], w2["weight"], gating_output, topk, renormalize) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index d9220bba01e46..a123b0451ad8e 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -167,13 +167,11 @@ def apply_weights(self, out = out + bias return out.reshape(out_shape) - def apply_moe_weights(self, - w1: Dict[str, torch.Tensor], - w2: Dict[str, torch.Tensor], - x: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool) -> torch.Tensor: + def apply_moe_weights(self, w1: Dict[str, + torch.Tensor], w2: Dict[str, + torch.Tensor], + x: torch.Tensor, gating_output: torch.Tensor, + topk: int, renormalize: bool) -> torch.Tensor: FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1024 if FP16_MATMUL_HEURISTIC_CONDITION: dequant_w1 = ops.awq_dequantize(w1["qweight"], w1["scales"], diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 96aa73b61ea1d..898ac2914f5e8 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -213,13 +213,11 @@ def apply_weights(self, output = output + bias return output.reshape(out_shape) - def apply_moe_weights(self, - w1: Dict[str, torch.Tensor], - w2: Dict[str, torch.Tensor], - x: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool) -> torch.Tensor: + def apply_moe_weights(self, w1: Dict[str, + torch.Tensor], w2: Dict[str, + torch.Tensor], + x: torch.Tensor, gating_output: torch.Tensor, + topk: int, renormalize: bool) -> torch.Tensor: # shuffle weights for exllama for w in [w1, w2]: if w["exllama_state"] == ExllamaState.UNINITIALIZED: diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index b3c3b395abd59..d597acbbcdc15 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -129,11 +129,9 @@ def apply_weights(self, out = out + bias return out.reshape(out_shape) - def apply_moe_weights(self, - w1: Dict[str, torch.Tensor], - w2: Dict[str, torch.Tensor], - x: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool) -> torch.Tensor: + def apply_moe_weights(self, w1: Dict[str, + torch.Tensor], w2: Dict[str, + torch.Tensor], + x: torch.Tensor, gating_output: torch.Tensor, + topk: int, renormalize: bool) -> torch.Tensor: raise NotImplementedError From 6b3e23e481351338bf6857fd99244c5f770b119f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Fri, 23 Feb 2024 14:18:12 +0800 Subject: [PATCH 08/20] Fix unit test --- tests/kernels/test_moe.py | 31 +++++++++++++++++++++++---- vllm/model_executor/models/mixtral.py | 6 ------ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 1170975284921..738bb17c91dd2 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -2,6 +2,8 @@ Run `pytest tests/kernels/test_moe.py`. """ +import tempfile + import pytest import torch from transformers import MixtralConfig @@ -17,6 +19,8 @@ GPTQLinearMethod) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.model_executor.parallel_utils.parallel_state import ( + destroy_model_parallel, initialize_model_parallel) def torch_moe(a, w1, w2, score, topk): @@ -55,7 +59,13 @@ def test_fused_moe( w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 score = torch.randn((m, e), device='cuda', dtype=dtype) - triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + renormalize=False, + inplace=False) torch_output = torch_moe(a, w1, w2, score, topk) assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) @@ -65,6 +75,17 @@ def test_fused_moe( @torch.inference_mode() def test_mixtral_moe(dtype: torch.dtype): "Make sure our Mixtral MoE implementation agrees with the one from huggingface." + # Initialize dist environment + if not torch.distributed.is_initialized(): + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + initialize_model_parallel() + torch.set_default_dtype(dtype) # Instantiate our and huggingface's MoE blocks config = MixtralConfig() @@ -74,7 +95,6 @@ def test_mixtral_moe(dtype: torch.dtype): top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - params_dtype=dtype, tp_size=1, ).cuda() @@ -83,8 +103,8 @@ def test_mixtral_moe(dtype: torch.dtype): for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) - vllm_moe.ws[i][:] = torch.cat(weights, dim=0) - vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data + vllm_moe.ws.weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.w2s.weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") @@ -93,6 +113,9 @@ def test_mixtral_moe(dtype: torch.dtype): hf_states, _ = hf_moe.forward(inputs) vllm_states = vllm_moe.forward(inputs) + # destroy dist environment + destroy_model_parallel() + mixtral_moe_tol = { torch.float32: 1e-3, torch.float16: 1e-3, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 2a2f3108960a8..c2fd2f69239f0 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -106,7 +106,6 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, linear_method: Optional[LinearMethodBase] = None, ): @@ -121,14 +120,9 @@ def __init__( if self.linear_method is None: self.linear_method = UnquantizedLinearMethod() - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - self.gate = ReplicatedLinear(self.hidden_size, self.num_total_experts, bias=False, - params_dtype=self.params_dtype, linear_method=None) if not self.linear_method.support_fused_moe: From d43445ebc804c238e78650670aff530092e3b104 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Sat, 24 Feb 2024 10:51:48 +0800 Subject: [PATCH 09/20] Add guard for awq unit test --- tests/kernels/test_moe.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 738bb17c91dd2..8c2fbe53dd76f 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -154,7 +154,7 @@ def torch_moe_gptq(a, w1, w1_gidx, w1_scale, w1_zero, w2, w2_gidx, w2_scale, @pytest.mark.parametrize("m", [1, 16, 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("n", [128, 512, 1024]) @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @@ -257,6 +257,14 @@ def test_fused_moe_awq( e: int, topk: int, ): + # awq requires minimum capablity 75 + if torch.version.hip is not None: + return + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < 75: + return + RANGE = 1000000000 groupsize = 128 a = torch.randn((m, k), device='cuda', dtype=torch.half) / 10 From 2c68478fe0f6bb1f0177a0aab41170ee57c43b79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Sat, 24 Feb 2024 11:04:14 +0800 Subject: [PATCH 10/20] Fix format --- tests/kernels/test_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 8c2fbe53dd76f..5260a32460e32 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -257,7 +257,7 @@ def test_fused_moe_awq( e: int, topk: int, ): - # awq requires minimum capablity 75 + # awq requires minimum capability 75 if torch.version.hip is not None: return capability = torch.cuda.get_device_capability() From 2c27dcc3b3dc84cb503c8dcfc61089d7f21c5f81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Sat, 24 Feb 2024 15:53:17 +0800 Subject: [PATCH 11/20] test --- tests/kernels/test_moe.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 5260a32460e32..38c3643ba2dca 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -153,10 +153,10 @@ def torch_moe_gptq(a, w1, w1_gidx, w1_scale, w1_zero, w2, w2_gidx, w2_scale, topk_weight.view(B, -1, 1)).sum(dim=1).half() -@pytest.mark.parametrize("m", [1, 16, 128]) -@pytest.mark.parametrize("n", [128, 512, 1024]) +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 512, 1024]) -@pytest.mark.parametrize("e", [8, 64]) +@pytest.mark.parametrize("e", [8]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("exstate", [ExllamaState.UNINITIALIZED, ExllamaState.UNUSED]) @@ -245,10 +245,10 @@ def torch_moe_awq(a, w1, w1_scale, w1_zero, w2, w2_scale, w2_zero, score, topk_weight.view(B, -1, 1)).sum(dim=1).half() -@pytest.mark.parametrize("m", [1, 16, 128, 1024]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("m", [1024, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 512, 1024]) -@pytest.mark.parametrize("e", [8, 64]) +@pytest.mark.parametrize("e", [8]) @pytest.mark.parametrize("topk", [2, 6]) def test_fused_moe_awq( m: int, From 68d34af3a5d69f3c50388c2088f8b740f4ccb4ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Tue, 27 Feb 2024 12:41:00 +0800 Subject: [PATCH 12/20] Fix import --- vllm/model_executor/layers/fused_moe/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 1391d43c8abeb..38efc61414505 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,5 +1,8 @@ -from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_moe, moe_align_block_size, fused_topk) __all__ = [ "fused_moe", + "moe_align_block_size", + "fused_topk" ] From d956844419c9aed4c47b99cdd8a558118bca2441 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Tue, 27 Feb 2024 13:49:57 +0800 Subject: [PATCH 13/20] fix format --- vllm/model_executor/layers/fused_moe/__init__.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 38efc61414505..549e04189d7a3 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,8 +1,4 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_moe, moe_align_block_size, fused_topk) -__all__ = [ - "fused_moe", - "moe_align_block_size", - "fused_topk" -] +__all__ = ["fused_moe", "moe_align_block_size", "fused_topk"] From 7a4ba90ad86837f44687029a8e3b76f83dc3cb4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Fri, 1 Mar 2024 16:01:00 +0800 Subject: [PATCH 14/20] Adapt gptq dequant to 3/8-bit --- csrc/ops.h | 1 + csrc/quantization/gptq/q_gemm.cu | 99 ++++++++++--------- .../layers/quantization/gptq.py | 2 + 3 files changed, 57 insertions(+), 45 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index acbcc972744e6..865d8708a43a4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -137,6 +137,7 @@ torch::Tensor dequant_gptq( torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + int bits, bool use_exllama ); diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index be5c67873f8ff..8970bdd3cd13c 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -1649,6 +1649,33 @@ void reconstruct_gptq } +void dequant_gptq_cuda +( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* temp_dq, + int size_k, + int size_n, + int groups, + int num_experts, + int bits, + bool use_exllama +) +{ + if (use_exllama) { + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, + size_k, size_n, groups, num_experts, bits); + } + else + { + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, num_experts, bits); + } +} + + void gemm_half_q_half_cuda ( cublasHandle_t cublas_handle, @@ -1676,15 +1703,8 @@ void gemm_half_q_half_cuda } if (use_reconstruct) { // Reconstruct FP16 matrix, then cuBLAS - if (use_exllama) { - reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, - size_k, size_n, groups, 1, bit); - } - else - { - reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, 1, bit); - } + dequant_gptq_cuda(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, + size_k, size_n, groups, 1, bit, use_exllama); const half alpha = __float2half(1.0f); const half beta = __float2half(0.0f); @@ -1833,6 +1853,11 @@ __global__ void make_sequential_2bit_kernel const int w_width ) { + if (blockIdx.z > 0){ + w = w + blockIdx.z * w_height * w_width; + w_new = w_new + blockIdx.z * w_height * w_width; + q_perm = q_perm + blockIdx.z * w_height * 16; + } const uint64_t* w2 = (uint64_t*) w; uint64_t* w_new2 = (uint64_t*) w_new; int w2_stride = w_width >> 1; @@ -1870,6 +1895,11 @@ __global__ void make_sequential_3bit_kernel const int w_width ) { + if (blockIdx.z > 0){ + w = w + blockIdx.z * w_height * w_width; + w_new = w_new + blockIdx.z * w_height * w_width; + q_perm = q_perm + blockIdx.z * w_height * 32 / 3; + } int w_column = THREADS_X * blockIdx.x + threadIdx.x; if (w_column >= w_width) return; int w_new_row = blockIdx.y * 3; @@ -1957,6 +1987,11 @@ __global__ void make_sequential_8bit_kernel const int w_width ) { + if (blockIdx.z > 0){ + w = w + blockIdx.z * w_height * w_width; + w_new = w_new + blockIdx.z * w_height * w_width; + q_perm = q_perm + blockIdx.z * w_height * 4; + } const uint64_t* w2 = (uint64_t*) w; uint64_t* w_new2 = (uint64_t*) w_new; int w2_stride = w_width >> 1; @@ -1985,8 +2020,7 @@ __global__ void make_sequential_8bit_kernel w_new2[w_new2_row * w2_stride + w2_column] = dst; } -// Only 4-bit support MoE -// todo: extend support to other bits + void shuffle_exllama_weight ( uint32_t* q_weight, @@ -2218,7 +2252,7 @@ __global__ void group_gemm_half_q_half_gptq_kernel } } -void group_gemm_half_q_half_cuda +void group_gemm_half_q_half ( const half* a, const uint32_t* b_q_weight, @@ -2465,7 +2499,7 @@ void group_gemm_half_q_half_cuda bool use_exllama ) { if (use_exllama) { - group_gemm_half_q_half_cuda( + group_gemm_half_q_half( a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, num_valid_tokens, @@ -2481,31 +2515,6 @@ void group_gemm_half_q_half_cuda } } -void dequant_gptq_cuda -( - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* temp_dq, - int size_k, - int size_n, - int groups, - int num_experts, - bool use_exllama -) -{ - if (use_exllama) { - reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, - size_k, size_n, groups, num_experts, 4); - } - else - { - reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, num_experts, 4); - } -} - } // namespace gptq } // namespace vllm @@ -2614,14 +2623,14 @@ torch::Tensor group_gptq_gemm return c; } -// Only support 4-bit -// todo: extend support to other bits + torch::Tensor dequant_gptq ( torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + int bits, bool use_exllama ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(b_gptq_scales)); @@ -2634,16 +2643,16 @@ torch::Tensor dequant_gptq int groups; // moe if (b_q_weight.dim() == 3) { - temp_dq = torch::empty({b_q_weight.size(0), b_q_weight.size(1) * 8, b_q_weight.size(2)}, options); + temp_dq = torch::empty({b_q_weight.size(0), b_q_weight.size(1) * 32 / bits, b_q_weight.size(2)}, options); num_experts = b_q_weight.size(0); - size_k = b_q_weight.size(1) * 8; + size_k = b_q_weight.size(1) * 32 / bits; size_n = b_q_weight.size(2); groups = b_gptq_scales.size(1); } else { - temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options); + temp_dq = torch::empty({b_q_weight.size(0) * 32 / bits, b_q_weight.size(1)}, options); num_experts = 1; - size_k = b_q_weight.size(0) * 8; + size_k = b_q_weight.size(0) * 32 / bits; size_n = b_q_weight.size(1); groups = b_gptq_scales.size(0); } @@ -2654,6 +2663,6 @@ torch::Tensor dequant_gptq b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), (half*) temp_dq.data_ptr(), size_k, size_n, groups, - num_experts, use_exllama); + num_experts, bits, use_exllama); return temp_dq; } diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 6267eae0d43c5..6839481cae57b 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -236,9 +236,11 @@ def apply_moe_weights(self, w1: Dict[str, if x.shape[0] >= 128: dequant_w1 = ops.dequant_gptq( w1["qweight"], w1["qzeros"], w1["scales"], w1["g_idx"], + self.quant_config.weight_bits, w1["exllama_state"] == ExllamaState.READY).permute(0, 2, 1) dequant_w2 = ops.dequant_gptq( w2["qweight"], w2["qzeros"], w2["scales"], w2["g_idx"], + self.quant_config.weight_bits, w2["exllama_state"] == ExllamaState.READY).permute(0, 2, 1) return fused_moe(x, dequant_w1, dequant_w2, gating_output, topk, renormalize) From 4ef69d543bd83d70236830363e4ab2414d29d5f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Sun, 3 Mar 2024 13:20:22 +0800 Subject: [PATCH 15/20] Fix marlin --- vllm/model_executor/layers/quantization/marlin.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 6194777faf9c3..f704b907bf250 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -211,3 +211,10 @@ def apply_weights( output.add_(bias) # In-place add return output + + def apply_moe_weights(self, w1: Dict[str, + torch.Tensor], w2: Dict[str, + torch.Tensor], + x: torch.Tensor, gating_output: torch.Tensor, + topk: int, renormalize: bool) -> torch.Tensor: + raise NotImplementedError From 9d6f7d18d60498568b841a97f7a84617738dd432 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Tue, 12 Mar 2024 21:35:03 +0800 Subject: [PATCH 16/20] Fix format check --- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1a83805d10a88..2aa7d76db9fb6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -253,7 +253,8 @@ def fused_topk( """Compute top-k indice and weights from gating logits Args: - gating_output (torch.Tensor): The output of the gating operation (before softmax). + gating_output (torch.Tensor): The output of the gating operation + (before softmax). topk (int): The number of top-k experts to select. renormalize (bool): If True, renormalize the top-k weights to sum to 1. """ From 4faebc368530e357630dc3912cbc7212fd603664 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Fri, 29 Mar 2024 20:37:52 +0800 Subject: [PATCH 17/20] Fix isort --- tests/kernels/test_moe.py | 3 +-- vllm/model_executor/layers/fused_moe/__init__.py | 2 +- vllm/model_executor/layers/linear.py | 2 +- vllm/model_executor/layers/quantization/awq.py | 4 ++-- vllm/model_executor/layers/quantization/gptq.py | 4 ++-- vllm/model_executor/models/deepseek.py | 9 ++++++--- vllm/model_executor/models/mixtral.py | 9 ++++++--- vllm/model_executor/models/qwen2_moe.py | 9 ++++++--- 8 files changed, 25 insertions(+), 17 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index f394854b802b9..a9dfcf9a97fe3 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -10,14 +10,13 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from vllm._C import ops - +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.quantization.awq import (AWQConfig, AWQLinearMethod) from vllm.model_executor.layers.quantization.gptq import (ExllamaState, GPTQConfig, GPTQLinearMethod) -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.parallel_utils.parallel_state import ( destroy_model_parallel, initialize_model_parallel) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3b229ec33ec94..c399facdcf23b 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,5 +1,5 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_moe, moe_align_block_size, fused_topk, get_config_file_name) + fused_moe, fused_topk, get_config_file_name, moe_align_block_size) __all__ = [ "fused_moe", "moe_align_block_size", "fused_topk", "get_config_file_name" diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b2d6dd31611c6..76391753d8ac1 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 7aa2861f5cd7b..ee7c78990bad0 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,8 +4,8 @@ from torch.nn.parameter import Parameter from vllm._C import ops -from vllm.model_executor.layers.fused_moe import (moe_align_block_size, - fused_moe, fused_topk) +from vllm.model_executor.layers.fused_moe import (fused_moe, fused_topk, + moe_align_block_size) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 46475410ce564..d4f66513a0a71 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -7,8 +7,8 @@ from torch.nn.parameter import Parameter from vllm._C import ops -from vllm.model_executor.layers.fused_moe import (moe_align_block_size, - fused_moe, fused_topk) +from vllm.model_executor.layers.fused_moe import (fused_moe, fused_topk, + moe_align_block_size) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 690852c6afb6e..8e70d1f0487b1 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -32,9 +32,12 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - UnquantizedLinearMethod, LinearMethodBase, MergedColumnParallelLinear, - ReplicatedLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index b2caf11b12629..ae4cc47b97b49 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -32,9 +32,12 @@ from vllm.config import LoRAConfig from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, LinearMethodBase, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6d16b92f6dd6b..cac3c7d04f4a6 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -34,9 +34,12 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, LinearMethodBase, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler From e8b2127f18035256c52b99241a033855ffefdbe0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Fri, 29 Mar 2024 20:40:29 +0800 Subject: [PATCH 18/20] Fix format --- vllm/model_executor/models/deepseek.py | 9 +++------ vllm/model_executor/models/mixtral.py | 9 +++------ vllm/model_executor/models/qwen2_moe.py | 9 +++------ 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 8e70d1f0487b1..f30e28782bea8 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -32,12 +32,9 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ae4cc47b97b49..2a1cc25ce0e1f 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -32,12 +32,9 @@ from vllm.config import LoRAConfig from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index cac3c7d04f4a6..73c0a4682f0a4 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -34,12 +34,9 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler From 1922e83566dae87c07cf8d41c1ca9e6ee0ae33cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Mon, 8 Apr 2024 21:25:33 +0800 Subject: [PATCH 19/20] Replace expert parallel with tensor parallel --- vllm/model_executor/layers/linear.py | 39 ++++- .../model_executor/layers/quantization/awq.py | 3 - .../layers/quantization/base_config.py | 5 - .../layers/quantization/gptq.py | 12 +- .../layers/quantization/marlin.py | 10 -- .../layers/quantization/squeezellm.py | 10 -- vllm/model_executor/models/deepseek.py | 139 ++++------------ vllm/model_executor/models/mixtral.py | 143 ++++------------ vllm/model_executor/models/qwen2_moe.py | 152 ++++-------------- 9 files changed, 132 insertions(+), 381 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 76391753d8ac1..11a0562576244 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,8 +5,9 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter +from vllm._C import ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_moe, fused_topk from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -66,14 +67,46 @@ def create_moe_weights(self, num_experts: int, linear_weights[name] = new_param return linear_weights - @abstractmethod def apply_moe_weights(self, w1: Dict[str, torch.Tensor], w2: Dict[str, torch.Tensor], x: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool) -> torch.Tensor: """Apply the weights to the input tensor.""" - raise NotImplementedError + routing_weights, selected_experts = fused_topk(gating_output, + topk, + renormalize=renormalize) + final_hidden_states = None + num_experts = gating_output.shape[-1] + for expert_idx in range(num_experts): + w1_expert = { + key: + value[expert_idx] if isinstance(value, torch.Tensor) else value + for key, value in w1.items() + } + w2_expert = { + key: + value[expert_idx] if isinstance(value, torch.Tensor) else value + for key, value in w2.items() + } + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + hidden_states = self.apply_weights(w1_expert, x) + output_shape = (hidden_states.shape[:-1] + + (hidden_states.shape[-1] // 2, )) + out = torch.empty(output_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + ops.silu_and_mul(out, hidden_states) + current_hidden_states = self.apply_weights( + w2_expert, out).mul_(expert_weights) + + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + return final_hidden_states class UnquantizedLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ee7c78990bad0..b608c8645bcb1 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -70,9 +70,6 @@ def get_linear_method(self) -> "AWQLinearMethod": def get_scaled_act_names(self) -> List[str]: return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] - def support_fused_moe(self) -> bool: - return True - class AWQLinearMethod(LinearMethodBase): """Linear method for AWQ. diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 663a8d2d8b5e6..6115e7c3be956 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -62,8 +62,3 @@ def get_scaled_act_names(self) -> List[str]: For now, this is only used by AWQ. """ raise NotImplementedError - - @abstractmethod - def support_fused_moe(self) -> bool: - """Whether fused moe kernel is implemented""" - raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index d4f66513a0a71..a987dd34eb26a 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -71,10 +71,6 @@ def get_linear_method(self) -> "GPTQLinearMethod": def get_scaled_act_names(self) -> List[str]: return [] - def support_fused_moe(self) -> bool: - # Fused MoE only supports 4-bit so far. - return self.weight_bits == 4 - class ExllamaState(Enum): @@ -232,11 +228,17 @@ def apply_moe_weights(self, w1: Dict[str, w["g_idx"] = torch.argsort(w["g_idx"], dim=-1).to(torch.int) else: - w["g_idx"] = torch.empty((1, 1), device="meta") + w["g_idx"] = torch.empty((w["g_idx"].shape[0], 1), + device="meta") w["exllama_state"] = ExllamaState.READY ops.gptq_shuffle(w["qweight"], w["g_idx"], self.quant_config.weight_bits) + # Fused moe only supports 4-bit + if self.quant_config.weight_bits != 4: + return super().apply_moe_weights(w1, w2, x, gating_output, topk, + renormalize) + if x.shape[0] >= 128: dequant_w1 = ops.dequant_gptq( w1["qweight"], w1["qzeros"], w1["scales"], w1["g_idx"], diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 25ed3007ca6c0..784229878edf4 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -78,9 +78,6 @@ def get_linear_method(self) -> "MarlinLinearMethod": def get_scaled_act_names(self) -> List[str]: return [] - def support_fused_moe(self) -> bool: - return False - class MarlinLinearMethod(LinearMethodBase): """Linear method for Marlin. @@ -221,10 +218,3 @@ def apply_weights( output.add_(bias) # In-place add return output - - def apply_moe_weights(self, w1: Dict[str, - torch.Tensor], w2: Dict[str, - torch.Tensor], - x: torch.Tensor, gating_output: torch.Tensor, - topk: int, renormalize: bool) -> torch.Tensor: - raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 0ead9f9bbc668..ed25455e6ec1f 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -57,9 +57,6 @@ def get_linear_method(self) -> "SqueezeLLMLinearMethod": def get_scaled_act_names(self) -> List[str]: return [] - def support_fused_moe(self) -> bool: - return False - class SqueezeLLMLinearMethod(LinearMethodBase): """Linear method for SqueezeLLM. @@ -131,10 +128,3 @@ def apply_weights(self, if bias is not None: out = out + bias return out.reshape(out_shape) - - def apply_moe_weights(self, w1: Dict[str, - torch.Tensor], w2: Dict[str, - torch.Tensor], - x: torch.Tensor, gating_output: torch.Tensor, - topk: int, renormalize: bool) -> torch.Tensor: - raise NotImplementedError diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index f30e28782bea8..26a6306b7f216 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -23,18 +23,22 @@ """Inference-only Deepseek model.""" from typing import Any, Dict, List, Optional -import numpy as np import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) +# yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler @@ -82,39 +86,6 @@ def forward(self, x): return x -class DeepseekExpertMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.gate_proj = ReplicatedLinear(hidden_size, - intermediate_size, - bias=False, - linear_method=linear_method) - self.up_proj = ReplicatedLinear(hidden_size, - intermediate_size, - bias=False, - linear_method=linear_method) - self.down_proj = ReplicatedLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) - self.act_fn = nn.SiLU() - - def forward(self, hidden_states): - gate_out, _ = self.gate_proj(hidden_states) - gate_out = self.act_fn(gate_out) - up_out, _ = self.up_proj(hidden_states) - current_hidden_states = gate_out * up_out - current_hidden_states, _ = self.down_proj(current_hidden_states) - return current_hidden_states - - class DeepseekMoE(nn.Module): def __init__( @@ -132,40 +103,16 @@ def __init__( if self.linear_method is None: self.linear_method = UnquantizedLinearMethod() - if not isinstance( - self.linear_method, UnquantizedLinearMethod - ) and not self.linear_method.quant_config.support_fused_moe(): - if self.tp_size > self.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.n_routed_experts}.") - # Split experts equally between ranks - self.expert_indicies = np.array_split(range( - self.n_routed_experts), self.tp_size)[self.rank].tolist() - if not self.expert_indicies: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList([ - DeepseekExpertMLP( - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - linear_method=linear_method, - ) if idx in self.expert_indicies else None - for idx in range(self.n_routed_experts) - ]) - else: - self.w1 = MergedColumnParallelLinear( - config.hidden_size, [config.moe_intermediate_size] * 2, - bias=False, - linear_method=linear_method, - num_experts=self.n_routed_experts) - self.w2 = RowParallelLinear(config.moe_intermediate_size, - config.hidden_size, - bias=False, - linear_method=linear_method, - num_experts=self.n_routed_experts) + self.w1 = MergedColumnParallelLinear( + config.hidden_size, [config.moe_intermediate_size] * 2, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) + self.w2 = RowParallelLinear(config.moe_intermediate_size, + config.hidden_size, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, @@ -191,35 +138,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if not isinstance( - self.linear_method, UnquantizedLinearMethod - ) and not self.linear_method.quant_config.support_fused_moe(): - routing_weights, selected_experts = fused_topk( - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob) - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum( - dim=-1, keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - else: - final_hidden_states = self.linear_method.apply_moe_weights( - self.w1.linear_weights, - self.w2.linear_weights, - hidden_states, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - ) + final_hidden_states = self.linear_method.apply_moe_weights( + self.w1.linear_weights, + self.w2.linear_weights, + hidden_states, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + ) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output @@ -468,8 +394,8 @@ def load_weights(self, ("qkv_proj", "v_proj", "v"), ("mlp.gate_up_proj", "mlp.gate_proj", 0), ("mlp.gate_up_proj", "mlp.up_proj", 1), - ("shared_experts.gate_up_proj", "shared_experts.gate_proj", 0), - ("shared_experts.gate_up_proj", "shared_experts.up_proj", 1), + ("shared_expert.gate_up_proj", "shared_expert.gate_proj", 0), + ("shared_expert.gate_up_proj", "shared_expert.up_proj", 1), ] expert_params_mapping = [ @@ -480,8 +406,7 @@ def load_weights(self, for weight_name, shard_id in [("gate_proj", 0), ("up_proj", 1), ("down_proj", None)] - ] if self.linear_method is None or ( - self.linear_method.quant_config.support_fused_moe()) else [] + ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( @@ -499,10 +424,6 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 2a1cc25ce0e1f..60f465f02c6ee 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -23,18 +23,22 @@ """Inference-only Mixtral model.""" from typing import List, Optional -import numpy as np import torch from torch import nn from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig -from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) +# yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler @@ -43,52 +47,13 @@ from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput -class MixtralMLP(nn.Module): - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - linear_method=linear_method) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -108,7 +73,6 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.rank = get_tensor_model_parallel_rank() self.tp_size = tp_size or get_tensor_model_parallel_world_size() self.num_total_experts = num_experts self.top_k = top_k @@ -123,40 +87,16 @@ def __init__( bias=False, linear_method=None) - if not isinstance( - self.linear_method, UnquantizedLinearMethod - ) and not self.linear_method.quant_config.support_fused_moe(): - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}.") - # Split experts equally between ranks - self.expert_indicies = np.array_split( - range(self.num_total_experts), - self.tp_size)[self.rank].tolist() - if not self.expert_indicies: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - hidden_size, - intermediate_size, - linear_method=linear_method) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) - else: - self.ws = MergedColumnParallelLinear(hidden_size, - [intermediate_size] * 2, - bias=False, - linear_method=linear_method, - num_experts=num_experts) - self.w2s = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method, - num_experts=num_experts) + self.ws = MergedColumnParallelLinear(hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + num_experts=num_experts) + self.w2s = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method, + num_experts=num_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape @@ -164,34 +104,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if not isinstance( - self.linear_method, UnquantizedLinearMethod - ) and not self.linear_method.quant_config.support_fused_moe(): - routing_weights, selected_experts = fused_topk(router_logits, - self.top_k, - renormalize=True) - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum( - dim=-1, keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - else: - final_hidden_states = self.linear_method.apply_moe_weights( - self.ws.linear_weights, - self.w2s.linear_weights, - hidden_states, - router_logits, - self.top_k, - renormalize=True, - ) + final_hidden_states = self.linear_method.apply_moe_weights( + self.ws.linear_weights, + self.w2s.linear_weights, + hidden_states, + router_logits, + self.top_k, + renormalize=True, + ) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -473,8 +393,7 @@ def load_weights(self, f"experts.{expert_id}.{weight_name}", shard_id, expert_id) for expert_id in range(self.config.num_local_experts) for weight_name, shard_id in [("w1", 0), ("w3", 1), ("w2", None)] - ] if self.linear_method is None or ( - self.linear_method.quant_config.support_fused_moe()) else [] + ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( @@ -521,10 +440,6 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 73c0a4682f0a4..d9e06bf6ce1ae 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -24,7 +24,6 @@ """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" from typing import Any, Dict, List, Optional -import numpy as np import torch import torch.nn.functional as F from torch import nn @@ -32,11 +31,16 @@ from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) +# yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler @@ -84,50 +88,7 @@ def forward(self, x): return x -class Qwen2MoeExpertMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.gate_proj = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - self.down_proj = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - linear_method=linear_method) - self.up_proj = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - gate_out, _ = self.gate_proj(hidden_states) - gate_out = self.act_fn(gate_out) - up_out, _ = self.up_proj(hidden_states) - current_hidden_states = gate_out * up_out - current_hidden_states, _ = self.down_proj(current_hidden_states) - return current_hidden_states - - class Qwen2MoeSparseMoeBlock(nn.Module): - """A tensor-parallel MoE implementation for Qwen2Moe that shards each expert - across all ranks. - - Each expert's weights are sharded across all ranks and a fused MoE - kernel is used for the forward pass, and finally we reduce the outputs - across ranks. - """ def __init__( self, @@ -135,8 +96,6 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() self.config = config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -169,34 +128,16 @@ def __init__( 1, bias=False) - if not isinstance( - self.linear_method, UnquantizedLinearMethod - ) and not self.linear_method.quant_config.support_fused_moe(): - # Split experts equally between ranks - self.expert_indicies = np.array_split(range( - self.n_routed_experts), self.tp_size)[self.rank].tolist() - if not self.expert_indicies: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList([ - Qwen2MoeExpertMLP(config.hidden_size, - config.moe_intermediate_size, - linear_method=linear_method) - if idx in self.expert_indicies else None - for idx in range(self.n_routed_experts) - ]) - else: - self.w1 = MergedColumnParallelLinear( - config.hidden_size, [config.moe_intermediate_size] * 2, - bias=False, - linear_method=linear_method, - num_experts=self.n_routed_experts) - self.w2 = RowParallelLinear(config.moe_intermediate_size, - config.hidden_size, - bias=False, - linear_method=linear_method, - num_experts=self.n_routed_experts) + self.w1 = MergedColumnParallelLinear( + config.hidden_size, [config.moe_intermediate_size] * 2, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) + self.w2 = RowParallelLinear(config.moe_intermediate_size, + config.hidden_size, + bias=False, + linear_method=linear_method, + num_experts=self.n_routed_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -211,42 +152,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if not isinstance( - self.linear_method, UnquantizedLinearMethod - ) and not self.linear_method.quant_config.support_fused_moe(): - routing_weights, selected_experts = fused_topk( - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob) - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum( - dim=-1, keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - else: - final_hidden_states = self.linear_method.apply_moe_weights( - self.w1.linear_weights, - self.w2.linear_weights, - hidden_states, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - ) + final_hidden_states = self.linear_method.apply_moe_weights( + self.w1.linear_weights, + self.w2.linear_weights, + hidden_states, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -502,8 +420,7 @@ def load_weights(self, for weight_name, shard_id in [("gate_proj", 0), ("up_proj", 1), ("down_proj", None)] - ] if self.linear_method is None or ( - self.linear_method.quant_config.support_fused_moe()) else [] + ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( @@ -521,10 +438,6 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_expert." in name) - and name not in params_dict): - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -553,11 +466,6 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name - or "mlp.shared_expert." in name) - and name not in params_dict): - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From 8bc089fcc56cee67b8f7e9217b78f41d8d700c2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Mon, 8 Apr 2024 22:01:12 +0800 Subject: [PATCH 20/20] Fix typo --- vllm/model_executor/models/deepseek.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 26a6306b7f216..cd0d7dc650ae5 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -394,8 +394,8 @@ def load_weights(self, ("qkv_proj", "v_proj", "v"), ("mlp.gate_up_proj", "mlp.gate_proj", 0), ("mlp.gate_up_proj", "mlp.up_proj", 1), - ("shared_expert.gate_up_proj", "shared_expert.gate_proj", 0), - ("shared_expert.gate_up_proj", "shared_expert.up_proj", 1), + ("shared_experts.gate_up_proj", "shared_experts.gate_proj", 0), + ("shared_experts.gate_up_proj", "shared_experts.up_proj", 1), ] expert_params_mapping = [ @@ -452,11 +452,6 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name - or "mlp.shared_experts." in name) - and name not in params_dict): - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)