From 58c22f356b1255b3456e1d07163f93125e75afb0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Oct 2023 13:38:47 +0100 Subject: [PATCH] Small cleanup --- src/brevitas_examples/llm/llm_quant/export.py | 101 ++++++------------ 1 file changed, 32 insertions(+), 69 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/export.py b/src/brevitas_examples/llm/llm_quant/export.py index 068e75266..44d17b77e 100644 --- a/src/brevitas_examples/llm/llm_quant/export.py +++ b/src/brevitas_examples/llm/llm_quant/export.py @@ -259,83 +259,46 @@ def __init__(self): register_custom_op_symbolic('::MatMulNBitsFn', MatMulNBitsFn.symbolic, 1) def pack_int_weights(self, bit_width, int_weights, zero_point): - assert int_weights.dtype in [torch.uint8], "Packing requires (u)int8 input." - zero_point = zero_point.to(torch.uint8).flatten() + assert int_weights.dtype in [torch.uint8, torch.int8], "Packing requires (u)int8 input." + assert bit_width == 4, "Only 4 bit quantization export is supported at the moment" + + is_symmetric = torch.sum(zero_point) == 0 + zero_point = zero_point.to(torch.uint8) rows, cols = int_weights.shape block_size = self.group_size blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size padded_rows = k_blocks * block_size pad_len = padded_rows - rows + + # ONNX operator assumes implicit zp of 8 (largest negative number in Po2) + # If we are in a "symmetric" quantized scenario, we need to add this implicit zero point + # Otherwise it has already been added during the convesion to integer + zp = 0 if not int_weights.dtype == torch.int8 else 8 + int_weights += zp if pad_len > 0: int_weights = torch.nn.functional(int_weights, (0, 0, 0, pad_len)) - if bit_width == 8: - return int_weights - elif bit_width == 4 or bit_width == 2: - packed_int_weights = torch.zeros((k_blocks * blob_size, cols), - device=int_weights.device, - dtype=torch.uint8) - packed_zp = torch.zeros((zero_point.shape[0] + 1) // 2, - device=int_weights.device, - dtype=torch.uint8) - i = 0 - for column in range(packed_int_weights.shape[0]): - # Compared to the reference below we don't transpose the matrix and we pack into 8b data rather than 32b - # https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/05781593c818d4dc8adc2d32c975e83d17d2b9a8/quant/quant_linear.py#L346 - for j in range(i, i + (8 // bit_width)): - shift_factor = (bit_width * (j - i)) - packed_int_weights[column, :] |= int_weights[j, :] << shift_factor - i += 8 // bit_width - packed_int_weights = packed_int_weights.t() - packed_int_weights = packed_int_weights.reshape(-1, k_blocks, blob_size) - i = 0 - for column in range(packed_zp.shape[0]): - # Compared to the reference below we don't transpose the matrix and we pack into 8b data rather than 32b - # https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/05781593c818d4dc8adc2d32c975e83d17d2b9a8/quant/quant_linear.py#L346 - for j in range(i, i + (8 // bit_width)): - shift_factor = (bit_width * (j - i)) - packed_zp[column] |= zero_point[j] << shift_factor - i += 8 // bit_width - return packed_int_weights, packed_zp - else: - raise RuntimeError("Only 4 and 8 bit quantization export is supported at the moment") - - # # pack 3b values into 3 bytes, 5b values into 5 bytes, 6b values into 4 bytes - # elif bit_width == 3 or bit_width == 5 or bit_width == 6: - # padding = (int_weights.shape[1] * bit_width) % 8 - # if padding > 0: - # warnings.warn( - # f"Weight tensor does not divide by {bit_width}, zero-padding columns by {padding}." - # ) - # packed_int_weights = torch.zeros( - # (int_weights.shape[0], (int_weights.shape[1] * bit_width + padding) // 8), - # device=int_weights.device, - # dtype=int_weights.dtype) - - # def lcm(x, y): - # from fractions import gcd - # return x * y // gcd(x, y) - - # num_packed_bits = lcm(bit_width, 8) - # num_packed_bytes = num_packed_bits // 8 - # num_packed_elems = num_packed_bits // bit_width - - # i = 0 - # for column in range(0, packed_int_weights.shape[1], num_packed_bytes): - # # cast to uint8 since it's the only dtype supported by unpackbits - # # the bit-wise representation of int8 values isn't affected - # bits_to_unpack = int_weights[:, i:i + num_packed_elems].numpy().astype(np.uint8) - # unpacked_bits = np.unpackbits(bits_to_unpack, axis=1) - # unpacked_bits = unpacked_bits.reshape(unpacked_bits.shape[0], -1, 8) - # unpacked_bits = unpacked_bits[:, :, -bit_width:] - # unpacked_bits = unpacked_bits.reshape(unpacked_bits.shape[0], -1) - # packed_bits = np.packbits(unpacked_bits, axis=1) - # packed_int_weights[:, column:column + - # num_packed_bytes] |= torch.from_numpy(packed_bits) - # i += num_packed_elems - # return packed_int_weights - # else: - # raise ValueError(f"Bit width {bit_width} not supported.") + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + rows, cols = int_weights.shape + int_weights = int_weights.t() + for n in range(cols): + for k_id in range(0, rows, block_size): + blk_int0 = (int_weights[n, k_id:k_id + block_size:2].numpy()).astype("uint8") + blk_int1 = (int_weights[n, k_id + 1:k_id + block_size:2].numpy()).astype("uint8") + packed[n, k_id // block_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) + + zero_point = zero_point.to(torch.uint8).flatten() + base_zp = 136 if is_symmetric else 0 + packed_zp = base_zp * torch.ones( + (zero_point.shape[0] + 1) // 2, device=int_weights.device, dtype=torch.uint8) + + i = 0 + for column in range(packed_zp.shape[0]): + for j in range(i, i + (8 // bit_width)): + shift_factor = (bit_width * (j - i)) + packed_zp[column] |= zero_point[j] << shift_factor + i += 8 // bit_width + return torch.tensor(packed), packed_zp def prepare_for_export(self, module): self.bit_width = self.bit_width_impl(module.weight_quant)()