Skip to content

Commit

Permalink
AWQ Triton kernels. Make autoawq-kernels optional. (#608)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Sep 12, 2024
1 parent 8d903b2 commit ae77736
Show file tree
Hide file tree
Showing 15 changed files with 484 additions and 232 deletions.
40 changes: 9 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,41 +46,19 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up
- Your NVIDIA GPU(s) must be of Compute Capability 7.5. Turing and later architectures are supported.
- Your CUDA version must be CUDA 11.8 or later.
- AMD:
- Your ROCm version must be ROCm 5.6 or later.
- Your ROCm version must be compatible with Triton.

### Install from PyPi

To install the newest AutoAWQ from PyPi, you need CUDA 12.1 installed.
There are a few ways to install AutoAWQ:

```
pip install autoawq
```

### Build from source

For CUDA 11.8, ROCm 5.6, and ROCm 5.7, you can install wheels from the [release page](https://github.com/casper-hansen/AutoAWQ/releases/latest):

```
pip install autoawq@https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl
```

Or from the main branch directly:

```
pip install autoawq@https://github.com/casper-hansen/AutoAWQ.git
```

Or by cloning the repository and installing from source:

```
git clone https://github.com/casper-hansen/AutoAWQ
cd AutoAWQ
pip install -e .
```

All three methods will install the latest and correct kernels for your system from [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases).

If your system is not supported (i.e. not on the release page), you can build the kernels yourself by following the instructions in [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases) and then install AutoAWQ from source.
1. Default:
- `pip install autoawq`
- NOTE: The default installation includes no external kernels and relies on Triton for inference.

2. From main branch with kernels:
- `INSTALL_KERNELS=1 pip install git+https://github.com/casper-hansen/AutoAWQ.git`
- NOTE: This installs https://github.com/casper-hansen/AutoAWQ_kernels

## Usage

Expand Down
8 changes: 7 additions & 1 deletion awq/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import gc
import json
import warnings
import logging
import torch
import transformers
Expand Down Expand Up @@ -30,6 +31,7 @@
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize,
try_import,
)
from awq.utils.utils import get_best_device, qbits_available
from transformers import (
Expand Down Expand Up @@ -530,8 +532,12 @@ def from_quantized(
)

# Dispath to devices
awq_ext, msg = try_import("awq_ext")
if fuse_layers:
self.fuse_layers(model)
if awq_ext is None:
warnings.warn("Skipping fusing modules because AWQ extension is not installed." + msg)
else:
self.fuse_layers(model)

if use_cpu_qbits:
dtype = torch.bfloat16 if check_isa_supported("AMX") else torch.float32
Expand Down
21 changes: 4 additions & 17 deletions awq/modules/linear/exllama.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import torch
import warnings
import torch.nn as nn
from awq.utils.module import try_import
from awq.utils.packing_utils import unpack_reorder_pack

try:
import exl_ext # with CUDA kernels (AutoAWQ_kernels)

EXL_INSTALLED = True
except Exception as ex:
EXL_INSTALLED = False
warnings.warn(f"AutoAWQ could not load ExLlama kernels extension. Details: {ex}")
exl_ext, msg = try_import("exl_ext")

# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
Expand Down Expand Up @@ -106,15 +100,8 @@ def forward(self, x):
"module.post_init() must be called before module.forward(). "
"Use exllama_post_init() on the whole model."
)
assert EXL_INSTALLED, (
"Exllama kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)

assert EXL_INSTALLED, (
"ExllamaV2 kernels are not installed. "
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
)
if exl_ext is None:
raise ModuleNotFoundError("External ExLlama kernels are not properly installed." + msg)

input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
Expand Down
15 changes: 4 additions & 11 deletions awq/modules/linear/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
import warnings
import torch.nn as nn
from typing import Dict
from awq.utils.module import try_import
from awq.utils.packing_utils import unpack_reorder_pack

try:
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels)

EXLV2_INSTALLED = True
except Exception as ex:
EXLV2_INSTALLED = False
warnings.warn(f"AutoAWQ could not load ExLlamaV2 kernels extension. Details: {ex}")
exlv2_ext, msg = try_import("exlv2_ext")

# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
Expand Down Expand Up @@ -133,10 +128,8 @@ def forward(self, x):
"module.post_init() must be called before module.forward(). "
"Use exllamav2_post_init() on the whole model."
)
assert EXLV2_INSTALLED, (
"ExllamaV2 kernels are not installed. "
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
)
if exlv2_ext is None:
raise ModuleNotFoundError("External ExLlamaV2 kernels are not properly installed." + msg)

input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
Expand Down
52 changes: 40 additions & 12 deletions awq/modules/linear/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@
import warnings
import torch.nn as nn
from torch.autograd import Function
from awq.utils.module import try_import
from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm

# NOTE: We check if awq_ext or triton is available. awq_ext will be preferred if both are installed.

awq_ext, msg = try_import("awq_ext")
user_has_been_warned = False

try:
import awq_ext # with CUDA kernels (AutoAWQ_kernels)
from awq.modules.triton.gemm import awq_gemm_triton, awq_dequantize_triton

AWQ_INSTALLED = True
except Exception as ex:
AWQ_INSTALLED = False
warnings.warn(f"AutoAWQ could not load GEMM kernels extension. Details: {ex}")
# covers both CUDA and ROCm
if torch.cuda.is_available():
TRITON_AVAILABLE = True

except ImportError:
TRITON_AVAILABLE = False

# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class WQLinearMMFunction(Function):
Expand All @@ -35,7 +43,7 @@ def forward(
out_shape = x.shape[:-1] + (out_features,)
x = x.to(torch.float16)

if AWQ_INSTALLED:
if awq_ext is not None:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024

if FP16_MATMUL_HEURISTIC_CONDITION:
Expand All @@ -47,7 +55,22 @@ def forward(
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8
)

elif TRITON_AVAILABLE:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024

if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_dequantize_triton(qweight, scales, qzeros)
out = torch.matmul(x, out)
else:
out = awq_gemm_triton(
x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8,
)

else:
if not user_has_been_warned:
warnings.warn("Using naive (slow) implementation." + msg)
user_has_been_warned = True
out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
out = torch.matmul(x, out)

Expand All @@ -64,16 +87,21 @@ def forward(
def backward(ctx, grad_output):
input, qweight, qzeros, scales, bias = ctx.saved_tensors

if not AWQ_INSTALLED:
if awq_ext is None and not TRITON_AVAILABLE:
raise ValueError(
"auto-awq kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
"either triton or autoawq-kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
" by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels"
)

# Cast to correct dtype for mixed precision training
weights = awq_ext.dequantize_weights_cuda(
qweight, scales, qzeros, 1, 0, 0, False
).to(grad_output.dtype)
if awq_ext is not None:
weights = awq_ext.dequantize_weights_cuda(
qweight, scales, qzeros, 1, 0, 0, False
).to(grad_output.dtype)
else:
weights = awq_dequantize_triton(
qweight, scales, qzeros
).to(grad_output.dtype)

if ctx.needs_input_grad[0]:
# 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm
Expand Down
24 changes: 12 additions & 12 deletions awq/modules/linear/gemm_qbits.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch
import torch.nn as nn
from awq.utils.module import try_import
from ...utils.packing_utils import reverse_awq_order, unpack_awq

try:
from intel_extension_for_transformers import qbits # with QBits kernels ()

QBITS_INSTALLED = True
except:
QBITS_INSTALLED = False
intel_extension_for_transformers, msg = try_import("intel_extension_for_transformers")
if intel_extension_for_transformers is not None:
qbits = getattr(intel_extension_for_transformers, 'qbits')

BITS_DTYPE_MAPPING = {
4: "int4_clip",
Expand All @@ -34,8 +32,8 @@ class WQLinear_QBits(nn.Module):

def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev):
super().__init__()
assert QBITS_INSTALLED, \
"Please install ITREX qbits package with `pip install intel-extension-for-transformers`."
if intel_extension_for_transformers is None:
raise ModuleNotFoundError("Please install ITREX qbits package with `pip install intel-extension-for-transformers`." + msg)

self.use_bf16 = qbits.check_isa_supported("AMX")

Expand Down Expand Up @@ -118,10 +116,12 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze

@torch.no_grad()
def forward(self, x):
assert QBITS_INSTALLED, (
"QBits kernels could not be loaded. "
"Please install with `pip install intel-extension-for-transformers` and "
"refer to the detial https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md")
if intel_extension_for_transformers is None:
raise ModuleNotFoundError(
"QBits kernels could not be loaded. "
"Please install with `pip install intel-extension-for-transformers` and "
"refer to the detial https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md"
)

input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
Expand Down
15 changes: 4 additions & 11 deletions awq/modules/linear/gemv.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import torch
import warnings
import torch.nn as nn
from awq.utils.module import try_import

try:
import awq_ext # with CUDA kernels

AWQ_INSTALLED = True
except Exception as ex:
AWQ_INSTALLED = False
warnings.warn(f"AutoAWQ could not load GEMV kernels extension. Details: {ex}")
awq_ext, msg = try_import("awq_ext")

def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
Expand Down Expand Up @@ -160,10 +155,8 @@ def from_linear(

@torch.no_grad()
def forward(self, x):
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
if awq_ext is None:
raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg)

out_shape = x.shape[:-1] + (self.out_features,)
inputs = x.reshape(-1, x.shape[-1])
Expand Down
11 changes: 4 additions & 7 deletions awq/modules/linear/gemv_fast.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import torch
import warnings
from awq.utils.module import try_import

try:
import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels)

AWQ_INSTALLED = True
except Exception as ex:
AWQ_INSTALLED = False
warnings.warn(f"AutoAWQ could not load GEMVFast kernels extension. Details: {ex}")
awq_v2_ext, msg = try_import("awq_v2_ext")

def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
Expand Down Expand Up @@ -189,6 +184,8 @@ def from_linear(

@torch.no_grad()
def forward(self, x):
if awq_v2_ext is None:
raise ModuleNotFoundError("External AWQ V2 kernels are not properly installed." + msg)
inputs = x
batch_size, n_tokens, _ = inputs.shape
if batch_size < 8 and n_tokens == 1:
Expand Down
15 changes: 4 additions & 11 deletions awq/modules/linear/marlin.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import torch
import torch.nn as nn
import numpy as np
from awq.utils.module import try_import

try:
import marlin_cuda # with CUDA kernels (AutoAWQ_kernels)

MARLIN_INSTALLED = True
except:
MARLIN_INSTALLED = False

marlin_cuda, msg = try_import("marlin_cuda")

def _get_perms():
perm = []
Expand Down Expand Up @@ -179,10 +174,8 @@ def forward(self, x):
"module.post_init() must be called before module.forward(). "
"Use marlin_post_init() on the whole model."
)
assert MARLIN_INSTALLED, (
"Marlin kernels are not installed. "
"Please install AWQ compatible Marlin kernels from AutoAWQ_kernels."
)
if marlin_cuda is None:
raise ModuleNotFoundError("External Marlin kernels are not properly installed." + msg)

out_shape = x.shape[:-1] + (self.out_features,)

Expand Down
Empty file added awq/modules/triton/__init__.py
Empty file.
Loading

0 comments on commit ae77736

Please sign in to comment.