diff --git a/CMakeLists.txt b/CMakeLists.txt index 1845151181284..b2d0cf3e568b7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,6 +173,7 @@ set(VLLM_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC + "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/marlin_cuda_kernel.cu" "csrc/custom_all_reduce.cu") diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py new file mode 100644 index 0000000000000..9602d20bcbc74 --- /dev/null +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -0,0 +1,302 @@ +import argparse +import os +import sys +from typing import Optional + +import torch +import torch.nn.functional as F + +from vllm._C import ops +from vllm.model_executor.layers.quantization.aqlm import ( + dequantize_weight, generic_dequantize_gemm, get_int_dtype, + optimized_dequantize_gemm) + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def torch_mult( + input: torch.Tensor, # [..., in_features] + weights: torch.Tensor, + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] +) -> torch.Tensor: + output = F.linear(input, weights) + return output + + +def dequant_out_scale( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + if bias is None: + output = F.linear(input, weights, bias) + orig_shape = output.shape + flattened_output = output.view(-1, output.size(-1)) + f_scales = scales.view(-1, scales.shape[0]) + b_scales = f_scales.expand(flattened_output.shape[0], -1) + flattened_output *= b_scales + return flattened_output.view(orig_shape) + else: + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +def dequant_weight_scale( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +def dequant_no_scale( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + return F.linear(input, weights, bias) + + +# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against +# the generic pytorch version. +# Just visual comparison. +def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: + + n = parts.sum().item() + + device = torch.device('cuda:0') + + code_range = (1 << bits) // 2 + ingroups = 8 + + codes = torch.randint(-code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device) + + codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device) + + count = 0 + for index in range(16): + for i in range(8): + for book in range(nbooks): + codebooks[book, index, 0, i] = count * (10**book) + count += 1 + + print("codes shape", codes.shape) + + for i in range(16): + for book in range(nbooks): + codes[0, i, book] = i + codes[0, -i, book] = i + + weights = dequantize_weight(codes, codebooks, None) + weights2 = ops.aqlm_dequant(codes, codebooks, parts) + + print("weights shape:", weights.shape) + print("weights2 shape:", weights2.shape) + + print("weights are:", weights) + print("weights2 are:", weights2) + + print("first 128 weights are", weights[0, 0:128].to(torch.int32)) + print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32)) + + print("last 128 weights are", weights[0, -128:]) + print("last 128 weights2 are:", weights2[0, -128:]) + + +def main(): + + parser = argparse.ArgumentParser(description="Benchmark aqlm performance.") + + # Add arguments + parser.add_argument("--nbooks", + type=int, + default=1, + help="Number of codebooks (default: 1)") + parser.add_argument("--bits", + type=int, + default=16, + help="Number of bits per code element (default: 16)") + parser.add_argument( + "--test", + type=bool, + default=False, + help="Run the decompression/dequant tester rather than benchmarking " + "(default: False)") + + # Parse the arguments + args = parser.parse_args() + + # Extract values + nbooks = args.nbooks + bits = args.bits + + if args.test: + dequant_test(4096, torch.tensor((4096, )), nbooks, bits) + return + + # Otherwise, benchmark. + methods = [ + ops.aqlm_gemm, + dequant_out_scale, + generic_dequantize_gemm, + optimized_dequantize_gemm, + dequant_weight_scale, + torch_mult, + dequant_no_scale, + ] + + filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv" + print(f"writing benchmarks to file {filename}") + with open(filename, "w") as f: + sys.stdout = f + + print('m | k | n | n parts', end='') + for method in methods: + print(f" | {method.__name__.replace('_', ' ')} (µs)", end='') + print('') + + # These are reasonable prefill sizes. + ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )), + (4096, (11008, 11008)), (11008, (4096, ))) + + # reasonable ranges for m. + for m in [ + 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112, + 128, 256, 512, 1024, 1536, 2048, 3072, 4096 + ]: + print(f'{m}', file=sys.__stdout__) + for ksp in ksandpartions: + run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, + methods) + + sys.stdout = sys.__stdout__ + + +def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, + methods): + + # I didn't see visible improvements from increasing these, but feel free :) + num_warmup_trials = 1 + num_trials = 1 + + num_calls = 100 + + # warmup. + for method in methods: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + m=m, + k=k, + parts=parts, + nbooks=nbooks, + bits=bits, + method=method, + ) + + n = parts.sum().item() + print(f'{m} | {k} | {n} | {parts.tolist()}', end='') + + for method in methods: + best_time_us = 1e20 + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + m=m, + k=k, + parts=parts, + nbooks=nbooks, + bits=bits, + method=method, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + + if kernel_dur_us < best_time_us: + best_time_us = kernel_dur_us + + print(f' | {kernel_dur_us:.0f}', end='') + + print('') + + +def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor, + nbooks: int, bits: int, method) -> float: + + n = parts.sum().item() + + device = torch.device('cuda:0') + + input = torch.randn((1, m, k), dtype=torch.float16, device=device) + + code_range = (1 << bits) // 2 + ingroups = 8 + + codes = torch.randint(-code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device) + + codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device) + + scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) + + # for comparison to just a pytorch mult. + weights = torch.randn((n, k), dtype=torch.float16, device=device) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + if method is torch_mult: + for i in range(num_calls): + torch_mult(input, weights, scales) + else: + for i in range(num_calls): + method(input, codes, codebooks, scales, parts, None) + + end_event.record() + end_event.synchronize() + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/csrc/ops.h b/csrc/ops.h index 41ecc1e89371b..a379c910d9cf3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -86,6 +86,21 @@ void gelu_fast( torch::Tensor& input); #ifndef USE_ROCM +torch::Tensor aqlm_gemm( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias +); + +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes +); + torch::Tensor awq_gemm( torch::Tensor _in_feats, torch::Tensor _kernel, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index de02afc162113..42e92e5382e8e 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -63,6 +63,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Quantization ops #ifndef USE_ROCM + ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); + ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu new file mode 100644 index 0000000000000..4415316e1e8cd --- /dev/null +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -0,0 +1,712 @@ +/* + * Modified by Neural Magic + * Adapted from https://github.com/Vahe1994/AQLM + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + + +namespace vllm { +namespace aqlm { + +__global__ void Code1x16MatVec( + const int4* __restrict__ A, + const int4* __restrict__ B, + int4* __restrict__ C, + const int4* __restrict__ codebook, + const int prob_m, + const int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + const int codebook_stride // as int4. +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; + } + } + + int b_gl_rd = 0; + int c_gl_wr = a_gl_rd; + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + + __shared__ int4 sh_b[32 * 9]; + float res = 0; + + int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); + while (iters--) { + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { + if (b_gl_rd + i < prob_k / 8) + sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * 8; + + int b_sh_rd = 9 * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); + #pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t dec[4]; + // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't + // actually help us; this brings > 2x speedup. + asm volatile ( + "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*) &codebook[enc[i]]) + ); + half2* a = reinterpret_cast(&dec); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; + #pragma unroll + for (int j = 0; j < 4; j++) + res2 = __hfma2(a[j], b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); + b_sh_rd++; + } + a_gl_rd += 32; + } + } + + if (pred) { + #pragma unroll + for (int i = 16; i > 0; i /= 2) + res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } +} + +__global__ void Code2x8MatVec( + const int4* __restrict__ A, + const int4* __restrict__ B, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + const int codebook_stride // as int4. + +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; + } + } + + int b_gl_rd = 0; + int c_gl_wr = a_gl_rd; + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + int lane = threadIdx.x % 8; + + extern __shared__ int4 sh[]; + int4* sh_b = sh; + int4* sh_code = sh_b + 32 * 9; + int4* sh_code0 = sh_code; + int4* sh_code1 = sh_code + 256 * 8; + + for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { + int4 dec = codebook[i]; + #pragma unroll + for (int j = 0; j < 8; j++) + sh_code[8 * i + (j + lane) % 8] = dec; + } + __syncthreads(); + + float res = 0; + + int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); + while (iters--) { + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { + if (b_gl_rd + i < prob_k / 8) + sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * 8; + + int b_sh_rd = 9 * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); + #pragma unroll + for (int i = 0; i < 8; i++) { + half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; + #pragma unroll + for (int j = 0; j < 4; j++) + res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); + b_sh_rd++; + } + a_gl_rd += 32; + } + } + + if (pred) { + #pragma unroll + for (int i = 16; i > 0; i /= 2) + res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } +} + + +__global__ void Code1x16Dequant( + const int4* __restrict__ A, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m. + const int codebook_stride // as int4 +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; + } + } + + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + + int c_gl_stride = prob_k / 8; + int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; + + int iters = (prob_k / 8 - 1) / (8 * 32) + 1; + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); + #pragma unroll + for (int i = 0; i < 8; i++) { + int4 chunk; + auto dec = reinterpret_cast(&chunk); + // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't + // actually help us; this brings > 2x speedup. + asm volatile ( + "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*) &codebook[enc[i]]) + ); + + C[a_gl_rd * 8 + i] = chunk; + } + } + a_gl_rd += 32; + } +} + + +__global__ void Code2x8Dequant( + const int4* __restrict__ A, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. + const int codebook_stride // as int4 +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; + } + } + + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + int lane = threadIdx.x % 8; + + int c_gl_stride = prob_k / 8; + int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; + + extern __shared__ int4 sh[]; + int4* sh_code = sh; + int4* sh_code0 = sh_code; + int4* sh_code1 = sh_code + 256 * 8; + + for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { + int4 dec = codebook[i]; + #pragma unroll + for (int j = 0; j < 8; j++) + sh_code[8 * i + (j + lane) % 8] = dec; + } + __syncthreads(); + + float res = 0; + + int iters = (prob_k / 8 - 1) / (8 * 32) + 1; + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); + #pragma unroll + for (int i = 0; i < 8; i++) { + int4 chunk; + half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); + C[a_gl_rd * 8 + i] = chunk; + } + } + a_gl_rd += 32; + } +} + +inline int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +const int THREAD_M = 16; + +void code1x16_matvec_cuda( + const void* __restrict__ A, + const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, + const int codebook_stride +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16MatVec<<>>( + (const int4*) A, + (const int4*) B, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride + ); +} + +void code2x8_matvec_cuda( + const void* __restrict__ A, + const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, + const int codebook_stride +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + int shared = 16 * (2 * 256 * 8 + 32 * 9); + cudaFuncSetAttribute( + Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared + ); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code2x8MatVec<<>>( + (const int4*) A, + (const int4*) B, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride + ); +} + +void code1x16_dequant_cuda( + const void* __restrict__ A, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + const int codebook_stride // as int4. +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16Dequant<<>>( + (const int4*) A, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + codebook_stride // as int4. + ); +} + +// Dequantizes the code and codebook into weights. +void code2x8_dequant_cuda( + const void* __restrict__ A, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. + const int codebook_stride // as int4 +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + int shared = 16 * (2 * 256 * 8 + 32 * 9); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + cudaFuncSetAttribute( + Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared + ); + Code2x8Dequant<<>>( + (const int4*) A, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride + ); +} + +int codebook_stride(const torch::Tensor& codebooks) +{ + return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); +} + +void code1x16_matvec( + const torch::Tensor& A, + const torch::Tensor& B, + torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long. +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); + + code1x16_matvec_cuda( + A.data_ptr(), + B.data_ptr(), + C.data_ptr(), + codebook.data_ptr(), + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride(codebook) + ); +} + +torch::Tensor code1x16_matmat( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty({flat_input.size(0), out_features}, + torch::TensorOptions() + .dtype(input.dtype()) + .device(input.device()) + ); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); + code1x16_matvec( + codes.squeeze(2), + input_vec, + output_vec, + codebooks, + codebook_a_sizes + ); + } + flat_output *= scales.flatten().unsqueeze(0); + + if (bias.has_value()) { + flat_output += bias->unsqueeze(0); + } + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(-1); + auto output = flat_output.reshape(output_sizes); + return output; +} + +void code2x8_matvec( + const torch::Tensor& A, + const torch::Tensor& B, + torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); + code2x8_matvec_cuda( + A.data_ptr(), + B.data_ptr(), + C.data_ptr(), + codebook.data_ptr(), + prob_m, + prob_k, + codebook_a_sizes, + 2 * codebook_stride(codebook) + ); +} + +torch::Tensor code2x8_matmat( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias +) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty({flat_input.size(0), out_features}, + torch::TensorOptions() + .dtype(input.dtype()) + .device(input.device()) + ); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); + code2x8_matvec( + codes.squeeze(2), + input_vec, + output_vec, + codebooks, + codebook_a_sizes + ); + } + flat_output *= scales.flatten().unsqueeze(0); + if (bias.has_value()) { + flat_output += bias->unsqueeze(0); + } + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(-1); + auto output = flat_output.reshape(output_sizes); + return output; +} + +// Accumulate the partition sizes. +int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) +{ + int4 cumulative_sizes; + auto cumulative_size = &cumulative_sizes.x; + int i = 0; + int last = 0; + assert(codebook_partition_sizes.size(0) <= 4); + for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) + { + *cumulative_size = codebook_partition_sizes[i].item() + last; + last = *cumulative_size; + } + // fill in the rest with unreachable. + for (; i < 4; ++i, ++cumulative_size) + { + *cumulative_size = last*10; + } + return cumulative_sizes; +} + +} // namespace aqlm +} // namespace vllm + + +torch::Tensor aqlm_gemm( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias +) +{ + int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); + + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const entries = codebooks.size(1); + + if (nbooks == 1 && entries == (1 << 16)) + { + return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + } + if (nbooks == 2 && entries == (1 << 8)) + { + return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + } + + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + return {}; +} + +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes +) +{ + int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); + + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const entries = codebooks.size(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); + int rows = codes.size(1); + int cols = codes.size(0); + + auto in_features = codes.size(1) * 8; + auto out_features = codes.size(0); + + assert(out_features = codebook_partition_sizes.sum().item()); + + auto weights = torch::empty({out_features, in_features}, + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device()) + ); + + if (nbooks == 1 && entries == (1 << 16)) + { + vllm::aqlm::code1x16_dequant_cuda( + codes.data_ptr(), + weights.data_ptr(), + codebooks.data_ptr(), + out_features, + in_features, + cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) + // weights *= scales.index({"...", 0, 0}); + + return weights; + } + + if (nbooks == 2 && entries == (1 << 8)) + { + vllm::aqlm::code2x8_dequant_cuda( + codes.data_ptr(), + weights.data_ptr(), + codebooks.data_ptr(), + out_features, + in_features, + cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) + // weights *= scales.index({"...", 0, 0}); + + return weights; + } + + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + return {}; +} diff --git a/examples/aqlm_example.py b/examples/aqlm_example.py new file mode 100644 index 0000000000000..e7c17fa0362ae --- /dev/null +++ b/examples/aqlm_example.py @@ -0,0 +1,46 @@ +import argparse + +from vllm import LLM, SamplingParams + + +def main(): + + parser = argparse.ArgumentParser(description='AQLM examples') + + parser.add_argument('--model', + '-m', + type=str, + default=None, + help='model path, as for HF') + parser.add_argument('--choice', + '-c', + type=int, + default=0, + help='known good models by index, [0-4]') + parser.add_argument('--tensor_parallel_size', + '-t', + type=int, + default=1, + help='tensor parallel size') + + args = parser.parse_args() + + models = [ + "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", + "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf", + "ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf", + "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf", + "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf", + ] + + model = LLM(args.model if args.model is not None else models[args.choice], + tensor_parallel_size=args.tensor_parallel_size) + + sampling_params = SamplingParams(max_tokens=100, temperature=0) + outputs = model.generate("Hello my name is", + sampling_params=sampling_params) + print(outputs[0].outputs[0].text) + + +if __name__ == '__main__': + main() diff --git a/tests/models/test_aqlm.py b/tests/models/test_aqlm.py new file mode 100644 index 0000000000000..a7abc011f57d7 --- /dev/null +++ b/tests/models/test_aqlm.py @@ -0,0 +1,95 @@ +"""Compare the outputs of a AQLM model between vLLM and HF Transformers + +Run `pytest tests/models/test_aqlm.py`. +""" + +import pytest +import torch + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +aqlm_not_supported = (capability < + QUANTIZATION_METHODS["aqlm"].get_min_capability()) + +# In this test we hardcode prompts and generations for the model so we don't +# need to require the AQLM package as a dependency +example_prompts = [ + 'vLLM is a high-throughput and memory-efficient inference and serving ' + 'engine for LLMs.\n', + 'Briefly describe the major milestones in the development of artificial ' + 'intelligence from 1950 to 2020.\n', + 'Compare and contrast artificial intelligence with human intelligence in ' + 'terms of processing information.\n', + 'Describe the basic components of a neural network and how it can be ' + 'trained.\n', + 'Write a short story about a robot that dreams for the first time.\n', + 'Analyze the impact of the COVID-19 pandemic on global economic structures ' + 'and future business models.\n', + 'Explain the cultural significance of the Mona Lisa painting, and how its ' + 'perception might vary in Western versus Eastern societies.\n', + "Translate the following English sentence into Japanese, French, and " + "Swahili: 'The early bird catches the worm.'\n" +] + +# These ground truth generations were generated using `transformers==4.38.1 +# aqlm==1.1.0 torch==2.2.0` +# and the below code: +# ```python +# from transformers import AutoTokenizer, AutoModelForCausalLM +# model_id = "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" +# quantized_model = AutoModelForCausalLM.from_pretrained(model_id, +# torch_dtype="auto", device_map="cuda").cuda() +# tokenizer = AutoTokenizer.from_pretrained(model_id) +# outputs = [] +# for prompt in example_prompts: +# input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda") +# hf_outputs = quantized_model.generate(input_ids, max_new_tokens=32) +# outputs.append(tokenizer.decode(hf_outputs[0][input_ids.shape[1]:])) +# print(outputs) +# ``` +ground_truth_generations = [ + '\n### Features\n\n- **High-throughput**: v', + 'The major milestones in the development of artificial intelligence from ' + '195', + 'Compare and contrast artificial intelligence with human intelligence in ' + 'terms of processing information. The', + 'Explain the difference between supervised and unsupervised learning.' + '\nExplain', + 'Write a short story about a robot that dreams for the first time. The', + 'Analyze the impact of the COVID-19 pandemic on global economic', + 'The Mona Lisa is a painting by Leonardo da Vinci, and it', + 'The early bird catches the worm.\nThe early bird catches the' +] + + +@pytest.mark.skipif(aqlm_not_supported, + reason="AQLM is not supported on this GPU type.") +@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("num_logprobs", [1]) +def test_models( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) + + # loop through the prompts to compare against the ground truth generations + for prompt_idx in range(len(example_prompts)): + vllm_output_ids, vllm_output_str, vllm_logprobs = vllm_outputs[ + prompt_idx] + + print("Prompt: ", repr(example_prompts[prompt_idx])) + print("Reference output:", repr(ground_truth_generations[prompt_idx])) + print("Output output: ", repr(vllm_output_str)) + assert vllm_output_str == ground_truth_generations[prompt_idx] diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d466d8807fc64..e56af9075e2fd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -31,7 +31,7 @@ class LinearMethodBase(ABC): @abstractmethod def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """Create weights for a linear layer. @@ -70,9 +70,10 @@ def __init__(self, separate_bias_add: bool = False): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), @@ -127,7 +128,7 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_method.create_weights(self, self.input_size, - self.output_size, self.input_size, + [self.output_size], self.input_size, self.output_size, self.params_dtype) if bias: self.bias = Parameter( @@ -161,6 +162,8 @@ class ColumnParallelLinear(torch.nn.Module): skip adding bias but instead return it. params_dtype: Data type for the parameters. linear_method: (Maybe quantized) linear method. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. """ def __init__( @@ -172,6 +175,7 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, linear_method: Optional[LinearMethodBase] = None, + output_sizes: Optional[List[int]] = None, ): super().__init__() @@ -188,10 +192,12 @@ def __init__( self.params_dtype = params_dtype if linear_method is None: linear_method = UnquantizedLinearMethod() + if output_sizes is None: + output_sizes = [output_size] self.linear_method = linear_method self.linear_method.create_weights(self, self.input_size, - self.output_size_per_partition, + [x // tp_size for x in output_sizes], self.input_size, self.output_size, self.params_dtype, @@ -268,14 +274,17 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method) + skip_bias_add, params_dtype, linear_method, + self.output_sizes) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): + param_data = param.data output_dim = getattr(param, "output_dim", None) + is_metadata = getattr(param, "is_metadata", False) if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -328,6 +337,11 @@ def weight_loader(self, start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -393,8 +407,14 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + output_sizes = [ + self.num_heads * tp_size * self.head_size, + self.num_kv_heads * tp_size * self.head_size, + self.num_kv_heads * tp_size * self.head_size + ] + super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, linear_method) + params_dtype, linear_method, output_sizes) def weight_loader(self, param: Parameter, @@ -402,6 +422,7 @@ def weight_loader(self, loaded_shard_id: Optional[str] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + is_metadata = getattr(param, "is_metadata", False) if loaded_shard_id is None: # Loaded weight is already packed. @@ -469,6 +490,12 @@ def weight_loader(self, start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, + shard_size) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -536,7 +563,7 @@ def __init__( self.linear_method = linear_method self.linear_method.create_weights(self, self.input_size_per_partition, - self.output_size, + [self.output_size], self.input_size, self.output_size, self.params_dtype, diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 0344d6e4e3e45..a525add458499 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,5 +1,6 @@ from typing import Type +from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -9,6 +10,7 @@ from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig QUANTIZATION_METHODS = { + "aqlm": AQLMConfig, "awq": AWQConfig, "fp8": FP8Config, "gptq": GPTQConfig, diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py new file mode 100644 index 0000000000000..6115b1de679ad --- /dev/null +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -0,0 +1,373 @@ +# Supports AQLM compression, see https://github.com/Vahe1994/AQLM +# and https://arxiv.org/pdf/2401.06118.pdf + +import math +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +def get_int_dtype(nbits: int) -> torch.dtype: + if nbits <= 8: + return torch.int8 + if nbits <= 16: + return torch.int16 + if nbits <= 32: + return torch.int32 + if nbits <= 64: + return torch.int64 + raise ValueError(f"No dtype available for {nbits}-bit codebooks") + + +@torch.inference_mode() +def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: + return data.to(torch.int64) % (2**nbits) + + +def dequantize_weight(codes: torch.Tensor, + codebooks: torch.Tensor, + scales: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Decode float weights from quantization codes. Differentiable. + :param codes: tensor of integer quantization codes, shape + [*dims, num_out_groups, num_in_groups, num_codebooks] + :param codebooks: tensor of vectors for each quantization code, + [num_codebooks, codebook_size, out_group_size, in_group_size] + :param scales: weight will be multiplied by this factor, must be + broadcastble with + [*dims, out_groups, num_in_groups, out_group_size, in_group_size] + :return: reconstructed weight tensor of shape + [*dims, num_in_groups*group_size] + """ + num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] + num_codebooks, codebook_size, out_group_size, in_group_size = \ + codebooks.shape + out_features = num_out_groups * out_group_size + in_features = num_in_groups * in_group_size + codebook_offsets = torch.arange( + 0, num_codebooks * codebook_size, codebook_size, + device=codes.device) # shape: [num_codebooks] + reconstructed_weight_flat = F.embedding_bag( + codes.flatten(0, -2) + codebook_offsets, + codebooks.flatten(0, 1).flatten(-2, -1), + mode="sum" + ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size + # * in_group_size] + + reconstructed_weight_groupwise = reconstructed_weight_flat.view( + list(codes.shape[:-3]) + + [num_out_groups, num_in_groups, out_group_size, in_group_size]) + if scales is not None: + reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul( + scales) + return reconstructed_weight_groupwise.swapaxes( + -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) + + +def dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + bias: Optional[torch.Tensor], +) -> torch.Tensor: + dequantized_weight = dequantize_weight( + unpack_int_data(codes, codebooks.shape[1].bit_length() - 1), + codebooks, + scales, + ) + return F.linear(input, dequantized_weight, bias) + + +# Generic dequantization, slow but flexible. +def generic_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + output_shape = input.shape[:-1] + (scales.shape[0], ) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + num_outputs = len(output_partition_sizes) + + # break the inputs and codebooks apart then combine the outputs. + # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big + # multiply at the end. + num_codebooks = codebooks.shape[0] // num_outputs + assert (scales.shape[0] == codes.shape[0]) + assert (sum(output_partition_sizes) == scales.shape[0]) + output_offset = 0 + codebooks_offset = 0 + for output_size in output_partition_sizes: + shard_output = dequantize_gemm( + input, codes.narrow(0, output_offset, output_size), + codebooks.narrow(0, codebooks_offset, num_codebooks), + scales.narrow(0, output_offset, output_size), None + if bias is None else bias.narrow(0, output_offset, output_size)) + + output_slice = output.narrow(-1, output_offset, output_size) + assert (output_slice.shape == shard_output.shape) + output_slice.copy_(shard_output) + output_offset += output_size + codebooks_offset += num_codebooks + return output + + +# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8 +# at 6 and 9 times faster than the generic version above, respectively. +def optimized_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + if bias is None: + # scaling the output is fastest, so we do that when possible. + output = F.linear(input, weights, bias) + orig_shape = output.shape + flattened_output = output.view(-1, output.size(-1)) + f_scales = scales.view(-1, scales.shape[0]) + b_scales = f_scales.expand(flattened_output.shape[0], -1) + flattened_output *= b_scales + return output.view(orig_shape) + else: + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +class AQLMConfig(QuantizationConfig): + """Config class for AQLM. + + Reference: https://github.com/Vahe1994/AQLM + """ + + def __init__( + self, + in_group_size: int, + nbits_per_codebook: int, + num_codebooks: int, + out_group_size: int, + ) -> None: + self.in_group_size = in_group_size + self.nbits_per_codebook = nbits_per_codebook + self.num_codebooks = num_codebooks + self.out_group_size = out_group_size + + # out_group_size > 1 is untested, and probably won't work as-is. + assert (self.out_group_size == 1) + self.pack_factor = (self.in_group_size * self.out_group_size) + + def __repr__(self) -> str: + return (f"AQLMConfig(in_group_size={self.in_group_size}, " + f"nbits_per_codebook={self.nbits_per_codebook}, " + f"num_codebooks={self.num_codebooks}, " + f"out_group_size={self.out_group_size})") + + @classmethod + def get_name(cls) -> str: + return "aqlm" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": + in_group_size = cls.get_from_keys(config, ["in_group_size"]) + nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"]) + num_code_books = cls.get_from_keys(config, ["num_codebooks"]) + out_group_size = cls.get_from_keys(config, ["out_group_size"]) + return cls(in_group_size, nbits_per_codebook, num_code_books, + out_group_size) + + def get_linear_method(self) -> "AQLMLinearMethod": + return AQLMLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class AQLMLinearMethod(LinearMethodBase): + """Linear method for AQLM. + + Args: + quant_config: The AQLM quantization config. + """ + + def __init__(self, quant_config: AQLMConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + del output_size # Unused. + del input_size # Unused. + + if params_dtype != torch.half: + raise ValueError("Only half is currently supported by aqlm") + if input_size_per_partition % self.quant_config.in_group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.out_group_size != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + codes = Parameter( + torch.empty( + # There could actually be two pack factors, one along input and + # one along output, but we don't currently support + # out_group_size, and only the one along output needs to be + # marked with "packed_dim" in order for QKVLinear to work. + output_size_per_partition, + input_size_per_partition // self.quant_config.pack_factor, + self.quant_config.num_codebooks, + dtype=get_int_dtype(self.quant_config.nbits_per_codebook), + ), + requires_grad=False, + ) + + set_weight_attrs( + codes, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }, + ) + + codebooks = Parameter( + torch.empty( + self.quant_config.num_codebooks * len(output_partition_sizes), + 2**self.quant_config.nbits_per_codebook, + self.quant_config.out_group_size, + self.quant_config.in_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + codebooks, + { + # metadata indicates fixed size concatenated along dim 0 + "is_metadata": + True, + "output_partition_sizes": + torch.tensor(output_partition_sizes, device='cpu'), + }, + ) + + scales = Parameter( + torch.empty( + ( + output_size_per_partition // + self.quant_config.out_group_size, + 1, + 1, + 1, + ), + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "output_dim": 0, + "packed_dim": 0, + "pack_factor": self.quant_config.out_group_size + }, + ) + + layer.register_parameter("codes", codes) + set_weight_attrs(codes, extra_weight_attrs) + layer.register_parameter("codebooks", codebooks) + set_weight_attrs(codebooks, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + codebooks = layer.codebooks + codes = layer.codes + scales = layer.scales + output_partition_sizes = getattr(codebooks, "output_partition_sizes", + None) + + nbooks = codes.shape[2] + ingroups = codebooks.shape[3] + outgroups = codebooks.shape[2] + bits = codebooks.shape[1] + + # We support these formats with dedicated gemm and decompression + # kernels. + if ingroups == 8 and outgroups == 1 and ( + (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)): + + # thresholds determined by timings on an A6000, one GPU + use_gemv = math.prod(x.shape[:-1]) <= 6 + + return ops.aqlm_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) if use_gemv else optimized_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) + + # fall back all unoptimized formats + return generic_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 98651aed8be0e..4f75134ee1889 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -81,7 +81,7 @@ def __init__(self, quant_config: AWQConfig): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: @@ -89,6 +89,8 @@ def create_weights(self, layer: torch.nn.Module, "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f370b94a210ee..92a5cdb9af928 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -91,7 +91,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -103,6 +103,7 @@ def create_weights( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + output_size_per_partition = sum(output_partition_sizes) if (output_size_per_partition % self.quant_config.pack_factor.numerator != 0): raise ValueError( diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index bf0500f1155a1..00c3c404c2d7a 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -93,7 +93,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -106,6 +106,7 @@ def create_weights( f"The params dtype must be float16, but got {params_dtype}") # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.min_n_threads != 0: raise ValueError( f"Weight output_size_per_partition = " diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 661ff9c55d0d1..cc44447d347b8 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -70,7 +70,7 @@ def __init__(self, quant_config: SqueezeLLMConfig): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: @@ -78,6 +78,8 @@ def create_weights(self, layer: torch.nn.Module, "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) qweight = Parameter( torch.empty( input_size_per_partition // self.quant_config.pack_factor,