Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 27, 2023
1 parent 58c22f3 commit 9ffc4c9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
10 changes: 7 additions & 3 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,12 @@ def set_export_handler(cls, module):
_set_proxy_export_handler(cls, module)


def block_quant_layer_level_manager(export_handlers):
def block_quant_layer_level_manager(export_handlers, target=None, custom_fns_to_register=None):

class BlockQuantLayerLevelManager(BaseManager):
handlers = export_handlers
target_name = '' if target is None else target
custom_fns = [] if custom_fns_to_register is None else custom_fns_to_register

@classmethod
def set_export_handler(cls, module):
Expand Down Expand Up @@ -256,7 +258,6 @@ class ONNXLinearWeightBlockQuantHandlerFwd(ONNXBaseHandler, WeightBlockQuantHand
def __init__(self):
super(ONNXLinearWeightBlockQuantHandlerFwd, self).__init__()
self.group_size = None
register_custom_op_symbolic('::MatMulNBitsFn', MatMulNBitsFn.symbolic, 1)

def pack_int_weights(self, bit_width, int_weights, zero_point):
assert int_weights.dtype in [torch.uint8, torch.int8], "Packing requires (u)int8 input."
Expand Down Expand Up @@ -329,6 +330,9 @@ def symbolic_execution(self, x):

def export_packed_onnx(model, input, export_path):
export_class = block_quant_layer_level_manager(
export_handlers=[ONNXLinearWeightBlockQuantHandlerFwd])
export_handlers=[ONNXLinearWeightBlockQuantHandlerFwd],
target='',
custom_fns_to_register=MatMulNBitsFn)

with torch.inference_mode(), brevitas_layer_export_mode(model, export_class):
torch.onnx.export(model, input, export_path)
1 change: 0 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ def main():
seqlen=args.seqlen)
# Tie back first/last layer weights in case they got untied
model.tie_weights()
print(model)
print("Model quantization applied.")

if args.act_calibration:
Expand Down

0 comments on commit 9ffc4c9

Please sign in to comment.