From fff6cd29628a069b44430e2ccdaa0a4268ec8483 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 25 Apr 2024 19:03:56 +0000 Subject: [PATCH] [Core]refactor aqlm quant ops (#4351) --- benchmarks/kernels/benchmark_aqlm.py | 2 +- vllm/_custom_ops.py | 14 ++++++++++++++ vllm/model_executor/layers/quantization/aqlm.py | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py index 9602d20bcbc74..59392947b15c8 100644 --- a/benchmarks/kernels/benchmark_aqlm.py +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.aqlm import ( dequantize_weight, generic_dequantize_gemm, get_int_dtype, optimized_dequantize_gemm) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e4b16ed918d1a..508d35656eb00 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -153,6 +153,20 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) +# aqlm +def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, + codebook_partition_sizes, bias) + + +def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: torch.Tensor) -> torch.Tensor: + return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) + + # fp8 def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: scale = torch.zeros(1, device=input.device, dtype=torch.float32) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 6115b1de679ad..b48c6e1702be4 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import (