Skip to content

Commit

Permalink
Temporary fix for QAT when linear layer bias is True
Browse files Browse the repository at this point in the history
Signed-off-by: yumin <[email protected]>
  • Loading branch information
yumin committed Oct 16, 2024
1 parent 6ea36c5 commit 220aae6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def _replace_linear_int4(
copy_weights: bool = False,
):
for name, child in module.named_children():
if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)):
if isinstance(child, nn.Linear) and child.bias is None and (skip_layer_func is None or not skip_layer_func(child.weight)):
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
new_linear = linear_class(
child.in_features,
Expand Down Expand Up @@ -979,7 +979,7 @@ def _replace_linear_8da4w(
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)
return isinstance(child, nn.Linear) and child.bias is None and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = linear_class(
Expand Down

0 comments on commit 220aae6

Please sign in to comment.