diff --git a/neural_compressor/torch/algorithms/weight_only/teq.py b/neural_compressor/torch/algorithms/weight_only/teq.py index 9f2e8fc8dce..9783d913070 100644 --- a/neural_compressor/torch/algorithms/weight_only/teq.py +++ b/neural_compressor/torch/algorithms/weight_only/teq.py @@ -16,8 +16,7 @@ # limitations under the License. # -import copy -from typing import Any +from typing import Any, List import torch @@ -36,10 +35,10 @@ class TrainableEquivalentTransformation: """Weight-only quantization, Trainable Equivalent Transformation (TEQ).""" - _PREPARE_ATTRS: list[str] = ["weight_config", "trained_alphas"] + _PREPARE_ATTRS: List[str] = ["weight_config", "trained_alphas"] _PREPARE_ATTRS_PREFIX = "_prepare_" - def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, example_inputs=None): + def __init__(self, model, weight_config={}, absorb_to_layer=None, folding=True, example_inputs=None): """ :param model: the model for quantization :param weight_config (dict, optional): contains all info required by RTN. Defaults to {}. @@ -54,6 +53,24 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex self.absorb_to_layer = absorb_to_layer self._post_initialized = False + def _detect_absorb_to_layer(self, model, folding, example_inputs): + # If user not provide the layers to absorb the quantization, detect layers automatically + supported_layers = ["Linear"] + detected_absorb_layers = {} + # Detect the layers that can be absorbed automatically + if folding: + from neural_compressor.torch.algorithms.weight_only.utility import GraphTrace + + tg = GraphTrace() + detected_absorb_layers, _ = tg.get_absorb_to_layer(model, example_inputs, supported_layers) + else: # pragma: no cover + for name, module in model.named_modules(): + if module.__class__.__name__ in supported_layers: + detected_absorb_layers[name] = [name] + logger.info("Detected **absorb layer**: **absorbed layers**") + logger.info(detected_absorb_layers) + return detected_absorb_layers + def _post_init(self): self.dtype = self._get_dtype() self.model.to(self.device) @@ -75,6 +92,8 @@ def add_tuning_scale(self, sqrt_w_init=False): to the paper for more details :param sqrt_w_init: use sqrt weight to init.""" + if not self.absorb_to_layer: + self.absorb_to_layer = self._detect_absorb_to_layer(self.model, self.folding, self.example_inputs) if not self._post_initialized: self._post_init() # freeze model. @@ -104,7 +123,7 @@ def add_tuning_scale(self, sqrt_w_init=False): self.trained_alphas[layer_norm] = alpha for layer_name in self.absorb_to_layer[layer_norm]: - if self.weight_config.get(layer_name) is None: # pragma: no cover + if not self.weight_config.get(layer_name): # pragma: no cover logger.info(f"layer {layer_name} not in weight config, skip.") continue num_bits = self.weight_config[layer_name]["bits"] @@ -117,10 +136,10 @@ def add_tuning_scale(self, sqrt_w_init=False): ) set_module(self.model, layer_name, wrapper_module) - for n, m in self.model.named_modules(): + for layer_name, m in self.model.named_modules(): if isinstance(m, torch.nn.Linear) and "orig_layer" not in n: - if self.weight_config.get(n) is None: # pragma: no cover - logger.info(f"out of absorbed layer {n} not in weight config, skip.") + if not self.weight_config.get(layer_name): # pragma: no cover + logger.info(f"out of absorbed layer {layer_name} not in weight config, skip.") continue num_bits = self.weight_config[layer_name]["bits"] group_size = self.weight_config[layer_name]["group_size"] @@ -131,7 +150,7 @@ def add_tuning_scale(self, sqrt_w_init=False): wrapper_module = TEQLinearFakeQuant( orig_layer=m, alpha=alpha, num_bits=num_bits, group_size=group_size, scheme=scheme ) - set_module(self.model, n, wrapper_module) + set_module(self.model, layer_name, wrapper_module) # Attach the weight config captured at prepare stage to the model self.model._weight_config = self.weight_config self.model._trained_alphas = self.trained_alphas @@ -190,7 +209,9 @@ def _absorb_scales(self, layer, scale, layer_name=""): scale = scale.view(scale.shape[0], 1) layer.weight *= scale - elif layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm": ##quite tricky + elif ( + layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm" + ): # pragma: no cover layer.weight *= scale else: # pragma: no cover @@ -222,7 +243,7 @@ def _scale_layer_weight(self, layer, scale): ##input channel @torch.no_grad() def transform(self): """Apply alpha/scale.""" - if not self._post_initialized: + if not self._post_initialized: # pragma: no cover self._post_init() for ln_name, layer_names in self.absorb_to_layer.items(): module = get_module(self.model, ln_name) @@ -272,7 +293,7 @@ def save(self, save_scale_file="", save_state_dict_file=""): class TEQuantizer(Quantizer): - def __init__(self, quant_config, folding, absorb_to_layer, example_inputs): + def __init__(self, quant_config, folding, example_inputs, absorb_to_layer=None): super().__init__(quant_config=quant_config) self.folding = folding self.absorb_to_layer = absorb_to_layer diff --git a/test/3x/torch/algorithms/weight_only/test_teq_quantizer.py b/test/3x/torch/algorithms/weight_only/test_teq_quantizer.py index a27ce5ec0f2..4e06cedb284 100644 --- a/test/3x/torch/algorithms/weight_only/test_teq_quantizer.py +++ b/test/3x/torch/algorithms/weight_only/test_teq_quantizer.py @@ -82,8 +82,21 @@ def setUpClass(self): ) self.gptj.seqlen = 512 - def train_func(self): - pass + def test_teq_detect_absorb_layers(self): + example_inputs = torch.ones([1, 512], dtype=torch.long) + test_input = torch.ones([1, 512], dtype=torch.long) + model = copy.deepcopy(self.gptj) + out0 = model(test_input) + + weight_config = { + # 'op_name': (bit, group_size, scheme) + "transformer.h.0.mlp.fc_in": {"bits": 8, "group_size": -1, "scheme": "sym"}, + "transformer.h.0.mlp.fc_out": {"bits": 4, "group_size": 32, "scheme": "asym"}, + } + quantizer = TEQuantizer(quant_config=weight_config, folding=True, example_inputs=example_inputs) + model = quantizer.quantize(copy.deepcopy(self.gptj), run_fn=train) + out1 = model(test_input) + self.assertTrue(torch.allclose(out1[0], out0[0], atol=0.03)) def test_teq(self): example_inputs = torch.ones([1, 512], dtype=torch.long)