diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index 8a1a683483f..ed36d257f2a 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -403,7 +403,7 @@ def rtn_quantize( model: fake quantized torch module """ assert isinstance(model, torch.nn.Module), "only support torch module" - supported_layers = ["Linear"] + supported_layers = (torch.nn.Linear,) if return_int: compression_dtype = kwargs.get("compression_dtype", torch.int32) compression_dim = kwargs.get("compression_dim", 1) @@ -412,7 +412,7 @@ def rtn_quantize( use_optimum_format = kwargs.get("use_optimum_format", True) with torch.no_grad(): for name, m in model.named_modules(): - if m.__class__.__name__ not in supported_layers: + if not isinstance(m, supported_layers): continue orig_dtype = next(m.parameters()).dtype if orig_dtype != torch.float: