Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AQLM CUDA support #3287

Merged
merged 116 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 109 commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
079cba5
actual add kernel
jaemzfleming Feb 26, 2024
23c3f77
getting serious
jaemzfleming Feb 26, 2024
20a71fd
adding in mat mat, need to move the pytorch stuff, maybe add some aql…
jaemzfleming Feb 26, 2024
d0cf25a
load the codebooks, codes, and scales.
jaemzfleming Feb 27, 2024
40463e3
try to bind cpp aqlm entry point to python
jaemzfleming Feb 27, 2024
0e03c23
add aqlm
jaemzfleming Feb 27, 2024
26f8d83
fix print statements
jaemzfleming Feb 27, 2024
dad66ce
add comment
jaemzfleming Feb 28, 2024
77a8913
remove unused enum
jaemzfleming Feb 28, 2024
2bb6871
add a bunch of prints, add bias
jaemzfleming Feb 28, 2024
5f0c319
minor fix for scales
jaemzfleming Feb 28, 2024
024b54c
change
jaemzfleming Feb 28, 2024
84c2e2a
format
jaemzfleming Feb 29, 2024
8ea4d9d
try reversing some formatting changes
jaemzfleming Feb 29, 2024
b993971
restored
jaemzfleming Feb 29, 2024
1766886
add aqlm_cuda
jaemzfleming Feb 29, 2024
b673f47
restore formatting
jaemzfleming Feb 29, 2024
4e7d398
restore format
jaemzfleming Feb 29, 2024
4fc1426
more formatting
jaemzfleming Feb 29, 2024
ac2ef81
format
jaemzfleming Feb 29, 2024
30d2d42
restore formatting
jaemzfleming Feb 29, 2024
3fcb944
restore formatting
jaemzfleming Feb 29, 2024
4e7291a
formta
jaemzfleming Feb 29, 2024
39abbc0
first working aqlm
jaemzfleming Feb 29, 2024
8d7fa96
some improvements
jaemzfleming Feb 29, 2024
9a3dbe1
restore format
jaemzfleming Feb 29, 2024
e7c2601
make a central c++ aqlm entry point
jaemzfleming Feb 29, 2024
6eba035
add support for 2x8, worked shockingly easily
jaemzfleming Feb 29, 2024
604f66f
support more than one model
jaemzfleming Mar 1, 2024
ce63937
formatting
jaemzfleming Mar 1, 2024
6cbdff7
remove secondary aqlm loading
jaemzfleming Mar 1, 2024
a58d369
restore trailing space
jaemzfleming Mar 1, 2024
31f0ddc
remove some code
jaemzfleming Mar 1, 2024
edc80c6
remove some comments
jaemzfleming Mar 1, 2024
3253dc7
add some attributions
jaemzfleming Mar 1, 2024
fefe1c8
support 2 tp
jaemzfleming Mar 1, 2024
4b12ed6
better tp support
jaemzfleming Mar 1, 2024
e5c2010
format
jaemzfleming Mar 1, 2024
eef729f
comments
jaemzfleming Mar 1, 2024
d31241b
comments
jaemzfleming Mar 1, 2024
ba3c125
rename aqlm_test
jaemzfleming Mar 1, 2024
703fa79
better comments
jaemzfleming Mar 1, 2024
6e47ff6
better comment
jaemzfleming Mar 1, 2024
556178f
first attempt
jaemzfleming Mar 4, 2024
e23f1cd
got it working
jaemzfleming Mar 5, 2024
6253807
remove prints
jaemzfleming Mar 5, 2024
05ccd50
add arguments and options
jaemzfleming Mar 5, 2024
7b67492
rename shard_dim to just bool is_metadata
jaemzfleming Mar 5, 2024
0af6eb2
Merge branch 'jf/aqlm' into jf/aqlm-nosplit
jaemzfleming Mar 5, 2024
3aafb3c
use TORCH_CHECK
jaemzfleming Mar 5, 2024
ef608a6
cleanup aqlm_example
jaemzfleming Mar 5, 2024
3bf6e7e
Merge branch 'jf/aqlm' into jf/aqlm-nosplit
jaemzfleming Mar 5, 2024
5bacc9d
format
jaemzfleming Mar 5, 2024
2def434
some stuff
jaemzfleming Mar 5, 2024
821ee99
change 60 to 70 for min cap
jaemzfleming Mar 5, 2024
35eb873
Merge branch 'jf/aqlm' into jf/aqlm-nosplit
jaemzfleming Mar 5, 2024
d0816bf
format
jaemzfleming Mar 5, 2024
6372c64
make aqlm not rocm supported
jaemzfleming Mar 5, 2024
9f4d75f
Merge branch 'jf/aqlm-nosplit' into jf/aqlm
jaemzfleming Mar 5, 2024
83c2070
Add LICENSE file
jaemzfleming Mar 5, 2024
267b339
add reference
jaemzfleming Mar 5, 2024
0408789
add better license headers
jaemzfleming Mar 5, 2024
48838b8
add support for 2x8 optimization
jaemzfleming Mar 7, 2024
4822629
format
jaemzfleming Mar 7, 2024
c255f44
add better example models, and replace output_partition_size with sizes
jaemzfleming Mar 7, 2024
7acedee
Merge branch 'upstream-main' into jf/aqlm
jaemzfleming Mar 7, 2024
15d7206
format
jaemzfleming Mar 7, 2024
8df10d9
Add test_aqlm.py
mgoin Mar 7, 2024
a3039dd
remove comments
jaemzfleming Mar 7, 2024
84611e7
Merge branch 'jf/aqlm' of https://github.com/neuralmagic/nm-vllm into…
jaemzfleming Mar 7, 2024
2ecce81
put aqlm inside rocm block
jaemzfleming Mar 8, 2024
5864a00
add model to example
jaemzfleming Mar 8, 2024
58dbb01
remove comment
jaemzfleming Mar 8, 2024
7dc5f83
format
jaemzfleming Mar 8, 2024
8069375
fix test
jaemzfleming Mar 8, 2024
9891e22
Add dequantization kernel
jaemzfleming Mar 12, 2024
a51192f
Update csrc/quantization/aqlm/aqlm_cuda_entry.cpp
mgoin Mar 12, 2024
992d584
Update csrc/quantization/aqlm/aqlm_cuda_entry.cpp
mgoin Mar 12, 2024
9143b45
set gpu_memory_utilization
jaemzfleming Mar 12, 2024
5d24991
add benchmark and refactor a bit.
jaemzfleming Mar 14, 2024
5985acb
Merge branch 'upstream-main' into jf/aqlm
jaemzfleming Mar 15, 2024
c319d2a
Merge branch 'upstream-main' into jf/aqlm
jaemzfleming Mar 21, 2024
d9152e2
add aqlm
jaemzfleming Mar 21, 2024
0574dff
Add dequant methods
jaemzfleming Mar 21, 2024
39ca4a0
fix format
jaemzfleming Mar 21, 2024
522f990
formatA
jaemzfleming Mar 21, 2024
d2ac6b2
some format fixes
jaemzfleming Mar 21, 2024
bb66e3c
formatting
jaemzfleming Mar 21, 2024
11c7950
format
jaemzfleming Mar 21, 2024
fb78b95
remove dead space
jaemzfleming Mar 21, 2024
d73a92b
niceties for aqlm benchmark
jaemzfleming Mar 21, 2024
4406555
update the test file
jaemzfleming Mar 22, 2024
3622342
remove gpu_memory_utilization reduction
jaemzfleming Mar 22, 2024
3cf2a1b
Merge branch 'upstream-main' into jf/aqlm
jaemzfleming Mar 26, 2024
e2b3529
port over better dequant kernels from aqlm
jaemzfleming Mar 26, 2024
3d65a48
better threshold for aqlm
jaemzfleming Mar 26, 2024
421249c
Merge branch 'upstream-main' into jf/aqlm
jaemzfleming Mar 26, 2024
d033c85
format
jaemzfleming Mar 26, 2024
f950178
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin Apr 8, 2024
7c604fe
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin Apr 8, 2024
92206de
Update test point
mgoin Apr 9, 2024
811e2cc
Poke test again
mgoin Apr 9, 2024
a97353b
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin Apr 9, 2024
6ca51d4
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin Apr 10, 2024
22f7fae
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin Apr 11, 2024
2282157
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin Apr 15, 2024
d0e8d0c
Resolve create_weights updates
mgoin Apr 15, 2024
6bb89c0
Better test debug output (manually tested TP)
mgoin Apr 16, 2024
09d4a24
Merge branch 'vllm-project:main' into jf/aqlm
mgoin Apr 17, 2024
4d46f18
Delete csrc/quantization/aqlm/LICENSE
mgoin Apr 18, 2024
a29008d
Address comments
mgoin Apr 18, 2024
dacdb52
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin Apr 18, 2024
3852115
Update test
mgoin Apr 18, 2024
d367895
Cleanup namespaces
mgoin Apr 18, 2024
d34f23d
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin Apr 22, 2024
7283c23
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin Apr 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ set(VLLM_EXT_SRC

if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/quantization/aqlm/aqlm_cuda_entry.cpp"
"csrc/quantization/aqlm/aqlm_cuda_kernel.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
"csrc/custom_all_reduce.cu")
Expand Down
302 changes: 302 additions & 0 deletions benchmarks/kernels/benchmark_aqlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
import argparse
import os
import sys
from typing import Optional

import torch
import torch.nn.functional as F

from vllm._C import ops
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def torch_mult(
input: torch.Tensor, # [..., in_features]
weights: torch.Tensor,
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
) -> torch.Tensor:
output = F.linear(input, weights)
return output


def dequant_out_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:

weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)

if bias is None:
output = F.linear(input, weights, bias)
orig_shape = output.shape
flattened_output = output.view(-1, output.size(-1))
f_scales = scales.view(-1, scales.shape[0])
b_scales = f_scales.expand(flattened_output.shape[0], -1)
flattened_output *= b_scales
return flattened_output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)


def dequant_weight_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:

weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)

b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)


def dequant_no_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:

weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)

return F.linear(input, weights, bias)


# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
# the generic pytorch version.
# Just visual comparison.
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:

n = parts.sum().item()

device = torch.device('cuda:0')

code_range = (1 << bits) // 2
ingroups = 8

codes = torch.randint(-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)

codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)

count = 0
for index in range(16):
for i in range(8):
for book in range(nbooks):
codebooks[book, index, 0, i] = count * (10**book)
count += 1

print("codes shape", codes.shape)

for i in range(16):
for book in range(nbooks):
codes[0, i, book] = i
codes[0, -i, book] = i

weights = dequantize_weight(codes, codebooks, None)
weights2 = ops.aqlm_dequant(codes, codebooks, parts)

print("weights shape:", weights.shape)
print("weights2 shape:", weights2.shape)

print("weights are:", weights)
print("weights2 are:", weights2)

print("first 128 weights are", weights[0, 0:128].to(torch.int32))
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))

print("last 128 weights are", weights[0, -128:])
print("last 128 weights2 are:", weights2[0, -128:])


def main():

parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")

# Add arguments
parser.add_argument("--nbooks",
type=int,
default=1,
help="Number of codebooks (default: 1)")
parser.add_argument("--bits",
type=int,
default=16,
help="Number of bits per code element (default: 16)")
parser.add_argument(
"--test",
type=bool,
default=False,
help="Run the decompression/dequant tester rather than benchmarking "
"(default: False)")

# Parse the arguments
args = parser.parse_args()

# Extract values
nbooks = args.nbooks
bits = args.bits

if args.test:
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
return

# Otherwise, benchmark.
methods = [
ops.aqlm_gemm,
dequant_out_scale,
generic_dequantize_gemm,
optimized_dequantize_gemm,
dequant_weight_scale,
torch_mult,
dequant_no_scale,
]

filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
print(f"writing benchmarks to file {filename}")
with open(filename, "w") as f:
sys.stdout = f

print('m | k | n | n parts', end='')
for method in methods:
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
print('')

# These are reasonable prefill sizes.
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
(4096, (11008, 11008)), (11008, (4096, )))

# reasonable ranges for m.
for m in [
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
128, 256, 512, 1024, 1536, 2048, 3072, 4096
]:
print(f'{m}', file=sys.__stdout__)
for ksp in ksandpartions:
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
methods)

sys.stdout = sys.__stdout__


def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
methods):

# I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials = 1
num_trials = 1
Comment on lines +211 to +212
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be increased a bit so we have multiple measurements and proper warmup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice much difference, but I left a comment.


num_calls = 100

# warmup.
for method in methods:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)

n = parts.sum().item()
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')

for method in methods:
best_time_us = 1e20
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)

kernel_dur_us = 1000 * kernel_dur_ms

if kernel_dur_us < best_time_us:
best_time_us = kernel_dur_us

print(f' | {kernel_dur_us:.0f}', end='')

print('')


def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
nbooks: int, bits: int, method) -> float:

n = parts.sum().item()

device = torch.device('cuda:0')

input = torch.randn((1, m, k), dtype=torch.float16, device=device)

code_range = (1 << bits) // 2
ingroups = 8

codes = torch.randint(-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)

codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)

scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)

# for comparison to just a pytorch mult.
weights = torch.randn((n, k), dtype=torch.float16, device=device)

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()

if method is torch_mult:
for i in range(num_calls):
torch_mult(input, weights, scales)
else:
for i in range(num_calls):
method(input, codes, codebooks, scales, parts, None)

end_event.record()
end_event.synchronize()

dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms


if __name__ == "__main__":
sys.exit(main())
15 changes: 15 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ void gelu_fast(
torch::Tensor& input);

#ifndef USE_ROCM
torch::Tensor aqlm_gemm(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias
);

torch::Tensor aqlm_dequant(
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes
);

torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
Expand Down
2 changes: 2 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

// Quantization ops
#ifndef USE_ROCM
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
Expand Down
Loading
Loading