From 220aae64e3371a721bce471a9eb966304736794e Mon Sep 17 00:00:00 2001 From: yumin Date: Wed, 16 Oct 2024 10:25:18 +0800 Subject: [PATCH] Temporary fix for QAT when linear layer bias is True Signed-off-by: yumin --- torchao/quantization/GPTQ.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 23c87141c7..bf1116f811 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -617,7 +617,7 @@ def _replace_linear_int4( copy_weights: bool = False, ): for name, child in module.named_children(): - if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)): + if isinstance(child, nn.Linear) and child.bias is None and (skip_layer_func is None or not skip_layer_func(child.weight)): if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: new_linear = linear_class( child.in_features, @@ -979,7 +979,7 @@ def _replace_linear_8da4w( from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool: - return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed) + return isinstance(child, nn.Linear) and child.bias is None and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed) def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: new_linear = linear_class(