Skip to content

Commit

Permalink
Fix RTN supported layer checking condition (#1705)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored Apr 2, 2024
1 parent 14868c0 commit 0791776
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def rtn_quantize(
model: fake quantized torch module
"""
assert isinstance(model, torch.nn.Module), "only support torch module"
supported_layers = ["Linear"]
supported_layers = (torch.nn.Linear,)
if return_int:
compression_dtype = kwargs.get("compression_dtype", torch.int32)
compression_dim = kwargs.get("compression_dim", 1)
Expand All @@ -412,7 +412,7 @@ def rtn_quantize(
use_optimum_format = kwargs.get("use_optimum_format", True)
with torch.no_grad():
for name, m in model.named_modules():
if m.__class__.__name__ not in supported_layers:
if not isinstance(m, supported_layers):
continue
orig_dtype = next(m.parameters()).dtype
if orig_dtype != torch.float:
Expand Down

0 comments on commit 0791776

Please sign in to comment.