From f5f4ef7d8c57ffe9d844672fca17761a912c5304 Mon Sep 17 00:00:00 2001 From: alexm Date: Tue, 14 May 2024 11:26:51 -0400 Subject: [PATCH 1/5] Add gptq marlin unit tests --- tests/kernels/marlin/marlin_utils.py | 172 +++++++++++++++++++++++ tests/kernels/marlin/quant_utils.py | 145 +++++++++++++++++++ tests/kernels/marlin/test_marlin_gemm.py | 157 +++++++++++++++++++++ 3 files changed, 474 insertions(+) create mode 100644 tests/kernels/marlin/marlin_utils.py create mode 100644 tests/kernels/marlin/quant_utils.py create mode 100644 tests/kernels/marlin/test_marlin_gemm.py diff --git a/tests/kernels/marlin/marlin_utils.py b/tests/kernels/marlin/marlin_utils.py new file mode 100644 index 0000000000000..6ecdae325f2de --- /dev/null +++ b/tests/kernels/marlin/marlin_utils.py @@ -0,0 +1,172 @@ +import numpy +import torch +from quant_utils import get_pack_factor, quantize_weights, sort_weights + +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_TILE) + +__cuda_arch = torch.cuda.get_device_capability() + + +def is_marlin_supported(): + return __cuda_arch[0] >= 8 + + +# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 +# +# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def _get_perms(num_bits): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +_perm = {} +_scale_perm = {} +_scale_perm_single = {} +for num_bits in [4, 8]: + perm, scale_perm, scale_perm_single = _get_perms(num_bits) + _perm[num_bits] = perm + _scale_perm[num_bits] = scale_perm + _scale_perm_single[num_bits] = scale_perm_single + + +def marlin_permute_weights(q_w, + size_k, + size_n, + num_bits, + tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape( + (-1, _perm[num_bits].numel()))[:, _perm[num_bits]].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, num_bits) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(_scale_perm[num_bits])))[:, + _scale_perm[num_bits]] + else: + s = s.reshape( + (-1, + len(_scale_perm_single[num_bits])))[:, + _scale_perm_single[num_bits]] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, + act_order: bool, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, + act_order) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, num_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +class MarlinWorkspace: + + def __init__(self, out_features): + assert (out_features % GPTQ_MARLIN_MIN_THREAD_N == 0), ( + "out_features = {} is undivisible by GPTQ_MARLIN_MIN_THREAD_N = {}" + .format(out_features, GPTQ_MARLIN_MIN_THREAD_N)) + + max_workspace_size = ((out_features // GPTQ_MARLIN_MIN_THREAD_N) * + GPTQ_MARLIN_MAX_PARALLEL) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") diff --git a/tests/kernels/marlin/quant_utils.py b/tests/kernels/marlin/quant_utils.py new file mode 100644 index 0000000000000..28888b1a0691c --- /dev/null +++ b/tests/kernels/marlin/quant_utils.py @@ -0,0 +1,145 @@ +import numpy +import torch + +SUPPORTED_NUM_BITS = [4, 8] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +def get_pack_factor(num_bits): + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size, ), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, + act_order: bool): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Reshape to [groupsize, -1] + if group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s + + # Restore original shapes + if group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + s = s.reshape((-1, size_n)).contiguous() + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) + + w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to( + dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res diff --git a/tests/kernels/marlin/test_marlin_gemm.py b/tests/kernels/marlin/test_marlin_gemm.py new file mode 100644 index 0000000000000..3288e497ee292 --- /dev/null +++ b/tests/kernels/marlin/test_marlin_gemm.py @@ -0,0 +1,157 @@ +"""Tests for the marlin kernel. + +Run `pytest tests/kernels/marlin/test_marlin_gemm.py`. +""" +import pytest +import torch +from marlin_utils import (MarlinWorkspace, is_marlin_supported, + marlin_quantize, marlin_weights) +from quant_utils import gptq_pack, quantize_weights, sort_weights + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) + +ACT_ORDER_OPTS = [False, True] +K_FULL_OPTS = [False, True] + +K_CHUNKS = [128, 256] +N_CHUNKS = [64, 128, 256] + +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (1, 7 * 4, 5 * 1), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), +] + + +def rand_data(shape): + data = torch.rand(shape).to(torch.half).cuda() + return data + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", K_CHUNKS) +@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, + mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Create input + b_weight = rand_data((size_k, size_n)) + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits, + group_size, act_order) + + # Pack to GPTQ format + q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Pack to Marlin format + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits) + + # Run Marlin repack GPU kernel + marlin_q_w_2 = ops.gptq_marlin_repack( + q_w_gptq, + sort_indices, + size_k, + size_n, + num_bits, + ) + torch.cuda.synchronize() + + assert torch.allclose(marlin_q_w_1, marlin_q_w_2) + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", K_CHUNKS) +@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) +@pytest.mark.parametrize("is_k_full", K_FULL_OPTS) +def test_marlin_gemm( + k_chunk, + n_chunk, + num_bits, + group_size, + mnk_factors, + act_order, + is_k_full, +): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + print(f"groupsize = {group_size}") + + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, num_bits, group_size, act_order) + + workspace = MarlinWorkspace(size_n) + + output = ops.gptq_marlin_gemm( + a_input, + marlin_q_w, + marlin_s, + g_idx, + sort_indices, + workspace.scratch, + num_bits, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full, + ) + output_ref = torch.matmul(a_input, w_ref) + + torch.cuda.synchronize() + + assert torch.allclose(output, output_ref, rtol=1e-2) From febeb2838134df04ddcd7fa603de965fb5ad7787 Mon Sep 17 00:00:00 2001 From: alexm Date: Tue, 14 May 2024 14:28:54 -0400 Subject: [PATCH 2/5] Add marlin benchmark and fix unit tests --- benchmarks/kernels/benchmark_marlin.py | 189 ++++++++++++++++++ benchmarks/kernels/benchmark_shapes.py | 75 +++++++ .../kernels/{marlin => }/test_marlin_gemm.py | 7 +- .../layers/quantization/utils/__init__.py | 0 .../quantization/utils}/marlin_utils.py | 3 +- .../layers/quantization/utils}/quant_utils.py | 0 6 files changed, 270 insertions(+), 4 deletions(-) create mode 100644 benchmarks/kernels/benchmark_marlin.py create mode 100644 benchmarks/kernels/benchmark_shapes.py rename tests/kernels/{marlin => }/test_marlin_gemm.py (94%) create mode 100644 vllm/model_executor/layers/quantization/utils/__init__.py rename {tests/kernels/marlin => vllm/model_executor/layers/quantization/utils}/marlin_utils.py (97%) rename {tests/kernels/marlin => vllm/model_executor/layers/quantization/utils}/quant_utils.py (100%) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py new file mode 100644 index 0000000000000..f6297ba55237e --- /dev/null +++ b/benchmarks/kernels/benchmark_marlin.py @@ -0,0 +1,189 @@ +import argparse +import torch +import torch.utils.benchmark as benchmark + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) + +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, quantize_weights, sort_weights) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_quantize, MarlinWorkspace) + +from benchmark_shapes import WEIGHT_SHAPES + +DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] + +ACT_ORDER_OPTS = [False, True] +K_FULL_OPTS = [False, True] + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results = [] + for model in args.models: + for layer in WEIGHT_SHAPES[model]: + size_k = layer[0] + size_n = layer[1] + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for act_order in ACT_ORDER_OPTS: + for is_k_full in K_FULL_OPTS: + for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: + if len( + args.limit_group_size + ) > 0 and group_size not in args.limit_group_size: + continue + + # For act_order, the group_size must be less than + # size_k + if act_order and (group_size == size_k + or group_size == -1): + continue + + for size_m in args.batch_sizes: + label = "Quant Matmul" + sub_label = ( + "{}, act={} k_full={}, b={}, g={}, " + "MKN=({}x{}x{})".format( + model, act_order, is_k_full, num_bits, + group_size, size_m, size_k, size_n)) + + print(f"Testing: {sub_label}") + + a = torch.randn(size_m, + size_k).to(torch.half).cuda() + b = torch.rand(size_k, + size_n).to(torch.half).cuda() + + a_tmp = (torch.zeros(size_m, size_k).to( + torch.half).cuda()) + + # Marlin quant + ( + marlin_w_ref, + marlin_q_w, + marlin_s, + marlin_g_idx, + marlin_sort_indices, + marlin_rand_perm, + ) = marlin_quantize(b, num_bits, group_size, + act_order) + + # GPTQ quant + w_ref, q_w, s, g_idx, rand_perm = quantize_weights( + b, num_bits, group_size, act_order) + q_w_gptq = gptq_pack(q_w, num_bits, size_k, + size_n) + + # For act_order, sort the "weights" and "g_idx" + # so that group ids are increasing + repack_sort_indices = torch.empty( + 0, dtype=torch.int, device=b.device) + if act_order: + (q_w, g_idx, + repack_sort_indices) = sort_weights( + q_w, g_idx) + + # Prepare + marlin_workspace = MarlinWorkspace(size_n) + + globals = { + "marlin_w_ref": marlin_w_ref, + "marlin_q_w": marlin_q_w, + "marlin_s": marlin_s, + "marlin_g_idx": marlin_g_idx, + "marlin_sort_indices": marlin_sort_indices, + "marlin_rand_perm": marlin_rand_perm, + "q_w_gptq": q_w_gptq, + "repack_sort_indices": repack_sort_indices, + "num_bits": num_bits, + "group_size": group_size, + "size_m": size_m, + "size_n": size_n, + "size_k": size_k, + "is_k_full": is_k_full, + "a": a, + "a_tmp": a_tmp, + "gptq_marlin_gemm": ops.gptq_marlin_gemm, + "gptq_marlin_repack": + ops.gptq_marlin_repack, + "marlin_workspace": marlin_workspace, + } + + min_run_time = 1 + + # Warmup pytorch + for i in range(5): + torch.matmul(a, marlin_w_ref) + + results.append( + benchmark.Timer( + stmt="torch.matmul(a, marlin_w_ref)", + globals=globals, + label=label, + sub_label=sub_label, + description="pytorch_gemm", + ).blocked_autorange( + min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_gemm", + ).blocked_autorange( + min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_repack", + ).blocked_autorange( + min_run_time=min_run_time)) + + compare = benchmark.Compare(results) + compare.print() + + +# For quick benchmarking use: +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501 +# +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark Marlin across specified models/shapes/batches") + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py new file mode 100644 index 0000000000000..4eeeca35a37cc --- /dev/null +++ b/benchmarks/kernels/benchmark_shapes.py @@ -0,0 +1,75 @@ +WEIGHT_SHAPES = { + "ideal": [[4 * 256 * 32, 256 * 32]], + "mistralai/Mistral-7B-v0.1/TP1": [ + [4096, 6144], + [4096, 4096], + [4096, 28672], + [14336, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP2": [ + [4096, 3072], + [2048, 4096], + [4096, 14336], + [7168, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP4": [ + [4096, 1536], + [1024, 4096], + [4096, 7168], + [3584, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP1": [ + [4096, 12288], + [4096, 4096], + [4096, 22016], + [11008, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP2": [ + [4096, 6144], + [2048, 4096], + [4096, 11008], + [5504, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP4": [ + [4096, 3072], + [1024, 4096], + [4096, 5504], + [2752, 4096], + ], + "meta-llama/Llama-2-13b-hf/TP1": [ + [5120, 15360], + [5120, 5120], + [5120, 27648], + [13824, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP2": [ + [5120, 7680], + [2560, 5120], + [5120, 13824], + [6912, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP4": [ + [5120, 3840], + [1280, 5120], + [5120, 6912], + [3456, 5120], + ], + "meta-llama/Llama-2-70b-hf/TP1": [ + [8192, 10240], + [8192, 8192], + [8192, 57344], + [28672, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP2": [ + [8192, 5120], + [4096, 8192], + [8192, 28672], + [14336, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP4": [ + [8192, 2560], + [2048, 8192], + [8192, 14336], + [7168, 8192], + ], +} diff --git a/tests/kernels/marlin/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py similarity index 94% rename from tests/kernels/marlin/test_marlin_gemm.py rename to tests/kernels/test_marlin_gemm.py index 3288e497ee292..b0ad85c25c572 100644 --- a/tests/kernels/marlin/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -4,13 +4,14 @@ """ import pytest import torch -from marlin_utils import (MarlinWorkspace, is_marlin_supported, - marlin_quantize, marlin_weights) -from quant_utils import gptq_pack, quantize_weights, sort_weights from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, quantize_weights, sort_weights) ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/kernels/marlin/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py similarity index 97% rename from tests/kernels/marlin/marlin_utils.py rename to vllm/model_executor/layers/quantization/utils/marlin_utils.py index 6ecdae325f2de..802fd26de8d65 100644 --- a/tests/kernels/marlin/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,9 +1,10 @@ import numpy import torch -from quant_utils import get_pack_factor, quantize_weights, sort_weights from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_TILE) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_pack_factor, quantize_weights, sort_weights) __cuda_arch = torch.cuda.get_device_capability() diff --git a/tests/kernels/marlin/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py similarity index 100% rename from tests/kernels/marlin/quant_utils.py rename to vllm/model_executor/layers/quantization/utils/quant_utils.py From 09ae113aecb068ac0a10b30ff67bf0b41fbffdf7 Mon Sep 17 00:00:00 2001 From: alexm Date: Tue, 14 May 2024 14:53:48 -0400 Subject: [PATCH 3/5] fixes --- benchmarks/kernels/benchmark_marlin.py | 216 +++++++++--------- .../layers/quantization/utils/marlin_utils.py | 6 +- 2 files changed, 107 insertions(+), 115 deletions(-) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index f6297ba55237e..0b1486840cff8 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -1,17 +1,16 @@ import argparse + import torch import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) - +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MarlinWorkspace, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_quantize, MarlinWorkspace) - -from benchmark_shapes import WEIGHT_SHAPES DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] @@ -20,6 +19,102 @@ K_FULL_OPTS = [False, True] +def bench_run(results, model, act_order, is_k_full, num_bits, group_size, + size_m, size_k, size_n): + label = "Quant Matmul" + sub_label = ("{}, act={} k_full={}, b={}, g={}, " + "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, + group_size, size_m, size_k, size_n)) + + print(f"Testing: {sub_label}") + + a = torch.randn(size_m, size_k).to(torch.half).cuda() + b = torch.rand(size_k, size_n).to(torch.half).cuda() + + a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) + + # Marlin quant + ( + marlin_w_ref, + marlin_q_w, + marlin_s, + marlin_g_idx, + marlin_sort_indices, + marlin_rand_perm, + ) = marlin_quantize(b, num_bits, group_size, act_order) + + # GPTQ quant + (w_ref, q_w, s, g_idx, + rand_perm) = quantize_weights(b, num_bits, group_size, act_order) + q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" + # so that group ids are increasing + repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) + if act_order: + (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) + + # Prepare + marlin_workspace = MarlinWorkspace(size_n) + + globals = { + "marlin_w_ref": marlin_w_ref, + "marlin_q_w": marlin_q_w, + "marlin_s": marlin_s, + "marlin_g_idx": marlin_g_idx, + "marlin_sort_indices": marlin_sort_indices, + "marlin_rand_perm": marlin_rand_perm, + "q_w_gptq": q_w_gptq, + "repack_sort_indices": repack_sort_indices, + "num_bits": num_bits, + "group_size": group_size, + "size_m": size_m, + "size_n": size_n, + "size_k": size_k, + "is_k_full": is_k_full, + "a": a, + "a_tmp": a_tmp, + "gptq_marlin_gemm": ops.gptq_marlin_gemm, + "gptq_marlin_repack": ops.gptq_marlin_repack, + "marlin_workspace": marlin_workspace, + } + + min_run_time = 1 + + # Warmup pytorch + for i in range(5): + torch.matmul(a, marlin_w_ref) + + results.append( + benchmark.Timer( + stmt="torch.matmul(a, marlin_w_ref)", + globals=globals, + label=label, + sub_label=sub_label, + description="pytorch_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_repack", + ).blocked_autorange(min_run_time=min_run_time)) + + def main(args): print("Benchmarking models:") for i, model in enumerate(args.models): @@ -53,118 +148,15 @@ def main(args): continue for size_m in args.batch_sizes: - label = "Quant Matmul" - sub_label = ( - "{}, act={} k_full={}, b={}, g={}, " - "MKN=({}x{}x{})".format( - model, act_order, is_k_full, num_bits, - group_size, size_m, size_k, size_n)) - - print(f"Testing: {sub_label}") - - a = torch.randn(size_m, - size_k).to(torch.half).cuda() - b = torch.rand(size_k, - size_n).to(torch.half).cuda() - - a_tmp = (torch.zeros(size_m, size_k).to( - torch.half).cuda()) - - # Marlin quant - ( - marlin_w_ref, - marlin_q_w, - marlin_s, - marlin_g_idx, - marlin_sort_indices, - marlin_rand_perm, - ) = marlin_quantize(b, num_bits, group_size, - act_order) - - # GPTQ quant - w_ref, q_w, s, g_idx, rand_perm = quantize_weights( - b, num_bits, group_size, act_order) - q_w_gptq = gptq_pack(q_w, num_bits, size_k, - size_n) - - # For act_order, sort the "weights" and "g_idx" - # so that group ids are increasing - repack_sort_indices = torch.empty( - 0, dtype=torch.int, device=b.device) - if act_order: - (q_w, g_idx, - repack_sort_indices) = sort_weights( - q_w, g_idx) - - # Prepare - marlin_workspace = MarlinWorkspace(size_n) - - globals = { - "marlin_w_ref": marlin_w_ref, - "marlin_q_w": marlin_q_w, - "marlin_s": marlin_s, - "marlin_g_idx": marlin_g_idx, - "marlin_sort_indices": marlin_sort_indices, - "marlin_rand_perm": marlin_rand_perm, - "q_w_gptq": q_w_gptq, - "repack_sort_indices": repack_sort_indices, - "num_bits": num_bits, - "group_size": group_size, - "size_m": size_m, - "size_n": size_n, - "size_k": size_k, - "is_k_full": is_k_full, - "a": a, - "a_tmp": a_tmp, - "gptq_marlin_gemm": ops.gptq_marlin_gemm, - "gptq_marlin_repack": - ops.gptq_marlin_repack, - "marlin_workspace": marlin_workspace, - } - - min_run_time = 1 - - # Warmup pytorch - for i in range(5): - torch.matmul(a, marlin_w_ref) - - results.append( - benchmark.Timer( - stmt="torch.matmul(a, marlin_w_ref)", - globals=globals, - label=label, - sub_label=sub_label, - description="pytorch_gemm", - ).blocked_autorange( - min_run_time=min_run_time)) - - results.append( - benchmark.Timer( - stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="gptq_marlin_gemm", - ).blocked_autorange( - min_run_time=min_run_time)) - - results.append( - benchmark.Timer( - stmt= - "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="gptq_marlin_repack", - ).blocked_autorange( - min_run_time=min_run_time)) + bench_run(results, model, act_order, is_k_full, + num_bits, group_size, size_m, size_k, + size_n) compare = benchmark.Compare(results) compare.print() -# For quick benchmarking use: +# For quick benchmarking use: # python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501 # if __name__ == "__main__": diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 802fd26de8d65..e6dd336f4ba16 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -22,7 +22,7 @@ def is_marlin_supported(): # As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 # (without the need to use ldmatrix instructions) # noqa: E501 def _get_perms(num_bits): - perm = [] + perm_list = [] for i in range(32): perm1 = [] col = i // 4 @@ -35,9 +35,9 @@ def _get_perms(num_bits): ]: perm1.append(16 * row + col + 8 * block) for j in range(4): - perm.extend([p + 256 * j for p in perm1]) + perm_list.extend([p + 256 * j for p in perm1]) - perm = numpy.array(perm) + perm = numpy.array(perm_list) if num_bits == 4: interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) From 665c6ebfc8cfbb334a43fc95e2a040774ddb6253 Mon Sep 17 00:00:00 2001 From: alexm Date: Wed, 15 May 2024 12:01:54 -0400 Subject: [PATCH 4/5] sync with Rob changes: --- benchmarks/kernels/benchmark_marlin.py | 1 + vllm/model_executor/layers/quantization/utils/marlin_utils.py | 1 + vllm/model_executor/layers/quantization/utils/quant_utils.py | 1 + 3 files changed, 3 insertions(+) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 0b1486840cff8..313d11ec0bf65 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -22,6 +22,7 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size, size_m, size_k, size_n): label = "Quant Matmul" + sub_label = ("{}, act={} k_full={}, b={}, g={}, " "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, group_size, size_m, size_k, size_n)) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index e6dd336f4ba16..33b3169983475 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,3 +1,4 @@ +"""This file is used for /tests and /benchmarks""" import numpy import torch diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 28888b1a0691c..177cb23f63cf4 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,3 +1,4 @@ +"""This file is used for /tests and /benchmarks""" import numpy import torch From 7872d4b534b2a2a5d32e7f1dc442541245ede728 Mon Sep 17 00:00:00 2001 From: alexm Date: Wed, 15 May 2024 14:43:22 -0400 Subject: [PATCH 5/5] sync --- benchmarks/kernels/benchmark_marlin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 313d11ec0bf65..5dcffc284f3d4 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -122,6 +122,7 @@ def main(args): print(f"[{i}] {model}") results = [] + for model in args.models: for layer in WEIGHT_SHAPES[model]: size_k = layer[0]