diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index 51cc819a3c6..916275e4db6 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -25,14 +25,19 @@ import torch import torch.nn as nn -import transformers from tqdm import tqdm -from neural_compressor.torch.utils import fetch_module, get_device, logger, set_module +from neural_compressor.torch.utils import fetch_module, get_device, is_transformers_imported, logger, set_module from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator from .modules import WeightOnlyLinear +if is_transformers_imported(): + import transformers + + SUPPORTED_LAYERS = [nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D] +else: + SUPPORTED_LAYERS = [nn.Conv2d, nn.Conv1d, nn.Linear] DEBUG = False accelerator = auto_detect_accelerator() @@ -131,7 +136,7 @@ def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList, torch.nn return gptq_related_blocks -def find_layers(module, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D], name=""): +def find_layers(module, layers=SUPPORTED_LAYERS, name=""): """Get all layers with target types.""" if type(module) in layers: return {name: module} @@ -147,7 +152,7 @@ def find_layers(module, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Co return res -def find_layers_name(module, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D], name=""): +def find_layers_name(module, layers=SUPPORTED_LAYERS, name=""): """Get all layers with target types.""" if type(module) in layers: return [name] @@ -157,9 +162,7 @@ def find_layers_name(module, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transforme return res -def log_quantizable_layers_per_transformer( - transformer_blocks, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D] -): +def log_quantizable_layers_per_transformer(transformer_blocks, layers=SUPPORTED_LAYERS): """Print all layers which will be quantized in GPTQ algorithm.""" logger.info("* * Layer to be quantized * *") @@ -734,6 +737,8 @@ def tmp(_, inp, out): Q = sub_layers[layer_name].weight.data if weight_config_this_layer["act_order"]: Q.copy_(Q[:, gptq_perm]) + if is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D): + Q = Q.t_().contiguous() from .utility import quant_weight_w_scale quant_weight_w_scale( @@ -743,15 +748,24 @@ def tmp(_, inp, out): weight_config_this_layer["group_size"], dtype=weight_config_this_layer["dtype"], ) - # import pdb;pdb.set_trace() if weight_config_this_layer["act_order"]: invperm = torch.argsort(gptq_perm) Q.copy_(Q[:, invperm]) int_weight = Q.type(torch.int32) # copy_ is not workable for different types. # replace module + if isinstance(sub_layers[layer_name], torch.nn.Linear): + in_features = sub_layers[layer_name].in_features + out_features = sub_layers[layer_name].out_features + elif is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D): + in_features = sub_layers[layer_name].weight.shape[0] + out_features = sub_layers[layer_name].weight.shape[1] + int_weight = sub_layers[layer_name].weight.t_().contiguous() + scale = scale.t_().contiguous() + zp = zp.t_().contiguous() if zp is not None else zp + new_module = WeightOnlyLinear( - sub_layers[layer_name].in_features, - sub_layers[layer_name].out_features, + in_features, + out_features, dtype=weight_config_this_layer["dtype"], bits=weight_config_this_layer["bits"], group_size=weight_config_this_layer["group_size"], @@ -790,7 +804,7 @@ def __init__(self, layer, W, device="cpu"): # W = layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d): W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): + if is_transformers_imported() and isinstance(self.layer, transformers.Conv1D): W = W.t() self.rows = W.shape[0] # output channels self.columns = W.shape[1] # input channels @@ -806,7 +820,9 @@ def add_batch(self, inp, out): if len(inp.shape) == 2: inp = inp.unsqueeze(0) tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): + if isinstance(self.layer, nn.Linear) or ( + is_transformers_imported() and isinstance(self.layer, transformers.Conv1D) + ): if len(inp.shape) == 3: inp = inp.reshape((-1, inp.shape[-1])) inp = inp.t() @@ -833,7 +849,7 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F weight_shape, weight_dtype = W.shape, W.data.dtype if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): + if is_transformers_imported() and isinstance(self.layer, transformers.Conv1D): W = W.t() W = W.float() @@ -937,7 +953,7 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F invperm = torch.argsort(perm) Q = Q[:, invperm] - if isinstance(self.layer, transformers.Conv1D): + if is_transformers_imported() and isinstance(self.layer, transformers.Conv1D): Q = Q.t() # self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) Q = Q.reshape(weight_shape).to(weight_dtype) diff --git a/neural_compressor/torch/algorithms/weight_only/rtn.py b/neural_compressor/torch/algorithms/weight_only/rtn.py index e6ad03f1351..9092169b4b1 100644 --- a/neural_compressor/torch/algorithms/weight_only/rtn.py +++ b/neural_compressor/torch/algorithms/weight_only/rtn.py @@ -24,10 +24,13 @@ import torch from neural_compressor.torch.algorithms import Quantizer -from neural_compressor.torch.utils import get_device, logger, set_module +from neural_compressor.torch.utils import get_device, is_transformers_imported, logger, set_module from .utility import cast_fp8, quant_tensor, search_clip +if is_transformers_imported(): + import transformers + class RTNQuantizer(Quantizer): def __init__(self, quant_config: OrderedDict = {}): @@ -94,7 +97,10 @@ def convert( model.to(device) assert isinstance(model, torch.nn.Module), "only support torch module" - supported_layers = (torch.nn.Linear,) + if is_transformers_imported(): + supported_layers = (torch.nn.Linear, transformers.Conv1D) + else: + supported_layers = (torch.nn.Linear,) # initialize global configuration double_quant_config = { "double_quant": kwargs.get("use_double_quant", False), @@ -153,7 +159,12 @@ def convert( continue logger.debug(f"RTN quantized module:{name, m}") logger.debug(log_msg) - if group_dim == 0: + # for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight. + if is_transformers_imported(): + transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D)) + else: + transpose = group_dim == 0 + if transpose: weight = m.weight.t_().contiguous() else: weight = m.weight @@ -171,14 +182,23 @@ def convert( full_range=use_full_range, **double_quant_config, ) - int_weight = int_weight.t_().contiguous() if group_dim == 0 else int_weight - scale = scale.t_().contiguous() if group_dim == 0 else scale - zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp + int_weight = int_weight.t_().contiguous() if transpose else int_weight + scale = scale.t_().contiguous() if transpose else scale + zp = zp.t_().contiguous() if transpose and zp is not None else zp + if isinstance(m, torch.nn.Linear): + in_features = m.in_features + out_features = m.out_features + elif is_transformers_imported() and isinstance(m, transformers.Conv1D): + in_features = m.weight.shape[1] + out_features = m.weight.shape[0] + int_weight = int_weight.t_().contiguous() + scale = scale.t_().contiguous() + zp = zp.t_().contiguous() if zp is not None else zp from .modules import WeightOnlyLinear new_module = WeightOnlyLinear( - m.in_features, - m.out_features, + in_features, + out_features, dtype=dtype, bits=bits, group_size=group_size, @@ -203,8 +223,9 @@ def convert( full_range=use_full_range, **double_quant_config, ) - if group_dim == 0: - # for group_dim is 0, we need to transpose the quantized tensor and module's weight back + if transpose: + # for only group_dim is 0 or only `transformers.Conv1D`, + # we need to transpose the quantized tensor and module's weight back weight = weight.t_().contiguous() m.weight.t_().contiguous() m.weight.data.copy_(weight) diff --git a/neural_compressor/torch/algorithms/weight_only/teq.py b/neural_compressor/torch/algorithms/weight_only/teq.py index 90671c14f12..827203367ae 100644 --- a/neural_compressor/torch/algorithms/weight_only/teq.py +++ b/neural_compressor/torch/algorithms/weight_only/teq.py @@ -20,14 +20,16 @@ from typing import Any import torch -import transformers from neural_compressor.torch.algorithms.base_algorithm import Quantizer -from neural_compressor.torch.utils import get_device, logger +from neural_compressor.torch.utils import get_device, is_transformers_imported, logger from .modules import MulLinear, TEQLinearFakeQuant from .utility import get_module, quant_tensor, set_module +if is_transformers_imported(): + import transformers + __all__ = ["TrainableEquivalentTransformation", "TEQuantizer"] diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index b909cc0fe66..dc850b04d2c 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -43,7 +43,7 @@ STATIC_QUANT, TEQ, ) -from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, logger +from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, is_transformers_imported, logger from neural_compressor.torch.utils.constants import ( PRIORITY_AUTOROUND, PRIORITY_AWQ, @@ -66,6 +66,12 @@ FRAMEWORK_NAME = "torch" +if is_transformers_imported(): + import transformers + + WOQ_WHITE_LIST = (torch.nn.Linear, transformers.Conv1D) +else: + WOQ_WHITE_LIST = (torch.nn.Linear,) class OperatorConfig(NamedTuple): @@ -193,10 +199,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]: @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - white_list = (torch.nn.Linear,) filter_result = [] for op_name, module in model.named_modules(): - if isinstance(module, white_list): + if isinstance(module, WOQ_WHITE_LIST): pair = (op_name, type(module).__name__) filter_result.append(pair) logger.debug(f"Get model info: {filter_result}") @@ -339,10 +344,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]: @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - white_list = (torch.nn.Linear,) filter_result = [] for op_name, module in model.named_modules(): - if isinstance(module, white_list): + if isinstance(module, WOQ_WHITE_LIST): pair = (op_name, type(module).__name__) filter_result.append(pair) logger.debug(f"Get model info: {filter_result}") @@ -472,10 +476,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]: @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - white_list = (torch.nn.Linear,) filter_result = [] for op_name, module in model.named_modules(): - if isinstance(module, white_list): + if isinstance(module, WOQ_WHITE_LIST): pair = (op_name, type(module).__name__) filter_result.append(pair) logger.debug(f"Get model info: {filter_result}") @@ -597,10 +600,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]: @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - white_list = (torch.nn.Linear,) filter_result = [] for op_name, module in model.named_modules(): - if isinstance(module, white_list): + if isinstance(module, WOQ_WHITE_LIST): pair = (op_name, type(module).__name__) filter_result.append(pair) logger.debug(f"Get model info: {filter_result}") @@ -743,10 +745,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]: @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - white_list = (torch.nn.Linear,) filter_result = [] for op_name, module in model.named_modules(): - if isinstance(module, white_list): + if isinstance(module, WOQ_WHITE_LIST): pair = (op_name, type(module).__name__) filter_result.append(pair) logger.debug(f"Get model info: {filter_result}") @@ -1071,10 +1072,9 @@ def __init__( @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - white_list = (torch.nn.Linear,) filter_result = [] for op_name, module in model.named_modules(): - if isinstance(module, white_list): + if isinstance(module, WOQ_WHITE_LIST): pair = (op_name, type(module).__name__) filter_result.append(pair) return filter_result diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index b787febf40d..090897d5e5e 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -74,6 +74,13 @@ def is_ipex_imported() -> bool: return False +def is_transformers_imported() -> bool: + for name, _ in sys.modules.items(): + if name == "transformers": + return True + return False + + def get_device(device_name="auto"): from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index d8c0027c844..6ff08696a01 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -4,6 +4,7 @@ import pytest import torch import transformers +from packaging.version import Version from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer, get_autoround_default_run_fn from neural_compressor.torch.quantization import ( @@ -18,6 +19,10 @@ try: import auto_round + AUTO_ROUND_VERSION_0_11 = Version("0.11") + + auto_round_version = auto_round.__version__.split("+")[0] + auto_round_version = Version(auto_round_version) auto_round_installed = True except ImportError: auto_round_installed = False @@ -146,3 +151,37 @@ def test_save_and_load(self): loaded_model = load("saved_results") loaded_out = loaded_model(self.inp)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." + + @pytest.mark.skipif(auto_round_version <= AUTO_ROUND_VERSION_0_11, reason="Requires auto_round>=0.11") + def test_conv1d(self): + input = torch.randn(1, 32) + from transformers import GPT2Model, GPT2Tokenizer + + tokenizer = GPT2Tokenizer.from_pretrained("sshleifer/tiny-gpt2") + model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2") + text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + out1 = model(**encoded_input)[0] + run_fn = get_autoround_default_run_fn + run_args = ( + tokenizer, + "NeelNanda/pile-10k", + 20, + 10, + ) + weight_config = { + "*": { + "data_type": "int", + "bits": 4, + "group_size": 32, + "sym": False, + } + } + quantizer = AutoRoundQuantizer(quant_config=weight_config) + + # quantizer execute + model = quantizer.prepare(model=model) + run_fn(model, *run_args) + q_model = quantizer.convert(model) + out2 = q_model(**encoded_input)[0] + assert torch.allclose(out2, out1, atol=0.01), "Accuracy gap atol > 0.01 is unexpected." diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index 3aa63635eaa..9a38234b349 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -192,6 +192,27 @@ def test_double_quant_params(self, dtype, double_quant_bits, double_quant_group_ except: assert torch.allclose(atol_false, atol_true, atol=0.008), "atol is very close, double checked the logic." + def test_conv1d(self): + from transformers import GPT2Model, GPT2Tokenizer + + tokenizer = GPT2Tokenizer.from_pretrained("sshleifer/tiny-gpt2") + model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2") + text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + + def run_fn_conv1d(model): + with pytest.raises(ValueError): + for i in range(2): + model(**encoded_input) + + quant_config = get_default_gptq_config() + out1 = model(**encoded_input)[0] + model = prepare(model, quant_config) + run_fn_conv1d(model) + q_model = convert(model) + out2 = q_model(**encoded_input)[0] + assert torch.allclose(out2, out1, atol=0.01), "Accuracy gap atol > 0.01 is unexpected." + def test_save_and_load(self): fp32_model = copy.deepcopy(self.tiny_gptj) quant_config = get_default_gptq_config() diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index 80bd512463c..ae92e4a935c 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -16,6 +16,20 @@ ) +class ModelConv1d(torch.nn.Module): + def __init__(self): + super(ModelConv1d, self).__init__() + self.fc1 = transformers.Conv1D(50, 32) + self.fc2 = torch.nn.Linear(50, 32) + self.fc3 = torch.nn.Linear(32, 5) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + out = self.fc3(out) + return out + + class TestRTNQuant: def setup_class(self): self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( @@ -223,6 +237,39 @@ def test_rtn_with_quantize_API(self): output_1.eq(output_2) ), "The results of calling `convert` + `prepare` and calling `quantize` should be equal." + # TODO: (4, True, 32, 0), group_dim=0, format not supported + @pytest.mark.parametrize( + "bits, use_sym, group_size, group_dim", + [ + (8, True, 128, 1), + (4, True, 128, 1), + (4, False, 32, 1), + (4, False, -1, 1), + (2, True, 8, 1), + ], + ) + def test_conv1d(self, bits, use_sym, group_size, group_dim): + model = ModelConv1d() + input = torch.randn(1, 32) + quant_config = RTNConfig( + bits=bits, + use_sym=use_sym, + group_size=group_size, + group_dim=group_dim, + ) + out1 = model(input) + model = prepare(model, quant_config) + model = convert(model) + out2 = model(input) + # assert torch.allclose(out2, out1, atol=0.01), "Accuracy gap atol > 0.01 is unexpected." + assert (out2 != out1).all(), "WOQ out2put should be different with raw output" + if (bits, use_sym, group_size, group_dim) == (8, True, 128, 1): + assert torch.allclose(out2, out1, atol=0.01), "Accuracy gap atol > 0.01 is unexpected." + if (bits, use_sym, group_size, group_dim) == [(4, True, 128, 0), (4, True, 32, 1)]: + assert torch.allclose(out2, out1, atol=0.1), "Accuracy gap atol > 0.1 is unexpected." + if (bits, use_sym, group_size, group_dim) == [(4, False, 32, 0), (4, False, -1, 1), (2, True, 8, 1)]: + assert torch.allclose(out2, out1, atol=0.5), "Accuracy gap atol > 0.5 is unexpected." + def test_save_and_load(self): fp32_model = copy.deepcopy(self.tiny_gptj) quant_config = get_default_rtn_config()