Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Add Exllama as a backend for compressed-tensors #9395

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_DISABLED_KERNELS: List[str] = []


def get_default_cache_root():
Expand Down Expand Up @@ -430,6 +431,14 @@ def get_default_config_root():
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0"
) == "1",

# List of quantization kernels that should be disabled, used for testing
# and performance comparisons. Currently only affects MPLinearKernel
# selection
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
"VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),
}

# end-env-vars-definition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self,
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
if c.zero_points:
assert w_zp_param_name is not None
if c.has_g_idx:
assert w_gidx_param_name is not None
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name

Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/quantization/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import List, Optional, Type

import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.exllama import (
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.machete import (
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.marlin import (
Expand All @@ -13,6 +15,7 @@
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
MarlinLinearKernel,
ExllamaLinearKernel,
]


Expand Down Expand Up @@ -45,8 +48,7 @@ def choose_mp_linear_kernel(

failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
.split(","):
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue
Expand Down
142 changes: 142 additions & 0 deletions vllm/model_executor/layers/quantization/kernels/exllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Optional, Tuple

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from vllm.scalar_type import scalar_types

from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig


class ExllamaLinearKernel(MPLinearKernel):
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
# currently untested so not added to the list

@classmethod
def get_min_capability(cls) -> int:
return 60

@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Exllama, "\
"when the input features are partitioned across "\
"devices"

if c.act_type != torch.float16:
return False, "Exllama only supports float16 activations"

if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return False, f"Quant type ({c.weight_type}) not supported by "\
"Exllama, supported types are: "\
f"{cls.SUPPORTED_QUANT_TYPES}"

if c.full_weight_shape[0] % c.group_size != 0:
return False, f"Group size ({c.group_size}) does not evenly divide"\
" the number of input features "\
f"({c.full_weight_shape[0]})"

return True, None

# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config

# For Exllama, we need to set a zero-point tensor if there is not one
if not c.zero_points:
self.w_zp_name = "qzeros"
device = getattr(layer, self.w_q_name).device
groups = c.full_weight_shape[0] // c.group_size
out_features = c.partition_weight_shape[1]

if c.weight_type.has_bias():
# if the type has a bias we have to create a zeros tensor that
# contains the bias values repeated for each group (-1 due to
# a bug in the original GPTQ checkpoint format leading to
# exllama kernel adding 1 to the zero points during inference)
# Documentation of the bug can be found here:
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
zeros = torch.full((groups, out_features),
c.weight_type.bias - 1,
dtype=torch.int32,
device=device)
else:
raise NotImplementedError(
"A 0 zero-point is not supported by Exllama due to "
"a bug in the original GPTQ checkpoint format leading to "
"exllama kernel adding 1 to the zero points during "
"inference")
print("zeros", zeros.shape)
zeros = pack_quantized_values_into_int32(zeros,
c.weight_type,
packed_dim=1)
print("zeros_packed", zeros.shape)
setattr(layer, self.w_zp_name,
torch.nn.Parameter(zeros, requires_grad=False))

if c.has_g_idx:

def transform_w_g_idx(x):
# Exllama wants the permutation array instead of the group
# incdices
LucasWilkinson marked this conversation as resolved.
Show resolved Hide resolved
return torch.argsort(x).to(torch.int)

self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
else:
self.w_gidx_name = "g_idx"
empty_g_idx = torch.nn.Parameter(torch.empty((0, ),
dtype=torch.int,
device=device),
requires_grad=False)
setattr(layer, self.w_gidx_name, empty_g_idx)

def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
assert self.w_gidx_name is not None
g_idx = getattr(layer, self.w_gidx_name)

permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x_cont = x.data.contiguous()
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
return x_cont

def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x.to(dtype=c.act_type)

# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config

x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )

w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)

#print(w_q.shape, w_s.shape, w_zp.shape, w_g_idx.shape)
LucasWilkinson marked this conversation as resolved.
Show resolved Hide resolved

assert w_zp is not None, "Zero points are not supported by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True,
c.weight_type.size_bits)

if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/quantization/kernels/machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_weights_into_int32, unpack_weights_into_int32)
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)

Expand Down Expand Up @@ -71,13 +71,13 @@ def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_weights_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_perm = x_unpacked[perm, :]
x.data = pack_weights_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = pack_quantized_values_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
return x
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
}


def pack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
def pack_quantized_values_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
Expand All @@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor,
return res.permute(inv_perm)


def unpack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
def unpack_quantized_values_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
Expand Down
2 changes: 2 additions & 0 deletions vllm/scalar_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class scalar_types:
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)

# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)

Expand Down