diff --git a/.azure-pipelines/scripts/ut/env_setup.sh b/.azure-pipelines/scripts/ut/env_setup.sh index 2937c154300..db9a7a0d43a 100644 --- a/.azure-pipelines/scripts/ut/env_setup.sh +++ b/.azure-pipelines/scripts/ut/env_setup.sh @@ -97,6 +97,11 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then # so test distribute cases in the env with single fw installed pip install horovod fi + +if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then + pip install git+https://github.com/intel/auto-round.git@a868c805de4be271cfe7403309a64d9bf03a0ecf +fi + # test deps pip install coverage pip install pytest diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 65886791d48..0c21263e116 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -4615,6 +4615,9 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None): q_model._model = self.awq_quantize(q_model._model, tune_cfg, dataloader, calib_func) if "RTN" in all_algo: q_model._model = self.rtn_quantize(q_model._model, tune_cfg) + if "AUTOROUND" in all_algo: + q_model._model, autoround_config = self.autoround_quantize(q_model._model, tune_cfg, dataloader) + q_model.autoround_config = autoround_config q_model.q_config = copy.deepcopy(self.tune_cfg) q_model.is_quantized = True @@ -4911,6 +4914,93 @@ def awq_quantize(self, model, tune_cfg, dataloader, calib_func): ) return model + def autoround_quantize(self, model, tune_cfg, dataloader): + logger.info("quantizing with the AutoRound algorithm") + from .torch_utils.weight_only import autoround_quantize + + # build weight_config + """ + weight_config={ + 'layer1':##layer_name + { + 'data_type': 'int', + 'bits': 4, + 'group_size': 32, + 'scheme': "asym", ## or sym + } + ... + } + """ + weight_config = {} + for key, config in tune_cfg["op"].items(): + if config["weight"]["dtype"] == "fp32": + continue + op_name, op_type = key + weight_config[op_name] = {} + weight_config[op_name]["data_type"] = config["weight"]["dtype"] + weight_config[op_name]["bits"] = config["weight"]["bits"] + weight_config[op_name]["group_size"] = config["weight"]["group_size"] + weight_config[op_name]["scheme"] = config["weight"]["scheme"] + + # auto round recipes + enable_full_range = self.recipes["autoround_args"].get("enable_full_range", False) + bs = self.recipes["autoround_args"].get("bs", 8) + amp = self.recipes["autoround_args"].get("amp", True) + device = self.recipes["autoround_args"].get("device", "cpu") + lr_scheduler = self.recipes["autoround_args"].get("lr_scheduler", None) + dataset_name = self.recipes["autoround_args"].get("dataset_name", "NeelNanda/pile-10k") + dataset_split = self.recipes["autoround_args"].get("dataset_split", "train") + use_quant_input = self.recipes["autoround_args"].get("use_quant_input", True) + enable_minmax_tuning = self.recipes["autoround_args"].get("enable_minmax_tuning", True) + lr = self.recipes["autoround_args"].get("lr", None) + minmax_lr = self.recipes["autoround_args"].get("minmax_lr", None) + low_gpu_mem_usage = self.recipes["autoround_args"].get("low_gpu_mem_usage", True) + iters = self.recipes["autoround_args"].get("iters", 200) + seqlen = self.recipes["autoround_args"].get("seqlen", 2048) + n_samples = self.recipes["autoround_args"].get("n_samples", 512) + sampler = self.recipes["autoround_args"].get("sampler", "rand") + seed = self.recipes["autoround_args"].get("seed", 42) + n_blocks = self.recipes["autoround_args"].get("n_blocks", 1) + gradient_accumulate_steps = self.recipes["autoround_args"].get("gradient_accumulate_steps", 1) + not_use_best_mse = self.recipes["autoround_args"].get("not_use_best_mse", False) + dynamic_max_gap = self.recipes["autoround_args"].get("dynamic_max_gap", -1) + data_type = self.recipes["autoround_args"].get("data_type", "int") ##only support data_type + scale_dtype = self.recipes["autoround_args"].get("scale_dtype", "fp16") + + model, autoround_config = autoround_quantize( + model=model, + tokenizer=None, + bits=4, + group_size=128, + scheme="asym", + weight_config=weight_config, + enable_full_range=enable_full_range, + bs=bs, + amp=amp, + device=device, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + dataset_name=dataset_name, + dataset_split=dataset_split, + use_quant_input=use_quant_input, + enable_minmax_tuning=enable_minmax_tuning, + lr=lr, + minmax_lr=minmax_lr, + low_gpu_mem_usage=low_gpu_mem_usage, + iters=iters, + seqlen=seqlen, + n_samples=n_samples, + sampler=sampler, + seed=seed, + n_blocks=n_blocks, + gradient_accumulate_steps=gradient_accumulate_steps, + not_use_best_mse=not_use_best_mse, + dynamic_max_gap=dynamic_max_gap, + data_type=data_type, + scale_dtype=scale_dtype, + ) + return model, autoround_config + def _dump_model_op_stats(self, model, tune_cfg): """This is a function to dump quantizable ops of model to user. diff --git a/neural_compressor/adaptor/pytorch_cpu.yaml b/neural_compressor/adaptor/pytorch_cpu.yaml index b7f2ce79013..f815c5c7f18 100644 --- a/neural_compressor/adaptor/pytorch_cpu.yaml +++ b/neural_compressor/adaptor/pytorch_cpu.yaml @@ -267,7 +267,7 @@ # group_size=-1 means per-channel, others means per-group 'group_size': [32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024], # [1-inf], # 32 'scheme': ['sym', 'asym'], # sym, no ZP - 'algorithm': ['RTN', 'AWQ', 'GPTQ', 'TEQ'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order + 'algorithm': ['RTN', 'AWQ', 'GPTQ', 'TEQ', 'AUTOROUND'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order }, 'activation': { 'dtype': ['fp32'], diff --git a/neural_compressor/adaptor/torch_utils/auto_round.py b/neural_compressor/adaptor/torch_utils/auto_round.py new file mode 100644 index 00000000000..1bb8df54d06 --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/auto_round.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401 + + +def get_dataloader( + tokenizer, seqlen=2048, seed=42, train_bs=8, dataset_split="train", dataset_name="NeelNanda/pile-10k" +): + get_dataloader = CALIB_DATASETS.get(dataset_name, CALIB_DATASETS["NeelNanda/pile-10k"]) + dataloader = get_dataloader( + tokenizer, seqlen=seqlen, seed=seed, bs=train_bs, split=dataset_split, dataset_name=dataset_name + ) + return dataloader diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index 4a4fcf19d95..e5099490ab3 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -670,3 +670,121 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1): int_weight_tmp.add_(zp[:, -1].unsqueeze(1)) int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_()) return int_weight + + +def autoround_quantize( + model, + tokenizer, + bits: int = 4, + group_size: int = 128, + scheme: str = "asym", + weight_config: dict = {}, + enable_full_range: bool = False, ##for symmetric, TODO support later + bs: int = 8, + amp: bool = True, + device="cuda:0", + lr_scheduler=None, + dataloader=None, ## to support later + dataset_name: str = "NeelNanda/pile-10k", + dataset_split: str = "train", + use_quant_input: bool = True, + enable_minmax_tuning: bool = True, + lr: float = None, + minmax_lr: float = None, + low_gpu_mem_usage: bool = True, + iters: int = 200, + seqlen: int = 2048, + n_samples: int = 512, + sampler: str = "rand", + seed: int = 42, + n_blocks: int = 1, + gradient_accumulate_steps: int = 1, + not_use_best_mse: bool = False, + dynamic_max_gap: int = -1, + data_type: str = "int", ##only support data_type + scale_dtype="fp16", + **kwargs, +): + """Run autoround weight-only quantization. + Args: + model: The PyTorch model to be quantized. + tokenizer: Tokenizer for processing input data. Temporarily set as a mandatory parameter. + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + scheme (str): The quantization scheme to be used (default is "asym"). + weight_config (dict): Configuration for weight quantization (default is an empty dictionary). + weight_config={ + 'layer1':##layer_name + { + 'data_type': 'int', + 'bits': 4, + 'group_size': 32, + 'scheme': "asym", ## or sym + } + ... + } + enable_full_range (bool): Whether to enable full range quantization (default is False). + bs (int): Batch size for training (default is 8). + amp (bool): Whether to use automatic mixed precision (default is True). + device: The device to be used for tuning (default is "cuda:0"). + lr_scheduler: The learning rate scheduler to be used. + dataloader: The dataloader for input data (to be supported in future). + dataset_name (str): The default dataset name (default is "NeelNanda/pile-10k"). + dataset_split (str): The split of the dataset to be used (default is "train"). + use_quant_input (bool): Whether to use quantized input data (default is True). + enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). + lr (float): The learning rate (default is 0.005). + minmax_lr (float): The learning rate for min-max tuning (default is None). + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). + iters (int): Number of iterations (default is 200). + seqlen (int): Length of the sequence. + n_samples (int): Number of samples (default is 512). + sampler (str): The sampling method (default is "rand"). + seed (int): The random seed (default is 42). + n_blocks (int): Number of blocks (default is 1). + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). + not_use_best_mse (bool): Whether to use mean squared error (default is False). + dynamic_max_gap (int): The dynamic maximum gap (default is -1). + data_type (str): The data type to be used (default is "int"). + **kwargs: Additional keyword arguments. + + Returns: + The quantized model. + """ + from auto_round import AutoRound # pylint: disable=E0401 + + rounder = AutoRound( + model=model, + tokenizer=tokenizer, + bits=bits, + group_size=group_size, + scheme=scheme, + weight_config=weight_config, + enable_full_range=enable_full_range, ##for symmetric, TODO support later + bs=bs, + amp=amp, + device=device, + lr_scheduler=lr_scheduler, + dataloader=dataloader, ## to support later + dataset_name=dataset_name, + dataset_split=dataset_split, + use_quant_input=use_quant_input, + enable_minmax_tuning=enable_minmax_tuning, + lr=lr, + minmax_lr=minmax_lr, + low_gpu_mem_usage=low_gpu_mem_usage, + iters=iters, + seqlen=seqlen, + n_samples=n_samples, + sampler=sampler, + seed=seed, + n_blocks=n_blocks, + gradient_accumulate_steps=gradient_accumulate_steps, + not_use_best_mse=not_use_best_mse, + dynamic_max_gap=dynamic_max_gap, + data_type=data_type, ## only support data_type + scale_dtype=scale_dtype, + **kwargs, + ) + qdq_model, weight_config = rounder.quantize() + return qdq_model, weight_config diff --git a/neural_compressor/config.py b/neural_compressor/config.py index a8042b85d18..2503a78744e 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -60,7 +60,7 @@ ), Optional("algorithm"): And( list, # TODO: allow AWQ+GPTQ algo - lambda s: all(i in ["minmax", "RTN", "AWQ", "GPTQ", "TEQ"] for i in s), + lambda s: all(i in ["minmax", "RTN", "AWQ", "GPTQ", "TEQ", "AUTOROUND"] for i in s), ), Optional("bits"): And(list, lambda s: all(0 < i <= 8 and type(i) == int for i in s)), Optional("group_size"): And(list, lambda s: all(i >= -1 and i != 0 and type(i) == int for i in s)), @@ -941,6 +941,12 @@ def teq_args(val=None): else: return {} + def autoround_args(val=None): + if val is not None: + return _check_value("autoround_args", val, dict) + else: + return {} + def fast_bias_correction(val=None): if val is not None: return _check_value("fast_bias_correction", val, bool) @@ -1025,6 +1031,7 @@ def dedicated_qdq_pair(val=None): "awq_args": awq_args, "gptq_args": gptq_args, "teq_args": teq_args, + "autoround_args": autoround_args, } self._recipes = {} for k in RECIPES.keys(): diff --git a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py index a2da94ac822..7d75a19840d 100644 --- a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py +++ b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py @@ -12,6 +12,13 @@ from neural_compressor.utils.load_huggingface import export_compressed_model from neural_compressor.utils.pytorch import load +try: + import auto_round + + auto_round_installed = True +except ImportError: + auto_round_installed = False + class Model(torch.nn.Module): def __init__(self): @@ -738,6 +745,62 @@ def __iter__(self): out2 = q_model.model(input) self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-01)) + @unittest.skipIf(not auto_round_installed, "auto_round module is not installed") + def test_AutoRound_quant(self): + from neural_compressor.adaptor.torch_utils.auto_round import get_dataloader + + tokenizer = transformers.AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True + ) + dataloader = get_dataloader( + tokenizer, seqlen=10, seed=42, train_bs=8, dataset_split="train", dataset_name="NeelNanda/pile-10k" + ) + fp32_model = copy.deepcopy(self.gptj) + + conf = PostTrainingQuantConfig( + approach="weight_only", + op_type_dict={ + ".*": { # re.match + "weight": { + "dtype": "int", + "bits": 4, + "group_size": 32, # -1 (per-channel) + "scheme": "sym", + "algorithm": "AUTOROUND", + }, + }, + }, + op_name_dict={ + ".*lm_head": { # re.match + "weight": {"dtype": "fp32"}, + }, + }, + recipes={ + "autoround_args": { + "n_samples": 20, + "amp": False, + "seq_len": 10, + "iters": 10, + "scale_dtype": "fp32", + "device": "cpu", + }, + }, + ) + + input = torch.ones([1, 512], dtype=torch.long) + fp32_model = copy.deepcopy(self.gptj) + out1 = fp32_model(input) + q_model = quantization.fit( + fp32_model, + conf, + calib_dataloader=dataloader, + ) + out2 = q_model.model(input) + self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-01)) + self.assertTrue("transformer.h.0.attn.k_proj" in q_model.autoround_config.keys()) + self.assertTrue("scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()) + self.assertTrue(torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_weight_only_quantization.py b/test/quantization/test_weight_only_quantization.py index 5c991b74c5a..a43fea96e4b 100644 --- a/test/quantization/test_weight_only_quantization.py +++ b/test/quantization/test_weight_only_quantization.py @@ -1,4 +1,5 @@ import copy +import shutil import unittest import torch @@ -6,7 +7,20 @@ from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear from neural_compressor.adaptor.torch_utils.smooth_quant import GraphTrace -from neural_compressor.adaptor.torch_utils.weight_only import awq_quantize, gptq_quantize, rtn_quantize, teq_quantize +from neural_compressor.adaptor.torch_utils.weight_only import ( + autoround_quantize, + awq_quantize, + gptq_quantize, + rtn_quantize, + teq_quantize, +) + +try: + import auto_round + + auto_round_installed = True +except ImportError: + auto_round_installed = False class Model(torch.nn.Module): @@ -221,5 +235,63 @@ def test_teq(self): self.assertTrue(isinstance(model, torch.nn.Module)) +class LLMDataLoader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(2): + yield torch.ones([1, 10], dtype=torch.long) + + +@unittest.skipIf(not auto_round_installed, "auto_round module is not installed") +class TestAutoRoundWeightOnlyQuant(unittest.TestCase): + approach = "weight_only" + + @classmethod + def setUpClass(self): + self.dataloader = SimpleDataLoader() + self.gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + torchscript=True, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True + ) + self.gptj_no_jit = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + ) + self.llm_dataloader = LLMDataLoader() + self.lm_input = torch.ones([1, 10], dtype=torch.long) + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_autoround_int_quant(self): + model = copy.deepcopy(self.gptj) + device = "cpu" + model = model + out1 = model(self.lm_input) + q_model, weight_config1 = autoround_quantize( + model=model, + tokenizer=self.tokenizer, + n_samples=20, + device=device, + amp=False, + seqlen=10, + iters=10, + scale_dtype="fp32", + ) + q_model = q_model + model = model + out2 = model(self.lm_input) + out3 = q_model(self.lm_input) + self.assertTrue(torch.all(torch.isclose(out1[0], out2[0], atol=1e-1))) + self.assertFalse(torch.all(out1[0] == out2[0])) + self.assertTrue(torch.all(out2[0] == out3[0])) + + if __name__ == "__main__": unittest.main() diff --git a/test/requirements.txt b/test/requirements.txt index 616dbe385dc..3b71fd8af68 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -1,6 +1,7 @@ --find-links https://download.pytorch.org/whl/torch_stable.html accelerate==0.21.0 dynast==1.6.0rc1 +git+https://github.com/intel/auto-round.git@a868c805de4be271cfe7403309a64d9bf03a0ecf horovod intel-extension-for-pytorch intel-tensorflow>=2.12.0