-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AQLM CUDA support #3287
Merged
robertgshaw2-neuralmagic
merged 116 commits into
vllm-project:main
from
neuralmagic:jf/aqlm
Apr 23, 2024
Merged
AQLM CUDA support #3287
Changes from 109 commits
Commits
Show all changes
116 commits
Select commit
Hold shift + click to select a range
079cba5
actual add kernel
jaemzfleming 23c3f77
getting serious
jaemzfleming 20a71fd
adding in mat mat, need to move the pytorch stuff, maybe add some aql…
jaemzfleming d0cf25a
load the codebooks, codes, and scales.
jaemzfleming 40463e3
try to bind cpp aqlm entry point to python
jaemzfleming 0e03c23
add aqlm
jaemzfleming 26f8d83
fix print statements
jaemzfleming dad66ce
add comment
jaemzfleming 77a8913
remove unused enum
jaemzfleming 2bb6871
add a bunch of prints, add bias
jaemzfleming 5f0c319
minor fix for scales
jaemzfleming 024b54c
change
jaemzfleming 84c2e2a
format
jaemzfleming 8ea4d9d
try reversing some formatting changes
jaemzfleming b993971
restored
jaemzfleming 1766886
add aqlm_cuda
jaemzfleming b673f47
restore formatting
jaemzfleming 4e7d398
restore format
jaemzfleming 4fc1426
more formatting
jaemzfleming ac2ef81
format
jaemzfleming 30d2d42
restore formatting
jaemzfleming 3fcb944
restore formatting
jaemzfleming 4e7291a
formta
jaemzfleming 39abbc0
first working aqlm
jaemzfleming 8d7fa96
some improvements
jaemzfleming 9a3dbe1
restore format
jaemzfleming e7c2601
make a central c++ aqlm entry point
jaemzfleming 6eba035
add support for 2x8, worked shockingly easily
jaemzfleming 604f66f
support more than one model
jaemzfleming ce63937
formatting
jaemzfleming 6cbdff7
remove secondary aqlm loading
jaemzfleming a58d369
restore trailing space
jaemzfleming 31f0ddc
remove some code
jaemzfleming edc80c6
remove some comments
jaemzfleming 3253dc7
add some attributions
jaemzfleming fefe1c8
support 2 tp
jaemzfleming 4b12ed6
better tp support
jaemzfleming e5c2010
format
jaemzfleming eef729f
comments
jaemzfleming d31241b
comments
jaemzfleming ba3c125
rename aqlm_test
jaemzfleming 703fa79
better comments
jaemzfleming 6e47ff6
better comment
jaemzfleming 556178f
first attempt
jaemzfleming e23f1cd
got it working
jaemzfleming 6253807
remove prints
jaemzfleming 05ccd50
add arguments and options
jaemzfleming 7b67492
rename shard_dim to just bool is_metadata
jaemzfleming 0af6eb2
Merge branch 'jf/aqlm' into jf/aqlm-nosplit
jaemzfleming 3aafb3c
use TORCH_CHECK
jaemzfleming ef608a6
cleanup aqlm_example
jaemzfleming 3bf6e7e
Merge branch 'jf/aqlm' into jf/aqlm-nosplit
jaemzfleming 5bacc9d
format
jaemzfleming 2def434
some stuff
jaemzfleming 821ee99
change 60 to 70 for min cap
jaemzfleming 35eb873
Merge branch 'jf/aqlm' into jf/aqlm-nosplit
jaemzfleming d0816bf
format
jaemzfleming 6372c64
make aqlm not rocm supported
jaemzfleming 9f4d75f
Merge branch 'jf/aqlm-nosplit' into jf/aqlm
jaemzfleming 83c2070
Add LICENSE file
jaemzfleming 267b339
add reference
jaemzfleming 0408789
add better license headers
jaemzfleming 48838b8
add support for 2x8 optimization
jaemzfleming 4822629
format
jaemzfleming c255f44
add better example models, and replace output_partition_size with sizes
jaemzfleming 7acedee
Merge branch 'upstream-main' into jf/aqlm
jaemzfleming 15d7206
format
jaemzfleming 8df10d9
Add test_aqlm.py
mgoin a3039dd
remove comments
jaemzfleming 84611e7
Merge branch 'jf/aqlm' of https://github.com/neuralmagic/nm-vllm into…
jaemzfleming 2ecce81
put aqlm inside rocm block
jaemzfleming 5864a00
add model to example
jaemzfleming 58dbb01
remove comment
jaemzfleming 7dc5f83
format
jaemzfleming 8069375
fix test
jaemzfleming 9891e22
Add dequantization kernel
jaemzfleming a51192f
Update csrc/quantization/aqlm/aqlm_cuda_entry.cpp
mgoin 992d584
Update csrc/quantization/aqlm/aqlm_cuda_entry.cpp
mgoin 9143b45
set gpu_memory_utilization
jaemzfleming 5d24991
add benchmark and refactor a bit.
jaemzfleming 5985acb
Merge branch 'upstream-main' into jf/aqlm
jaemzfleming c319d2a
Merge branch 'upstream-main' into jf/aqlm
jaemzfleming d9152e2
add aqlm
jaemzfleming 0574dff
Add dequant methods
jaemzfleming 39ca4a0
fix format
jaemzfleming 522f990
formatA
jaemzfleming d2ac6b2
some format fixes
jaemzfleming bb66e3c
formatting
jaemzfleming 11c7950
format
jaemzfleming fb78b95
remove dead space
jaemzfleming d73a92b
niceties for aqlm benchmark
jaemzfleming 4406555
update the test file
jaemzfleming 3622342
remove gpu_memory_utilization reduction
jaemzfleming 3cf2a1b
Merge branch 'upstream-main' into jf/aqlm
jaemzfleming e2b3529
port over better dequant kernels from aqlm
jaemzfleming 3d65a48
better threshold for aqlm
jaemzfleming 421249c
Merge branch 'upstream-main' into jf/aqlm
jaemzfleming d033c85
format
jaemzfleming f950178
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin 7c604fe
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin 92206de
Update test point
mgoin 811e2cc
Poke test again
mgoin a97353b
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin 6ca51d4
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin 22f7fae
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin 2282157
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin d0e8d0c
Resolve create_weights updates
mgoin 6bb89c0
Better test debug output (manually tested TP)
mgoin 09d4a24
Merge branch 'vllm-project:main' into jf/aqlm
mgoin 4d46f18
Delete csrc/quantization/aqlm/LICENSE
mgoin a29008d
Address comments
mgoin dacdb52
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin 3852115
Update test
mgoin d367895
Cleanup namespaces
mgoin d34f23d
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin 7283c23
Merge remote-tracking branch 'upstream/main' into jf/aqlm
mgoin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
import argparse | ||
import os | ||
import sys | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from vllm._C import ops | ||
from vllm.model_executor.layers.quantization.aqlm import ( | ||
dequantize_weight, generic_dequantize_gemm, get_int_dtype, | ||
optimized_dequantize_gemm) | ||
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
|
||
|
||
def torch_mult( | ||
input: torch.Tensor, # [..., in_features] | ||
weights: torch.Tensor, | ||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] | ||
) -> torch.Tensor: | ||
output = F.linear(input, weights) | ||
return output | ||
|
||
|
||
def dequant_out_scale( | ||
input: torch.Tensor, # [..., in_features] | ||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] | ||
codebooks: torch. | ||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] | ||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] | ||
output_partition_sizes: torch.IntTensor, | ||
bias: Optional[torch.Tensor], | ||
) -> torch.Tensor: | ||
|
||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) | ||
|
||
if bias is None: | ||
output = F.linear(input, weights, bias) | ||
orig_shape = output.shape | ||
flattened_output = output.view(-1, output.size(-1)) | ||
f_scales = scales.view(-1, scales.shape[0]) | ||
b_scales = f_scales.expand(flattened_output.shape[0], -1) | ||
flattened_output *= b_scales | ||
return flattened_output.view(orig_shape) | ||
else: | ||
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( | ||
-1, weights.shape[1]) | ||
weights *= b_scales | ||
return F.linear(input, weights, bias) | ||
|
||
|
||
def dequant_weight_scale( | ||
input: torch.Tensor, # [..., in_features] | ||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] | ||
codebooks: torch. | ||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] | ||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] | ||
output_partition_sizes: torch.IntTensor, | ||
bias: Optional[torch.Tensor], | ||
) -> torch.Tensor: | ||
|
||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) | ||
|
||
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( | ||
-1, weights.shape[1]) | ||
weights *= b_scales | ||
return F.linear(input, weights, bias) | ||
|
||
|
||
def dequant_no_scale( | ||
input: torch.Tensor, # [..., in_features] | ||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] | ||
codebooks: torch. | ||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] | ||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1] | ||
output_partition_sizes: torch.IntTensor, | ||
bias: Optional[torch.Tensor], | ||
) -> torch.Tensor: | ||
|
||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) | ||
|
||
return F.linear(input, weights, bias) | ||
|
||
|
||
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against | ||
# the generic pytorch version. | ||
# Just visual comparison. | ||
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: | ||
|
||
n = parts.sum().item() | ||
|
||
device = torch.device('cuda:0') | ||
|
||
code_range = (1 << bits) // 2 | ||
ingroups = 8 | ||
|
||
codes = torch.randint(-code_range, | ||
code_range, | ||
size=(n, k // ingroups, nbooks), | ||
dtype=get_int_dtype(bits), | ||
device=device) | ||
|
||
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), | ||
dtype=torch.float16, | ||
device=device) | ||
|
||
count = 0 | ||
for index in range(16): | ||
for i in range(8): | ||
for book in range(nbooks): | ||
codebooks[book, index, 0, i] = count * (10**book) | ||
count += 1 | ||
|
||
print("codes shape", codes.shape) | ||
|
||
for i in range(16): | ||
for book in range(nbooks): | ||
codes[0, i, book] = i | ||
codes[0, -i, book] = i | ||
|
||
weights = dequantize_weight(codes, codebooks, None) | ||
weights2 = ops.aqlm_dequant(codes, codebooks, parts) | ||
|
||
print("weights shape:", weights.shape) | ||
print("weights2 shape:", weights2.shape) | ||
|
||
print("weights are:", weights) | ||
print("weights2 are:", weights2) | ||
|
||
print("first 128 weights are", weights[0, 0:128].to(torch.int32)) | ||
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32)) | ||
|
||
print("last 128 weights are", weights[0, -128:]) | ||
print("last 128 weights2 are:", weights2[0, -128:]) | ||
|
||
|
||
def main(): | ||
|
||
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.") | ||
|
||
# Add arguments | ||
parser.add_argument("--nbooks", | ||
type=int, | ||
default=1, | ||
help="Number of codebooks (default: 1)") | ||
parser.add_argument("--bits", | ||
type=int, | ||
default=16, | ||
help="Number of bits per code element (default: 16)") | ||
parser.add_argument( | ||
"--test", | ||
type=bool, | ||
default=False, | ||
help="Run the decompression/dequant tester rather than benchmarking " | ||
"(default: False)") | ||
|
||
# Parse the arguments | ||
args = parser.parse_args() | ||
|
||
# Extract values | ||
nbooks = args.nbooks | ||
bits = args.bits | ||
|
||
if args.test: | ||
dequant_test(4096, torch.tensor((4096, )), nbooks, bits) | ||
return | ||
|
||
# Otherwise, benchmark. | ||
methods = [ | ||
ops.aqlm_gemm, | ||
dequant_out_scale, | ||
generic_dequantize_gemm, | ||
optimized_dequantize_gemm, | ||
dequant_weight_scale, | ||
torch_mult, | ||
dequant_no_scale, | ||
] | ||
|
||
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv" | ||
print(f"writing benchmarks to file {filename}") | ||
with open(filename, "w") as f: | ||
sys.stdout = f | ||
|
||
print('m | k | n | n parts', end='') | ||
for method in methods: | ||
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='') | ||
print('') | ||
|
||
# These are reasonable prefill sizes. | ||
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )), | ||
(4096, (11008, 11008)), (11008, (4096, ))) | ||
|
||
# reasonable ranges for m. | ||
for m in [ | ||
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112, | ||
128, 256, 512, 1024, 1536, 2048, 3072, 4096 | ||
]: | ||
print(f'{m}', file=sys.__stdout__) | ||
for ksp in ksandpartions: | ||
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, | ||
methods) | ||
|
||
sys.stdout = sys.__stdout__ | ||
|
||
|
||
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, | ||
methods): | ||
|
||
# I didn't see visible improvements from increasing these, but feel free :) | ||
num_warmup_trials = 1 | ||
num_trials = 1 | ||
|
||
num_calls = 100 | ||
|
||
# warmup. | ||
for method in methods: | ||
for _ in range(num_warmup_trials): | ||
run_timing( | ||
num_calls=num_calls, | ||
m=m, | ||
k=k, | ||
parts=parts, | ||
nbooks=nbooks, | ||
bits=bits, | ||
method=method, | ||
) | ||
|
||
n = parts.sum().item() | ||
print(f'{m} | {k} | {n} | {parts.tolist()}', end='') | ||
|
||
for method in methods: | ||
best_time_us = 1e20 | ||
for _ in range(num_trials): | ||
kernel_dur_ms = run_timing( | ||
num_calls=num_calls, | ||
m=m, | ||
k=k, | ||
parts=parts, | ||
nbooks=nbooks, | ||
bits=bits, | ||
method=method, | ||
) | ||
|
||
kernel_dur_us = 1000 * kernel_dur_ms | ||
|
||
if kernel_dur_us < best_time_us: | ||
best_time_us = kernel_dur_us | ||
|
||
print(f' | {kernel_dur_us:.0f}', end='') | ||
|
||
print('') | ||
|
||
|
||
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor, | ||
nbooks: int, bits: int, method) -> float: | ||
|
||
n = parts.sum().item() | ||
|
||
device = torch.device('cuda:0') | ||
|
||
input = torch.randn((1, m, k), dtype=torch.float16, device=device) | ||
|
||
code_range = (1 << bits) // 2 | ||
ingroups = 8 | ||
|
||
codes = torch.randint(-code_range, | ||
code_range, | ||
size=(n, k // ingroups, nbooks), | ||
dtype=get_int_dtype(bits), | ||
device=device) | ||
|
||
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), | ||
dtype=torch.float16, | ||
device=device) | ||
|
||
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) | ||
|
||
# for comparison to just a pytorch mult. | ||
weights = torch.randn((n, k), dtype=torch.float16, device=device) | ||
|
||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
|
||
start_event.record() | ||
|
||
if method is torch_mult: | ||
for i in range(num_calls): | ||
torch_mult(input, weights, scales) | ||
else: | ||
for i in range(num_calls): | ||
method(input, codes, codebooks, scales, parts, None) | ||
|
||
end_event.record() | ||
end_event.synchronize() | ||
|
||
dur_ms = start_event.elapsed_time(end_event) / num_calls | ||
return dur_ms | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be increased a bit so we have multiple measurements and proper warmup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't notice much difference, but I left a comment.