From e3c736fd910690faf08bf4609cc3b65529d79252 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Mon, 29 Apr 2024 14:03:55 +0800 Subject: [PATCH] Migrate AutoRound to Torch new 3x API (#1763) Signed-off-by: Kaihui-intel --- .../torch/algorithms/weight_only/autoround.py | 355 +++++++++++------- .../torch/quantization/algorithm_entry.py | 17 +- .../weight_only/test_autoround.py | 83 ++-- 3 files changed, 288 insertions(+), 167 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 00d89426811..8ba9d5fb637 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -12,13 +12,151 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time + import torch from auto_round import AutoRound # pylint: disable=E0401 from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401 +from auto_round.utils import get_block_names # pylint: disable=E0401 +from neural_compressor.torch.algorithms import Quantizer from neural_compressor.torch.utils import logger +class AutoRoundQuantizer(Quantizer): + def __init__( + self, + weight_config: dict = {}, + enable_full_range: bool = False, + batch_size: int = 8, + amp: bool = True, + device=None, + lr_scheduler=None, + use_quant_input: bool = True, + enable_minmax_tuning: bool = True, + lr: float = None, + minmax_lr: float = None, + low_gpu_mem_usage: bool = True, + iters: int = 200, + seqlen: int = 2048, + n_samples: int = 512, + sampler: str = "rand", + seed: int = 42, + n_blocks: int = 1, + gradient_accumulate_steps: int = 1, + not_use_best_mse: bool = False, + dynamic_max_gap: int = -1, + scale_dtype="fp32", + ): + """Init a AutQRoundQuantizer object. + + Args: + weight_config (dict): Configuration for weight quantization (default is an empty dictionary). + weight_config={ + 'layer1':##layer_name + { + 'data_type': 'int', + 'bits': 4, + 'group_size': 32, + 'sym': False, + } + ... + } + keys: + data_type (str): The data type to be used (default is "int"). + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + sym (bool): Whether to use symmetric quantization. (default is None). + enable_full_range (bool): Whether to enable full range quantization (default is False). + batch_size (int): Batch size for training (default is 8). + amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set. + device: The device to be used for tuning (default is None). Automatically detect and set. + lr_scheduler: The learning rate scheduler to be used. + use_quant_input (bool): Whether to use quantized input data (default is True). + enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). + lr (float): The learning rate (default is 0.005). + minmax_lr (float): The learning rate for min-max tuning (default is None). + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). + iters (int): Number of iterations (default is 200). + seqlen (int): Length of the sequence. + n_samples (int): Number of samples (default is 512). + sampler (str): The sampling method (default is "rand"). + seed (int): The random seed (default is 42). + n_blocks (int): Number of blocks (default is 1). + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). + not_use_best_mse (bool): Whether to use mean squared error (default is False). + dynamic_max_gap (int): The dynamic maximum gap (default is -1). + scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels + have different choices. + """ + + self.tokenizer = None + self.weight_config = weight_config + self.enable_full_range = enable_full_range + self.batch_size = batch_size + self.amp = amp + self.device = device + self.lr_scheduler = lr_scheduler + self.use_quant_input = use_quant_input + self.enable_minmax_tuning = enable_minmax_tuning + self.lr = lr + self.minmax_lr = minmax_lr + self.low_gpu_mem_usage = low_gpu_mem_usage + self.iters = iters + self.seqlen = seqlen + self.n_samples = n_samples + self.sampler = sampler + self.seed = seed + self.n_blocks = n_blocks + self.gradient_accumulate_steps = gradient_accumulate_steps + self.not_use_best_mse = not_use_best_mse + self.dynamic_max_gap = dynamic_max_gap + self.data_type = "int" + self.scale_dtype = scale_dtype + + def prepare(self, model: torch.nn.Module, *args, **kwargs): + """Prepares a given model for quantization. + Args: + model (torch.nn.Module): The model to be prepared. + + Returns: + A prepared model. + """ + self.rounder = AutoRoundProcessor( + model=model, + tokenizer=None, + weight_config=self.weight_config, + enable_full_range=self.enable_full_range, + batch_size=self.batch_size, + amp=self.amp, + device=self.device, + lr_scheduler=self.lr_scheduler, + use_quant_input=self.use_quant_input, + enable_minmax_tuning=self.enable_minmax_tuning, + lr=self.lr, + minmax_lr=self.minmax_lr, + low_gpu_mem_usage=self.low_gpu_mem_usage, + iters=self.iters, + seqlen=self.seqlen, + n_samples=self.n_samples, + sampler=self.sampler, + seed=self.seed, + n_blocks=self.n_blocks, + gradient_accumulate_steps=self.gradient_accumulate_steps, + not_use_best_mse=self.not_use_best_mse, + dynamic_max_gap=self.dynamic_max_gap, + data_type=self.data_type, + scale_dtype=self.scale_dtype, + ) + self.rounder.prepare() + return model + + def convert(self, model: torch.nn.Module, *args, **kwargs): + model, weight_config = self.rounder.convert() + model.autoround_config = weight_config + return model + + @torch.no_grad() def get_autoround_default_run_fn( model, @@ -94,140 +232,95 @@ def get_autoround_default_run_fn( ) -class InputCaptureModule(torch.nn.Module): +class AutoRoundProcessor(AutoRound): - def __init__(self) -> None: - super().__init__() - self.data_pairs = [] - self.device = "cpu" + def prepare(self): + """Prepares a given model for quantization.""" + # logger.info("cache block input") + self.start_time = time.time() + self.block_names = get_block_names(self.model) + if len(self.block_names) == 0: + logger.warning("could not find blocks, exit with original model") + return + if self.amp: + self.model = self.model.to(self.amp_dtype) + if not self.low_gpu_mem_usage: + self.model = self.model.to(self.device) + # inputs = self.cache_block_input(block_names[0], self.n_samples) - def forward(self, *args, **kwargs): - if kwargs and len(args) == 0: - # Handle cases where input data is a dict - self.data_pairs.append(kwargs) - elif args and len(args) == 1: - # Handle cases where input data is a Tensor - self.data_pairs.append(args[0]) - else: - logger.error("Handle cases where input data is neither a Tensor nor a dict") + # cache block input + self.inputs = {} + self.tmp_block_name = self.block_names[0] + self._replace_forward() + def convert(self): + """Converts a prepared model to a quantized model.""" + self._recover_forward() + inputs = self.inputs[self.tmp_block_name] + del self.tmp_block_name -def recover_dataloader_from_calib_fn(run_fn, run_args): - input_capture_model = InputCaptureModule() - input_capture_model.eval() - run_fn(input_capture_model, *run_args) - dataloader = torch.utils.data.DataLoader(input_capture_model.data_pairs) - return dataloader + del self.inputs + if "input_ids" in inputs.keys(): + dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type)) + total_samples = inputs["input_ids"].shape[dim] + self.n_samples = total_samples + if total_samples < self.train_bs: + self.train_bs = total_samples + logger.warning(f"force the train batch size to {total_samples} ") + self.model = self.model.to("cpu") + torch.cuda.empty_cache() + self.qdq_weight_round( + self.model, + inputs, + self.block_names, + n_blocks=self.n_blocks, + device=self.device, + ) + for n, m in self.model.named_modules(): + if n in self.weight_config.keys(): + if hasattr(m, "scale"): + self.weight_config[n]["scale"] = m.scale + self.weight_config[n]["zp"] = m.zp + if self.group_size <= 0: + self.weight_config[n]["g_idx"] = torch.tensor( + [0 for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu" + ) + else: + self.weight_config[n]["g_idx"] = torch.tensor( + [i // self.group_size for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu" + ) + delattr(m, "scale") + delattr(m, "zp") + else: + self.weight_config[n]["data_type"] = "float" + if self.amp_dtype == torch.bfloat16: + self.weight_config[n]["data_type"] = "bfloat" + self.weight_config[n]["bits"] = 16 + self.weight_config[n]["group_size"] = None + self.weight_config[n]["sym"] = None + end_time = time.time() + cost_time = end_time - self.start_time + logger.info(f"quantization tuning time {cost_time}") + ## dump a summary + quantized_layers = [] + unquantized_layers = [] + for n, m in self.model.named_modules(): + if isinstance(m, tuple(self.supported_types)): + if self.weight_config[n]["bits"] == 16: + unquantized_layers.append(n) + else: + quantized_layers.append(n) + summary_info = ( + f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model" + ) + if len(unquantized_layers) > 0: + summary_info += f", {unquantized_layers} have not been quantized" -def autoround_quantize( - model, - weight_config: dict = {}, - enable_full_range: bool = False, ##for symmetric, TODO support later - batch_size: int = 8, - amp: bool = True, - device=None, - lr_scheduler=None, - use_quant_input: bool = True, - enable_minmax_tuning: bool = True, - lr: float = None, - minmax_lr: float = None, - low_gpu_mem_usage: bool = True, - iters: int = 200, - seqlen: int = 2048, - n_samples: int = 512, - sampler: str = "rand", - seed: int = 42, - n_blocks: int = 1, - gradient_accumulate_steps: int = 1, - not_use_best_mse: bool = False, - dynamic_max_gap: int = -1, - scale_dtype="fp16", - run_fn=None, - run_args=None, -): - """The entry point of the autoround weight-only quantization. - Args: - model: The PyTorch model to be quantized. - weight_config (dict): Configuration for weight quantization (default is an empty dictionary). - weight_config={ - 'layer1':##layer_name - { - 'data_type': 'int', - 'bits': 4, - 'group_size': 32, - 'sym': False, - } - ... - } - keys: - data_type (str): The data type to be used (default is "int"). - bits (int): Number of bits for quantization (default is 4). - group_size (int): Size of the quantization group (default is 128). - sym (bool): Whether to use symmetric quantization. (default is None). - enable_full_range (bool): Whether to enable full range quantization (default is False). - batch_size (int): Batch size for training (default is 8). - amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set. - device: The device to be used for tuning (default is None). Automatically detect and set. - lr_scheduler: The learning rate scheduler to be used. - use_quant_input (bool): Whether to use quantized input data (default is True). - enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). - lr (float): The learning rate (default is 0.005). - minmax_lr (float): The learning rate for min-max tuning (default is None). - low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). - iters (int): Number of iterations (default is 200). - seqlen (int): Length of the sequence. - n_samples (int): Number of samples (default is 512). - sampler (str): The sampling method (default is "rand"). - seed (int): The random seed (default is 42). - n_blocks (int): Number of blocks (default is 1). - gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). - not_use_best_mse (bool): Whether to use mean squared error (default is False). - dynamic_max_gap (int): The dynamic maximum gap (default is -1). - scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels - have different choices. - run_fn: a calibration function for calibrating the model. Defaults to None. - run_args: positional arguments for `run_fn`. Defaults to None. - - Returns: - The quantized model. - """ - if run_fn is None or run_fn == get_autoround_default_run_fn: - assert run_args is not None, "Please provide tokenizer for AutoRound default calibration." - run_fn = get_autoround_default_run_fn - dataloader = recover_dataloader_from_calib_fn(run_fn, run_args) - - rounder = AutoRound( - model=model, - tokenizer=None, - bits=4, - group_size=128, - sym=False, - weight_config=weight_config, - enable_full_range=enable_full_range, ##for symmetric, TODO support later - batch_size=batch_size, - amp=amp, - device=device, - lr_scheduler=lr_scheduler, - dataloader=dataloader, - use_quant_input=use_quant_input, - enable_minmax_tuning=enable_minmax_tuning, - lr=lr, - minmax_lr=minmax_lr, - low_gpu_mem_usage=low_gpu_mem_usage, - iters=iters, - seqlen=seqlen, - n_samples=n_samples, - sampler=sampler, - seed=seed, - n_blocks=n_blocks, - gradient_accumulate_steps=gradient_accumulate_steps, - not_use_best_mse=not_use_best_mse, - dynamic_max_gap=dynamic_max_gap, - data_type="int", - scale_dtype=scale_dtype, - run_fn=run_fn, - run_args=run_args, - ) - qdq_model, weight_config = rounder.quantize() - return qdq_model, weight_config + logger.info(summary_info) + if len(unquantized_layers) > 0: + logger.info(f"Summary: {unquantized_layers} have not been quantized") + + self.quantized = True + self.model = self.model.to(self.model_orig_dtype) + return self.model, self.weight_config diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 2834b42949a..243a1bbfb99 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -334,12 +334,15 @@ def teq_quantize_entry( ###################### AUTOROUND Algo Entry ################################## @register_algo(name=AUTOROUND) def autoround_quantize_entry( - model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], AutoRoundConfig], *args, **kwargs + model: torch.nn.Module, + configs_mapping: Dict[Tuple[str, callable], AutoRoundConfig], + mode: Mode = Mode.QUANTIZE, + *args, + **kwargs ) -> torch.nn.Module: - from neural_compressor.torch.algorithms.weight_only.autoround import autoround_quantize + from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer logger.info("Quantize model with the AutoRound algorithm.") - calib_func = kwargs.get("run_fn", None) weight_config = {} for (op_name, op_type), quant_config in configs_mapping.items(): if quant_config.name != AUTOROUND or quant_config.dtype == "fp32": @@ -371,9 +374,8 @@ def autoround_quantize_entry( scale_dtype = quant_config.scale_dtype kwargs.pop("example_inputs") - kwargs.pop("mode") # TODO: will be removed after auto_round refactoring - model, autoround_config = autoround_quantize( - model=model, + + quantizer = AutoRoundQuantizer( weight_config=weight_config, enable_full_range=enable_full_range, batch_size=batch_size, @@ -393,9 +395,8 @@ def autoround_quantize_entry( not_use_best_mse=not_use_best_mse, dynamic_max_gap=dynamic_max_gap, scale_dtype=scale_dtype, - **kwargs ) - model.autoround_config = autoround_config + model = quantizer.execute(model=model, mode=mode, *args, **kwargs) logger.info("AutoRound quantization done.") return model diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index cd147ec62c5..7c96490292f 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -1,8 +1,8 @@ -import unittest - +import pytest import torch import transformers +from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer, get_autoround_default_run_fn from neural_compressor.torch.quantization import AutoRoundConfig, quantize from neural_compressor.torch.utils import logger @@ -14,7 +14,8 @@ auto_round_installed = False -def get_gpt_j(): +@pytest.fixture(scope="module") +def gpt_j(): tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( "hf-internal-testing/tiny-random-GPTJForCausalLM", torchscript=True, @@ -22,34 +23,29 @@ def get_gpt_j(): return tiny_gptj -@unittest.skipIf(not auto_round_installed, "auto_round module is not installed") -class TestAutoRound(unittest.TestCase): - @classmethod - def setUpClass(self): - self.gptj = get_gpt_j() - - @classmethod - def tearDownClass(self): - pass +@pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed") +class TestAutoRound: + @staticmethod + @pytest.fixture(scope="class", autouse=True) + def gpt_j_model(gpt_j): + yield gpt_j - def setUp(self): - # print the test name - logger.info(f"Running TestAutoRound test: {self.id()}") + def setup_method(self, method): + logger.info(f"Running TestAutoRound test: {method.__name__}") - def test_autoround(self): + def test_autoround(self, gpt_j_model): inp = torch.ones([1, 10], dtype=torch.long) tokenizer = transformers.AutoTokenizer.from_pretrained( "hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True ) - out1 = self.gptj(inp) + out1 = gpt_j_model(inp) quant_config = AutoRoundConfig(n_samples=20, seqlen=10, iters=10, scale_dtype="fp32") logger.info(f"Test AutoRound with config {quant_config}") - from neural_compressor.torch.algorithms.weight_only.autoround import get_autoround_default_run_fn qdq_model = quantize( - model=self.gptj, + model=gpt_j_model, quant_config=quant_config, run_fn=get_autoround_default_run_fn, run_args=( @@ -70,16 +66,47 @@ def test_autoround(self): dataloader=None, """ - out2 = qdq_model(inp) - self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-1)) - q_model = qdq_model out2 = q_model(inp) - self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-01)) - self.assertTrue("transformer.h.0.attn.k_proj" in q_model.autoround_config.keys()) - self.assertTrue("scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()) - self.assertTrue(torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]) + assert torch.allclose(out1[0], out2[0], atol=1e-1) + assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys() + assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys() + assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"] + def test_new_api(self, gpt_j_model): + inp = torch.ones([1, 10], dtype=torch.long) -if __name__ == "__main__": - unittest.main() + tokenizer = transformers.AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True + ) + + out1 = gpt_j_model(inp) + + run_fn = get_autoround_default_run_fn + run_args = ( + tokenizer, + "NeelNanda/pile-10k", + 20, + 10, + ) + weight_config = { + "*": { + "data_type": "int", + "bits": 4, + "group_size": 32, + "sym": False, + } + } + quantizer = AutoRoundQuantizer(weight_config=weight_config) + fp32_model = gpt_j_model + + # quantizer execute + model = quantizer.prepare(model=fp32_model) + run_fn(model, *run_args) + q_model = quantizer.convert(model) + + out2 = q_model(inp) + assert torch.allclose(out1[0], out2[0], atol=1e-1) + assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys() + assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys() + assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]