Skip to content

Commit

Permalink
Support transformers.Conv1D WOQ quantization (#1796)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored May 17, 2024
1 parent 2e1cdc5 commit b6237cf
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 39 deletions.
44 changes: 30 additions & 14 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@

import torch
import torch.nn as nn
import transformers
from tqdm import tqdm

from neural_compressor.torch.utils import fetch_module, get_device, logger, set_module
from neural_compressor.torch.utils import fetch_module, get_device, is_transformers_imported, logger, set_module
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .modules import WeightOnlyLinear

if is_transformers_imported():
import transformers

SUPPORTED_LAYERS = [nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D]
else:
SUPPORTED_LAYERS = [nn.Conv2d, nn.Conv1d, nn.Linear]
DEBUG = False
accelerator = auto_detect_accelerator()

Expand Down Expand Up @@ -131,7 +136,7 @@ def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList, torch.nn
return gptq_related_blocks


def find_layers(module, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D], name=""):
def find_layers(module, layers=SUPPORTED_LAYERS, name=""):
"""Get all layers with target types."""
if type(module) in layers:
return {name: module}
Expand All @@ -147,7 +152,7 @@ def find_layers(module, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Co
return res


def find_layers_name(module, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D], name=""):
def find_layers_name(module, layers=SUPPORTED_LAYERS, name=""):
"""Get all layers with target types."""
if type(module) in layers:
return [name]
Expand All @@ -157,9 +162,7 @@ def find_layers_name(module, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transforme
return res


def log_quantizable_layers_per_transformer(
transformer_blocks, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D]
):
def log_quantizable_layers_per_transformer(transformer_blocks, layers=SUPPORTED_LAYERS):
"""Print all layers which will be quantized in GPTQ algorithm."""
logger.info("* * Layer to be quantized * *")

Expand Down Expand Up @@ -734,6 +737,8 @@ def tmp(_, inp, out):
Q = sub_layers[layer_name].weight.data
if weight_config_this_layer["act_order"]:
Q.copy_(Q[:, gptq_perm])
if is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D):
Q = Q.t_().contiguous()
from .utility import quant_weight_w_scale

quant_weight_w_scale(
Expand All @@ -743,15 +748,24 @@ def tmp(_, inp, out):
weight_config_this_layer["group_size"],
dtype=weight_config_this_layer["dtype"],
)
# import pdb;pdb.set_trace()
if weight_config_this_layer["act_order"]:
invperm = torch.argsort(gptq_perm)
Q.copy_(Q[:, invperm])
int_weight = Q.type(torch.int32) # copy_ is not workable for different types.
# replace module
if isinstance(sub_layers[layer_name], torch.nn.Linear):
in_features = sub_layers[layer_name].in_features
out_features = sub_layers[layer_name].out_features
elif is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D):
in_features = sub_layers[layer_name].weight.shape[0]
out_features = sub_layers[layer_name].weight.shape[1]
int_weight = sub_layers[layer_name].weight.t_().contiguous()
scale = scale.t_().contiguous()
zp = zp.t_().contiguous() if zp is not None else zp

new_module = WeightOnlyLinear(
sub_layers[layer_name].in_features,
sub_layers[layer_name].out_features,
in_features,
out_features,
dtype=weight_config_this_layer["dtype"],
bits=weight_config_this_layer["bits"],
group_size=weight_config_this_layer["group_size"],
Expand Down Expand Up @@ -790,7 +804,7 @@ def __init__(self, layer, W, device="cpu"):
# W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
if is_transformers_imported() and isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.rows = W.shape[0] # output channels
self.columns = W.shape[1] # input channels
Expand All @@ -806,7 +820,9 @@ def add_batch(self, inp, out):
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if isinstance(self.layer, nn.Linear) or (
is_transformers_imported() and isinstance(self.layer, transformers.Conv1D)
):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
Expand All @@ -833,7 +849,7 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
weight_shape, weight_dtype = W.shape, W.data.dtype
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
if is_transformers_imported() and isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()

Expand Down Expand Up @@ -937,7 +953,7 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
invperm = torch.argsort(perm)
Q = Q[:, invperm]

if isinstance(self.layer, transformers.Conv1D):
if is_transformers_imported() and isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
# self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
Q = Q.reshape(weight_shape).to(weight_dtype)
Expand Down
41 changes: 31 additions & 10 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import torch

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import get_device, logger, set_module
from neural_compressor.torch.utils import get_device, is_transformers_imported, logger, set_module

from .utility import cast_fp8, quant_tensor, search_clip

if is_transformers_imported():
import transformers


class RTNQuantizer(Quantizer):
def __init__(self, quant_config: OrderedDict = {}):
Expand Down Expand Up @@ -94,7 +97,10 @@ def convert(
model.to(device)

assert isinstance(model, torch.nn.Module), "only support torch module"
supported_layers = (torch.nn.Linear,)
if is_transformers_imported():
supported_layers = (torch.nn.Linear, transformers.Conv1D)
else:
supported_layers = (torch.nn.Linear,)
# initialize global configuration
double_quant_config = {
"double_quant": kwargs.get("use_double_quant", False),
Expand Down Expand Up @@ -153,7 +159,12 @@ def convert(
continue
logger.debug(f"RTN quantized module:{name, m}")
logger.debug(log_msg)
if group_dim == 0:
# for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight.
if is_transformers_imported():
transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D))
else:
transpose = group_dim == 0
if transpose:
weight = m.weight.t_().contiguous()
else:
weight = m.weight
Expand All @@ -171,14 +182,23 @@ def convert(
full_range=use_full_range,
**double_quant_config,
)
int_weight = int_weight.t_().contiguous() if group_dim == 0 else int_weight
scale = scale.t_().contiguous() if group_dim == 0 else scale
zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp
int_weight = int_weight.t_().contiguous() if transpose else int_weight
scale = scale.t_().contiguous() if transpose else scale
zp = zp.t_().contiguous() if transpose and zp is not None else zp
if isinstance(m, torch.nn.Linear):
in_features = m.in_features
out_features = m.out_features
elif is_transformers_imported() and isinstance(m, transformers.Conv1D):
in_features = m.weight.shape[1]
out_features = m.weight.shape[0]
int_weight = int_weight.t_().contiguous()
scale = scale.t_().contiguous()
zp = zp.t_().contiguous() if zp is not None else zp
from .modules import WeightOnlyLinear

new_module = WeightOnlyLinear(
m.in_features,
m.out_features,
in_features,
out_features,
dtype=dtype,
bits=bits,
group_size=group_size,
Expand All @@ -203,8 +223,9 @@ def convert(
full_range=use_full_range,
**double_quant_config,
)
if group_dim == 0:
# for group_dim is 0, we need to transpose the quantized tensor and module's weight back
if transpose:
# for only group_dim is 0 or only `transformers.Conv1D`,
# we need to transpose the quantized tensor and module's weight back
weight = weight.t_().contiguous()
m.weight.t_().contiguous()
m.weight.data.copy_(weight)
Expand Down
6 changes: 4 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
from typing import Any

import torch
import transformers

from neural_compressor.torch.algorithms.base_algorithm import Quantizer
from neural_compressor.torch.utils import get_device, logger
from neural_compressor.torch.utils import get_device, is_transformers_imported, logger

from .modules import MulLinear, TEQLinearFakeQuant
from .utility import get_module, quant_tensor, set_module

if is_transformers_imported():
import transformers

__all__ = ["TrainableEquivalentTransformation", "TEQuantizer"]


Expand Down
26 changes: 13 additions & 13 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
STATIC_QUANT,
TEQ,
)
from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, logger
from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, is_transformers_imported, logger
from neural_compressor.torch.utils.constants import (
PRIORITY_AUTOROUND,
PRIORITY_AWQ,
Expand All @@ -66,6 +66,12 @@


FRAMEWORK_NAME = "torch"
if is_transformers_imported():
import transformers

WOQ_WHITE_LIST = (torch.nn.Linear, transformers.Conv1D)
else:
WOQ_WHITE_LIST = (torch.nn.Linear,)


class OperatorConfig(NamedTuple):
Expand Down Expand Up @@ -193,10 +199,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]:

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
if isinstance(module, WOQ_WHITE_LIST):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
Expand Down Expand Up @@ -339,10 +344,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]:

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
if isinstance(module, WOQ_WHITE_LIST):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
Expand Down Expand Up @@ -472,10 +476,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]:

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
if isinstance(module, WOQ_WHITE_LIST):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
Expand Down Expand Up @@ -597,10 +600,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]:

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
if isinstance(module, WOQ_WHITE_LIST):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
Expand Down Expand Up @@ -743,10 +745,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]:

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
if isinstance(module, WOQ_WHITE_LIST):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
logger.debug(f"Get model info: {filter_result}")
Expand Down Expand Up @@ -1071,10 +1072,9 @@ def __init__(

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
white_list = (torch.nn.Linear,)
filter_result = []
for op_name, module in model.named_modules():
if isinstance(module, white_list):
if isinstance(module, WOQ_WHITE_LIST):
pair = (op_name, type(module).__name__)
filter_result.append(pair)
return filter_result
Expand Down
7 changes: 7 additions & 0 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def is_ipex_imported() -> bool:
return False


def is_transformers_imported() -> bool:
for name, _ in sys.modules.items():
if name == "transformers":
return True
return False


def get_device(device_name="auto"):
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

Expand Down
39 changes: 39 additions & 0 deletions test/3x/torch/quantization/weight_only/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import torch
import transformers
from packaging.version import Version

from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer, get_autoround_default_run_fn
from neural_compressor.torch.quantization import (
Expand All @@ -18,6 +19,10 @@
try:
import auto_round

AUTO_ROUND_VERSION_0_11 = Version("0.11")

auto_round_version = auto_round.__version__.split("+")[0]
auto_round_version = Version(auto_round_version)
auto_round_installed = True
except ImportError:
auto_round_installed = False
Expand Down Expand Up @@ -146,3 +151,37 @@ def test_save_and_load(self):
loaded_model = load("saved_results")
loaded_out = loaded_model(self.inp)[0]
assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check."

@pytest.mark.skipif(auto_round_version <= AUTO_ROUND_VERSION_0_11, reason="Requires auto_round>=0.11")
def test_conv1d(self):
input = torch.randn(1, 32)
from transformers import GPT2Model, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("sshleifer/tiny-gpt2")
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors="pt")
out1 = model(**encoded_input)[0]
run_fn = get_autoround_default_run_fn
run_args = (
tokenizer,
"NeelNanda/pile-10k",
20,
10,
)
weight_config = {
"*": {
"data_type": "int",
"bits": 4,
"group_size": 32,
"sym": False,
}
}
quantizer = AutoRoundQuantizer(quant_config=weight_config)

# quantizer execute
model = quantizer.prepare(model=model)
run_fn(model, *run_args)
q_model = quantizer.convert(model)
out2 = q_model(**encoded_input)[0]
assert torch.allclose(out2, out1, atol=0.01), "Accuracy gap atol > 0.01 is unexpected."
Loading

0 comments on commit b6237cf

Please sign in to comment.