-
Notifications
You must be signed in to change notification settings - Fork 215
/
gemm.py
296 lines (251 loc) · 9.59 KB
/
gemm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import torch
import warnings
import torch.nn as nn
from torch.autograd import Function
from awq.utils.module import try_import
from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm
# NOTE: We check if awq_ext or triton is available. awq_ext will be preferred if both are installed.
awq_ext, msg = try_import("awq_ext")
user_has_been_warned = False
try:
from awq.modules.triton.gemm import awq_gemm_triton, awq_dequantize_triton
# covers both CUDA and ROCm
if torch.cuda.is_available():
TRITON_AVAILABLE = True
except ImportError:
TRITON_AVAILABLE = False
# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class WQLinearMMFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(
ctx,
x,
qweight,
qzeros,
scales,
w_bit=4,
group_size=128,
bias=None,
out_features=0,
):
# The forward pass can use ctx.
ctx.save_for_backward(x, qweight, qzeros, scales, bias)
ctx.out_features = out_features
out_shape = x.shape[:-1] + (out_features,)
x = x.to(torch.float16)
if awq_ext is not None:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
qweight, scales, qzeros, 0, 0, 0, False
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8
)
elif TRITON_AVAILABLE:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_dequantize_triton(qweight, scales, qzeros)
out = torch.matmul(x, out)
else:
out = awq_gemm_triton(
x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8,
)
else:
if not user_has_been_warned:
warnings.warn("Using naive (slow) implementation." + msg)
user_has_been_warned = True
out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
out = torch.matmul(x, out)
out = out + bias if bias is not None else out
out = out.reshape(out_shape)
# always want 3D tensor if tensor is 2D
if len(out.shape) == 2:
out = out.unsqueeze(0)
return out
@staticmethod
def backward(ctx, grad_output):
input, qweight, qzeros, scales, bias = ctx.saved_tensors
if awq_ext is None and not TRITON_AVAILABLE:
raise ValueError(
"either triton or autoawq-kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
" by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels"
)
# Cast to correct dtype for mixed precision training
if awq_ext is not None:
weights = awq_ext.dequantize_weights_cuda(
qweight, scales, qzeros, 1, 0, 0, False
).to(grad_output.dtype)
else:
weights = awq_dequantize_triton(
qweight, scales, qzeros
).to(grad_output.dtype)
if ctx.needs_input_grad[0]:
# 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm
# to propagate gradient across all batch sizes.
batch_size = grad_output.shape[0]
grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1))
return grad_input, None, None, None, None, None, None, None
class WQLinear_GEMM(nn.Module):
def __init__(
self, w_bit, group_size, in_features, out_features, bias, dev, training=False
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.training = training
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
awq_linear.scales = scales.clone().half()
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
pack_num = 32 // awq_linear.w_bit
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[idx // group_size])
/ awq_linear.scales[idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)
best_device = get_best_device()
# Avoid: The operator 'aten::__lshift__.Scalar' is not currently implemented for the MPS device
if "mps" in best_device:
intweight = intweight.to("cpu")
qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32, device=best_device)
if "mps" in best_device:
zeros = zeros.to("cpu")
qzeros = torch.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=zeros.device,
)
for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
awq_linear.qzeros = qzeros
return awq_linear
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
if self.training:
out = WQLinearMMFunction.apply(
x,
self.qweight,
self.qzeros,
self.scales,
self.w_bit,
self.group_size,
self.bias,
self.out_features,
)
else:
with torch.no_grad():
out = WQLinearMMFunction.apply(
x,
self.qweight,
self.qzeros,
self.scales,
self.w_bit,
self.group_size,
self.bias,
self.out_features,
)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
return out.reshape(out_shape)
def extra_repr(self) -> str:
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)