Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Temporary fix for QAT quantizer when linear layer bias is True #1087

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from torchao.quantization.prototype.qat.linear import (
FakeQuantizedLinear,
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear
)
from torchao.quantization.prototype.qat.utils import (
_choose_qparams_per_token_asymmetric,
Expand Down Expand Up @@ -63,6 +65,10 @@
TORCH_VERSION_AT_LEAST_2_5,
)

from torchao.quantization.GPTQ import (
_replace_linear_8da4w,
_replace_linear_int4
)

# TODO: put this in a common test utils file
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
Expand Down Expand Up @@ -851,7 +857,49 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
fq_out = fq_linear(x)
baseline_out = linear_forward_4w(x2, fq_linear.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_replace_linear_8da4w(self):
module = torch.nn.ModuleList([
torch.nn.Linear(in_features=256, out_features=50, bias=True)
])
_replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True)
assert(not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance(module[0], torch.nn.Linear))
module = torch.nn.ModuleList([
torch.nn.Linear(in_features=256, out_features=50, bias=False)
])
_replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True)
assert(isinstance(module[0], Int8DynActInt4WeightQATLinear))

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_replace_linear_int4(self):
module = torch.nn.ModuleList([
torch.nn.Linear(in_features=256, out_features=50, bias=True)
])
_replace_linear_int4(
module,
256,
8,
padding_allowed=True,
precision=torch.bfloat16,
scales_precision=torch.bfloat16,
linear_class=Int4WeightOnlyQATLinear,
copy_weights=True)
assert(not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance(module[0], torch.nn.Linear))
module = torch.nn.ModuleList([
torch.nn.Linear(in_features=256, out_features=50, bias=False)
])
_replace_linear_int4(
module,
256,
8,
padding_allowed=True,
precision=torch.bfloat16,
scales_precision=torch.bfloat16,
linear_class=Int4WeightOnlyQATLinear,
copy_weights=True)
assert(isinstance(module[0], Int4WeightOnlyQATLinear))


if __name__ == "__main__":
unittest.main()
unittest.main()
6 changes: 4 additions & 2 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ def _replace_linear_int4(
copy_weights: bool = False,
):
for name, child in module.named_children():
if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)):
# TODO: support linear bias
if isinstance(child, nn.Linear) and child.bias is None and (skip_layer_func is None or not skip_layer_func(child.weight)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add a TODO: support linear bias (here and L982)

if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
new_linear = linear_class(
child.in_features,
Expand Down Expand Up @@ -979,7 +980,8 @@ def _replace_linear_8da4w(
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)
# TODO: support linear bias
return isinstance(child, nn.Linear) and child.bias is None and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = linear_class(
Expand Down
Loading