Skip to content

Commit

Permalink
AQLM CUDA support (#3287)
Browse files Browse the repository at this point in the history
Co-authored-by: mgoin <[email protected]>
  • Loading branch information
jaemzfleming and mgoin authored Apr 23, 2024
1 parent 62b5166 commit 2b7949c
Show file tree
Hide file tree
Showing 14 changed files with 1,592 additions and 11 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ set(VLLM_EXT_SRC

if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
"csrc/custom_all_reduce.cu")
Expand Down
302 changes: 302 additions & 0 deletions benchmarks/kernels/benchmark_aqlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
import argparse
import os
import sys
from typing import Optional

import torch
import torch.nn.functional as F

from vllm._C import ops
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def torch_mult(
input: torch.Tensor, # [..., in_features]
weights: torch.Tensor,
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
) -> torch.Tensor:
output = F.linear(input, weights)
return output


def dequant_out_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:

weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)

if bias is None:
output = F.linear(input, weights, bias)
orig_shape = output.shape
flattened_output = output.view(-1, output.size(-1))
f_scales = scales.view(-1, scales.shape[0])
b_scales = f_scales.expand(flattened_output.shape[0], -1)
flattened_output *= b_scales
return flattened_output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)


def dequant_weight_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:

weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)

b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)


def dequant_no_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:

weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)

return F.linear(input, weights, bias)


# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
# the generic pytorch version.
# Just visual comparison.
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:

n = parts.sum().item()

device = torch.device('cuda:0')

code_range = (1 << bits) // 2
ingroups = 8

codes = torch.randint(-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)

codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)

count = 0
for index in range(16):
for i in range(8):
for book in range(nbooks):
codebooks[book, index, 0, i] = count * (10**book)
count += 1

print("codes shape", codes.shape)

for i in range(16):
for book in range(nbooks):
codes[0, i, book] = i
codes[0, -i, book] = i

weights = dequantize_weight(codes, codebooks, None)
weights2 = ops.aqlm_dequant(codes, codebooks, parts)

print("weights shape:", weights.shape)
print("weights2 shape:", weights2.shape)

print("weights are:", weights)
print("weights2 are:", weights2)

print("first 128 weights are", weights[0, 0:128].to(torch.int32))
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))

print("last 128 weights are", weights[0, -128:])
print("last 128 weights2 are:", weights2[0, -128:])


def main():

parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")

# Add arguments
parser.add_argument("--nbooks",
type=int,
default=1,
help="Number of codebooks (default: 1)")
parser.add_argument("--bits",
type=int,
default=16,
help="Number of bits per code element (default: 16)")
parser.add_argument(
"--test",
type=bool,
default=False,
help="Run the decompression/dequant tester rather than benchmarking "
"(default: False)")

# Parse the arguments
args = parser.parse_args()

# Extract values
nbooks = args.nbooks
bits = args.bits

if args.test:
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
return

# Otherwise, benchmark.
methods = [
ops.aqlm_gemm,
dequant_out_scale,
generic_dequantize_gemm,
optimized_dequantize_gemm,
dequant_weight_scale,
torch_mult,
dequant_no_scale,
]

filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
print(f"writing benchmarks to file {filename}")
with open(filename, "w") as f:
sys.stdout = f

print('m | k | n | n parts', end='')
for method in methods:
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
print('')

# These are reasonable prefill sizes.
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
(4096, (11008, 11008)), (11008, (4096, )))

# reasonable ranges for m.
for m in [
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
128, 256, 512, 1024, 1536, 2048, 3072, 4096
]:
print(f'{m}', file=sys.__stdout__)
for ksp in ksandpartions:
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
methods)

sys.stdout = sys.__stdout__


def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
methods):

# I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials = 1
num_trials = 1

num_calls = 100

# warmup.
for method in methods:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)

n = parts.sum().item()
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')

for method in methods:
best_time_us = 1e20
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)

kernel_dur_us = 1000 * kernel_dur_ms

if kernel_dur_us < best_time_us:
best_time_us = kernel_dur_us

print(f' | {kernel_dur_us:.0f}', end='')

print('')


def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
nbooks: int, bits: int, method) -> float:

n = parts.sum().item()

device = torch.device('cuda:0')

input = torch.randn((1, m, k), dtype=torch.float16, device=device)

code_range = (1 << bits) // 2
ingroups = 8

codes = torch.randint(-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)

codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)

scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)

# for comparison to just a pytorch mult.
weights = torch.randn((n, k), dtype=torch.float16, device=device)

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()

if method is torch_mult:
for i in range(num_calls):
torch_mult(input, weights, scales)
else:
for i in range(num_calls):
method(input, codes, codebooks, scales, parts, None)

end_event.record()
end_event.synchronize()

dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms


if __name__ == "__main__":
sys.exit(main())
15 changes: 15 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ void gelu_fast(
torch::Tensor& input);

#ifndef USE_ROCM
torch::Tensor aqlm_gemm(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias
);

torch::Tensor aqlm_dequant(
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes
);

torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
Expand Down
2 changes: 2 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

// Quantization ops
#ifndef USE_ROCM
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
Expand Down
Loading

0 comments on commit 2b7949c

Please sign in to comment.