Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ & AWQ Fused MOE #2761

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2d970c5
Add kernel
chu-tianxiang Feb 4, 2024
281354a
Add group gemm kernel for gptq
chu-tianxiang Feb 5, 2024
2a1c106
Add dequant kernel
chu-tianxiang Feb 5, 2024
e9846d0
Add awq supprt
chu-tianxiang Feb 6, 2024
a9d65a9
Add test
chu-tianxiang Feb 6, 2024
7dea006
format
chu-tianxiang Feb 6, 2024
dfbb034
Merge main and fix problem in kernel
chu-tianxiang Feb 7, 2024
46d15fb
format
chu-tianxiang Feb 7, 2024
78e6e70
Merge main and fix conflicts
chu-tianxiang Feb 16, 2024
6b3e23e
Fix unit test
chu-tianxiang Feb 23, 2024
238f544
Merge branch 'main' into moe_exp
chu-tianxiang Feb 24, 2024
d43445e
Add guard for awq unit test
chu-tianxiang Feb 24, 2024
2c68478
Fix format
chu-tianxiang Feb 24, 2024
2c27dcc
test
chu-tianxiang Feb 24, 2024
c1b98ef
merge main
chu-tianxiang Feb 27, 2024
68d34af
Fix import
chu-tianxiang Feb 27, 2024
6e69101
Merge branch 'main' into moe_exp
chu-tianxiang Feb 27, 2024
d956844
fix format
chu-tianxiang Feb 27, 2024
f19ddfb
Merge main and fix conflicts
chu-tianxiang Feb 29, 2024
7a4ba90
Adapt gptq dequant to 3/8-bit
chu-tianxiang Mar 1, 2024
2fe491d
Merge main branch
chu-tianxiang Mar 3, 2024
4ef69d5
Fix marlin
chu-tianxiang Mar 3, 2024
7a11506
Merge main branch and fix conflicts
chu-tianxiang Mar 12, 2024
9d6f7d1
Fix format check
chu-tianxiang Mar 12, 2024
d08c4fa
Merge main
chu-tianxiang Mar 29, 2024
4faebc3
Fix isort
chu-tianxiang Mar 29, 2024
e8b2127
Fix format
chu-tianxiang Mar 29, 2024
1922e83
Replace expert parallel with tensor parallel
chu-tianxiang Apr 8, 2024
8bc089f
Fix typo
chu-tianxiang Apr 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,25 @@ torch::Tensor awq_dequantize(
int thx,
int thy);

torch::Tensor awq_group_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
torch::Tensor _topk_weights,
torch::Tensor _sorted_token_ids_ptr,
torch::Tensor _expert_ids_ptr,
torch::Tensor _num_tokens_post_padded,
bool mul_weights,
int split_k_iters);

torch::Tensor marlin_gemm(
torch::Tensor& a,
torch::Tensor& a,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
torch::Tensor& b_scales,
torch::Tensor& workspace,
int64_t size_m,
int64_t size_n,
int64_t size_m,
int64_t size_n,
int64_t size_k);
#endif

Expand All @@ -129,6 +141,29 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);

torch::Tensor group_gptq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
torch::Tensor topk_weights,
torch::Tensor sorted_token_ids_ptr,
torch::Tensor expert_ids_ptr,
torch::Tensor num_tokens_post_padded,
bool mul_weights,
bool use_exllama
);

torch::Tensor dequant_gptq(
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
int bits,
bool use_exllama
);

void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
Expand Down
5 changes: 4 additions & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Quantization ops
#ifndef USE_ROCM
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("awq_group_gemm", &awq_group_gemm, "Grouped Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif

ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("group_gptq_gemm", &group_gptq_gemm, "Grouped Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("dequant_gptq", &dequant_gptq, "Dequantize gptq weight to half");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"moe_align_block_size",
Expand Down
Loading
Loading