Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port auto-detect absorb layers for TEQ #1895

Merged
merged 3 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions neural_compressor/torch/algorithms/weight_only/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
# limitations under the License.
#

import copy
from typing import Any
from typing import Any, List

import torch

Expand All @@ -36,10 +35,10 @@
class TrainableEquivalentTransformation:
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ)."""

_PREPARE_ATTRS: list[str] = ["weight_config", "trained_alphas"]
_PREPARE_ATTRS: List[str] = ["weight_config", "trained_alphas"]
_PREPARE_ATTRS_PREFIX = "_prepare_"

def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, example_inputs=None):
def __init__(self, model, weight_config={}, absorb_to_layer=None, folding=True, example_inputs=None):
"""
:param model: the model for quantization
:param weight_config (dict, optional): contains all info required by RTN. Defaults to {}.
Expand All @@ -54,6 +53,24 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
self.absorb_to_layer = absorb_to_layer
self._post_initialized = False

def _detect_absorb_to_layer(self, model, folding, example_inputs):
# If user not provide the layers to absorb the quantization, detect layers automatically
supported_layers = ["Linear"]
detected_absorb_layers = {}
# Detect the layers that can be absorbed automatically
if folding:
from neural_compressor.torch.algorithms.weight_only.utility import GraphTrace

tg = GraphTrace()
detected_absorb_layers, _ = tg.get_absorb_to_layer(model, example_inputs, supported_layers)
else: # pragma: no cover
for name, module in model.named_modules():
if module.__class__.__name__ in supported_layers:
detected_absorb_layers[name] = [name]
logger.info("Detected **absorb layer**: **absorbed layers**")
logger.info(detected_absorb_layers)
return detected_absorb_layers

def _post_init(self):
self.dtype = self._get_dtype()
self.model.to(self.device)
Expand All @@ -75,6 +92,8 @@ def add_tuning_scale(self, sqrt_w_init=False):
to the paper for more details
:param sqrt_w_init: use sqrt weight to init."""

if not self.absorb_to_layer:
self.absorb_to_layer = self._detect_absorb_to_layer(self.model, self.folding, self.example_inputs)
if not self._post_initialized:
self._post_init()
# freeze model.
Expand Down Expand Up @@ -104,7 +123,7 @@ def add_tuning_scale(self, sqrt_w_init=False):

self.trained_alphas[layer_norm] = alpha
for layer_name in self.absorb_to_layer[layer_norm]:
if self.weight_config.get(layer_name) is None: # pragma: no cover
if not self.weight_config.get(layer_name): # pragma: no cover
logger.info(f"layer {layer_name} not in weight config, skip.")
continue
num_bits = self.weight_config[layer_name]["bits"]
Expand All @@ -117,10 +136,10 @@ def add_tuning_scale(self, sqrt_w_init=False):
)
set_module(self.model, layer_name, wrapper_module)

for n, m in self.model.named_modules():
for layer_name, m in self.model.named_modules():
if isinstance(m, torch.nn.Linear) and "orig_layer" not in n:
if self.weight_config.get(n) is None: # pragma: no cover
logger.info(f"out of absorbed layer {n} not in weight config, skip.")
if not self.weight_config.get(layer_name): # pragma: no cover
logger.info(f"out of absorbed layer {layer_name} not in weight config, skip.")
continue
num_bits = self.weight_config[layer_name]["bits"]
group_size = self.weight_config[layer_name]["group_size"]
Expand All @@ -131,7 +150,7 @@ def add_tuning_scale(self, sqrt_w_init=False):
wrapper_module = TEQLinearFakeQuant(
orig_layer=m, alpha=alpha, num_bits=num_bits, group_size=group_size, scheme=scheme
)
set_module(self.model, n, wrapper_module)
set_module(self.model, layer_name, wrapper_module)
# Attach the weight config captured at prepare stage to the model
self.model._weight_config = self.weight_config
self.model._trained_alphas = self.trained_alphas
Expand Down Expand Up @@ -190,7 +209,9 @@ def _absorb_scales(self, layer, scale, layer_name=""):
scale = scale.view(scale.shape[0], 1)
layer.weight *= scale

elif layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm": ##quite tricky
elif (
layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm"
): # pragma: no cover
layer.weight *= scale

else: # pragma: no cover
Expand Down Expand Up @@ -222,7 +243,7 @@ def _scale_layer_weight(self, layer, scale): ##input channel
@torch.no_grad()
def transform(self):
"""Apply alpha/scale."""
if not self._post_initialized:
if not self._post_initialized: # pragma: no cover
self._post_init()
for ln_name, layer_names in self.absorb_to_layer.items():
module = get_module(self.model, ln_name)
Expand Down Expand Up @@ -272,7 +293,7 @@ def save(self, save_scale_file="", save_state_dict_file=""):

class TEQuantizer(Quantizer):

def __init__(self, quant_config, folding, absorb_to_layer, example_inputs):
def __init__(self, quant_config, folding, example_inputs, absorb_to_layer=None):
super().__init__(quant_config=quant_config)
self.folding = folding
self.absorb_to_layer = absorb_to_layer
Expand Down
17 changes: 15 additions & 2 deletions test/3x/torch/algorithms/weight_only/test_teq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,21 @@ def setUpClass(self):
)
self.gptj.seqlen = 512

def train_func(self):
pass
def test_teq_detect_absorb_layers(self):
example_inputs = torch.ones([1, 512], dtype=torch.long)
test_input = torch.ones([1, 512], dtype=torch.long)
model = copy.deepcopy(self.gptj)
out0 = model(test_input)

weight_config = {
# 'op_name': (bit, group_size, scheme)
"transformer.h.0.mlp.fc_in": {"bits": 8, "group_size": -1, "scheme": "sym"},
"transformer.h.0.mlp.fc_out": {"bits": 4, "group_size": 32, "scheme": "asym"},
}
quantizer = TEQuantizer(quant_config=weight_config, folding=True, example_inputs=example_inputs)
model = quantizer.quantize(copy.deepcopy(self.gptj), run_fn=train)
out1 = model(test_input)
self.assertTrue(torch.allclose(out1[0], out0[0], atol=0.03))

def test_teq(self):
example_inputs = torch.ones([1, 512], dtype=torch.long)
Expand Down
Loading