diff --git a/auto_gptq/nn_modules/qlinear/qlinear_hpu.py b/auto_gptq/nn_modules/qlinear/qlinear_hpu.py index 0c998da2..56dc5402 100644 --- a/auto_gptq/nn_modules/qlinear/qlinear_hpu.py +++ b/auto_gptq/nn_modules/qlinear/qlinear_hpu.py @@ -102,14 +102,96 @@ def post_init(self): self._preprocessing() def pack(self, linear, scales, zeros, g_idx): - #TODO: implement - raise NotImplementedError("QuantLinear HPU currently doesn't support packing") - - def set_packed(self, qlinear_cls): - self.qweight = qlinear_cls.qweight - self.qzeros = qlinear_cls.qzeros - self.scales = qlinear_cls.scales - self.bias = qlinear_cls.bias + W = linear.weight.data.clone() + if isinstance(linear, nn.Conv2d): + W = W.flatten(1) + if isinstance(linear, transformers.pytorch_utils.Conv1D): + W = W.t() + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().to(dtype=linear.weight.dtype) + if linear.bias is not None: + self.bias = linear.bias.clone().to(dtype=linear.weight.dtype) + + intweight = [] + for idx in range(self.infeatures): + g_idx = idx // self.group_size + intweight.append(torch.round((W[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + + i = 0 + row = 0 + qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) def forward(self, x): x_dtype = x.dtype diff --git a/tests/test_hpu_linear.py b/tests/test_hpu_linear.py index 0f141d01..4ceb020c 100644 --- a/tests/test_hpu_linear.py +++ b/tests/test_hpu_linear.py @@ -156,16 +156,14 @@ def test_qlinear_hpu(bits, group_size, infeatures, outfeatures, bias, scales_val zeros = torch.full((infeatures // group_size, outfeatures), 1, dtype=torch.int32) htcore.mark_step() + quant_hpu.pack(linear, s.clone().detach().T, zeros.clone().detach().T, g_idx=None) + htcore.mark_step() + quant_hpu.to("hpu") quant_ref_cuda_old.pack(linear, s.clone().detach().T, zeros.clone().detach().T, g_idx=None) htcore.mark_step() quant_ref_cuda_old.to("hpu") - #TODO: pack independently - quant_hpu.set_packed(quant_ref_cuda_old) - htcore.mark_step() - quant_hpu.to("hpu") - out_ref_cuda_old = quant_ref_cuda_old(input) htcore.mark_step() quant_hpu.post_init()