Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add export support for TEQ #1910

Merged
merged 5 commits into from
Jul 11, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 59 additions & 11 deletions neural_compressor/torch/algorithms/weight_only/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from neural_compressor.torch.algorithms.base_algorithm import Quantizer
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger
from neural_compressor.torch.utils import get_accelerator, get_model_device, is_transformers_imported, logger

from .modules import MulLinear, TEQLinearFakeQuant
from .utility import get_module, quant_tensor, set_module
Expand Down Expand Up @@ -265,18 +265,66 @@ def transform(self):
set_module(self.model, n, m.orig_layer)

@torch.no_grad()
def quantize(self):
def quantize(self, **kwargs):
"""quantization."""

for n, m in self.model.named_modules():
if self.weight_config.get(n) is None: # pragma: no cover
logger.info(f"quantize layer {n} not in weight config, skip.")
use_optimum_format = kwargs.get("use_optimum_format", True)
device = get_accelerator().current_device_name()
model_device = get_model_device(self.model) # return model on the same device
model = self.model
for name, m in model.named_modules():
if self.weight_config.get(name) is None: # pragma: no cover
logger.info(f"quantize layer {name} not in weight config, skip.")
continue
num_bits = self.weight_config[n]["bits"]
group_size = self.weight_config[n]["group_size"]
scheme = self.weight_config[n]["scheme"]
num_bits = self.weight_config[name]["bits"]
group_size = self.weight_config[name]["group_size"]
scheme = self.weight_config[name]["scheme"]
group_dim = self.weight_config[name].get("group_dim", 1)
# for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight.
if is_transformers_imported():
transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D))
else:
transpose = group_dim == 0
if transpose:
weight = m.weight.detach().T.contiguous()
else:
weight = m.weight.detach()
if isinstance(m, torch.nn.Linear): # pragma: no cover
quant_tensor(m.weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme)
int_weight, scale, zp = quant_tensor(
weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme
)
int_weight = int_weight.t_().contiguous() if transpose else int_weight
scale = scale.t_().contiguous() if transpose else scale
zp = zp.t_().contiguous() if transpose and zp is not None else zp
if isinstance(m, torch.nn.Linear):
in_features = m.in_features
out_features = m.out_features
elif is_transformers_imported() and isinstance(m, transformers.Conv1D):
in_features = m.weight.shape[0]
out_features = m.weight.shape[1]
int_weight = int_weight.t_().contiguous()
scale = scale.t_().contiguous()
zp = zp.t_().contiguous() if zp is not None else zp
from .modules import WeightOnlyLinear

new_module = WeightOnlyLinear(
in_features,
out_features,
bits=num_bits,
group_size=group_size,
zp=zp is not None,
bias=m.bias is not None,
use_optimum_format=use_optimum_format,
device=device,
)
new_module.pack(int_weight, scale, zp, m.bias)
if name == "":
return new_module
else:
set_module(model, name, new_module)
# Move modules back to the model device layer-by-layer
m.to(model_device)
new_module.to(model_device)
self.model = model

def save(self, save_scale_file="", save_state_dict_file=""):
"""
Expand Down Expand Up @@ -328,6 +376,6 @@ def convert(self, model, *args: Any, **kwargs: Any):
setattr(self._quantizer, attr, getattr(model, self._quantizer._PREPARE_ATTRS_PREFIX + attr, None))
self._quantizer.model = model
self._quantizer.transform()
self._quantizer.quantize()
self._quantizer.quantize(**kwargs)
logger.info("TEQ quantizing done.")
return self._quantizer.model
Loading