Skip to content

Commit

Permalink
Print warning if AutoAWQ cannot load extensions (#515)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Jul 23, 2024
1 parent 1716748 commit 47f64ac
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 8 deletions.
4 changes: 3 additions & 1 deletion awq/modules/linear/exllama.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import torch
import warnings
import torch.nn as nn
from awq.utils.packing_utils import unpack_reorder_pack

try:
import exl_ext # with CUDA kernels (AutoAWQ_kernels)

EXL_INSTALLED = True
except:
except Exception as ex:
EXL_INSTALLED = False
warnings.warn(f"AutoAWQ could not load ExLlama kernels extension. Details: {ex}")

# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
Expand Down
4 changes: 3 additions & 1 deletion awq/modules/linear/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import warnings
import torch.nn as nn
from typing import Dict
from awq.utils.packing_utils import unpack_reorder_pack
Expand All @@ -7,8 +8,9 @@
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels)

EXLV2_INSTALLED = True
except:
except Exception as ex:
EXLV2_INSTALLED = False
warnings.warn(f"AutoAWQ could not load ExLlamaV2 kernels extension. Details: {ex}")

# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
Expand Down
5 changes: 3 additions & 2 deletions awq/modules/linear/gemm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import warnings
import torch.nn as nn
from torch.autograd import Function
from awq.utils.utils import get_best_device
Expand All @@ -8,9 +9,9 @@
import awq_ext # with CUDA kernels (AutoAWQ_kernels)

AWQ_INSTALLED = True
except:
except Exception as ex:
AWQ_INSTALLED = False

warnings.warn(f"AutoAWQ could not load GEMM kernels extension. Details: {ex}")

# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class WQLinearMMFunction(Function):
Expand Down
5 changes: 3 additions & 2 deletions awq/modules/linear/gemv.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
import warnings
import torch.nn as nn

try:
import awq_ext # with CUDA kernels

AWQ_INSTALLED = True
except:
except Exception as ex:
AWQ_INSTALLED = False

warnings.warn(f"AutoAWQ could not load GEMV kernels extension. Details: {ex}")

def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
Expand Down
5 changes: 3 additions & 2 deletions awq/modules/linear/gemv_fast.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch
import warnings

try:
import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels)

AWQ_INSTALLED = True
except:
except Exception as ex:
AWQ_INSTALLED = False

warnings.warn(f"AutoAWQ could not load GEMVFast kernels extension. Details: {ex}")

def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
Expand Down

0 comments on commit 47f64ac

Please sign in to comment.