Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 4, 2023
1 parent 9ffc4c9 commit 42f9ab3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 6 additions & 1 deletion src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@ def symbolic(g, x, int_weights, scales, zero_points, K, N, bits, block_size):

@staticmethod
def forward(g, x, int_weights, scales, zero_points, K, N, bits, block_size):
dtype = x.dtype
device = x.device
shape = x.shape
out_shape = list(shape)
out_shape[-1] = N
return torch.empty(out_shape)
# Only tensor metadata (shape, dtype, device) are preserved in the forward pass during
# tracing, not the correct value
out = torch.empty(out_shape, dtype=dtype, device=device)
return out


AXIS_OPSET = 13
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ def pack_int_weights(self, bit_width, int_weights, zero_point):

# ONNX operator assumes implicit zp of 8 (largest negative number in Po2)
# If we are in a "symmetric" quantized scenario, we need to add this implicit zero point
# Otherwise it has already been added during the convesion to integer
# Otherwise it has already been added during the convesion to integer.
# This allows to pack weights always in unsigned integer.
zp = 0 if not int_weights.dtype == torch.int8 else 8
int_weights += zp
if pad_len > 0:
Expand All @@ -289,6 +290,9 @@ def pack_int_weights(self, bit_width, int_weights, zero_point):
packed[n, k_id // block_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4))

zero_point = zero_point.to(torch.uint8).flatten()

# The constant value 136 is derived from the source code in ORT test suite.
# https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py
base_zp = 136 if is_symmetric else 0
packed_zp = base_zp * torch.ones(
(zero_point.shape[0] + 1) // 2, device=int_weights.device, dtype=torch.uint8)
Expand Down

0 comments on commit 42f9ab3

Please sign in to comment.