From bfa27e422dc4760f6a9b1783eee7dae10fe5324f Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Wed, 17 Jul 2024 09:33:13 +0800 Subject: [PATCH] Integrate AutoRound v0.3 (#1925) Signed-off-by: Kaihui-intel --- .../torch/algorithms/weight_only/autoround.py | 134 +++++++++++------- .../torch/quantization/algorithm_entry.py | 16 ++- .../torch/quantization/config.py | 36 +++-- .../weight_only/test_autoround.py | 10 +- test/3x/torch/requirements.txt | 2 +- 5 files changed, 130 insertions(+), 68 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 2e97533c0bb..6f5a022cfee 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -31,69 +31,95 @@ class AutoRoundQuantizer(Quantizer): def __init__( self, quant_config: dict = {}, - enable_full_range: bool = False, + enable_full_range: bool = False, ##for symmetric, TODO support later batch_size: int = 8, amp: bool = True, - device=None, + device: str = None, lr_scheduler=None, + dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", enable_quanted_input: bool = True, enable_minmax_tuning: bool = True, lr: float = None, minmax_lr: float = None, - low_gpu_mem_usage: bool = True, + low_gpu_mem_usage: bool = False, iters: int = 200, seqlen: int = 2048, - n_samples: int = 512, + nsamples: int = 128, sampler: str = "rand", seed: int = 42, - n_blocks: int = 1, + nblocks: int = 1, gradient_accumulate_steps: int = 1, not_use_best_mse: bool = False, dynamic_max_gap: int = -1, data_type: str = "int", scale_dtype: str = "fp16", + multimodal: bool = False, + act_bits: int = 32, + act_group_size: int = None, + act_sym: bool = None, + act_dynamic: bool = True, + low_cpu_mem_usage: bool = False, **kwargs, ): """Init a AutQRoundQuantizer object. Args: - quant_config (dict): Configuration for weight quantization (default is None). - quant_config={ - 'layer1':##layer_name - { - 'data_type': 'int', - 'bits': 4, - 'group_size': 32, - 'sym': False, + quant_config (dict): Configuration for weight quantization (default is None). + quant_config={ + 'layer1':##layer_name + { + 'data_type': 'int', + 'bits': 4, + 'group_size': 32, + 'sym': False, + 'act_data_type': None, + 'act_bits': 32, + 'act_sym': None, + 'act_dynamic': True, + } + ..., } - ... - } - keys: - data_type (str): The data type to be used (default is "int"). - bits (int): Number of bits for quantization (default is 4). - group_size (int): Size of the quantization group (default is 128). - sym (bool): Whether to use symmetric quantization. (default is None). - enable_full_range (bool): Whether to enable full range quantization (default is False). - batch_size (int): Batch size for training (default is 8). - amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set. - device: The device to be used for tuning (default is None). Automatically detect and set. - lr_scheduler: The learning rate scheduler to be used. - use_quant_input (bool): Whether to use quantized input data (default is True). - enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). - lr (float): The learning rate (default is 0.005). - minmax_lr (float): The learning rate for min-max tuning (default is None). - low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). - iters (int): Number of iterations (default is 200). - seqlen (int): Length of the sequence. - n_samples (int): Number of samples (default is 512). - sampler (str): The sampling method (default is "rand"). - seed (int): The random seed (default is 42). - n_blocks (int): Number of blocks (default is 1). - gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). - not_use_best_mse (bool): Whether to use mean squared error (default is False). - dynamic_max_gap (int): The dynamic maximum gap (default is -1). - scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels - have different choices. + keys: + data_type (str): The data type to be used (default is "int"). + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + sym (bool): Whether to use symmetric quantization. (default is None). + bits (int): Number of bits for quantization (default is 4). + group_size (int): Size of the quantization group (default is 128). + sym (bool): Whether symmetric quantization is to be used (default is False). + enable_full_range (bool): Whether to enable full range quantization (default is False). + batch_size (int): Batch size for training (default is 8). + amp (bool): Whether to use automatic mixed precision (default is True). + device: The device to be used for tuning (default is "auto"). + lr_scheduler: The learning rate scheduler to be used. + dataset (str): The default dataset name (default is "NeelNanda/pile-10k"). + enable_quanted_input (bool): Whether to use the output of the previous quantized block as + the input for the current block (default is True). + enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True). + lr (float): The learning rate (default is None, will be set to 1.0/iters). + minmax_lr (float): The learning rate for min-max tuning + (default is None, it will be set to lr automatically). + low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). + iters (int): Number of iterations (default is 200). + seqlen (int): Data length of the sequence for tuning (default is 2048). + nsamples (int): Number of samples (default is 128). + sampler (str): The sampling method (default is "rand"). + seed (int): The random seed (default is 42). + nblocks (int): Number of blocks (default is 1). + gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). + not_use_best_mse (bool): Whether to use mean squared error (default is False). + dynamic_max_gap (int): The dynamic maximum gap (default is -1). + data_type (str): The data type to be used (default is "int"). + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels + have different choices. + multimodal(bool): Enable multimodal model quantization, (default is "False"). + act_bits (int): Number of bits for activation quantization. Default is 32. + act_group_size (int): Group size for activation quantization. Default is None. + act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. + + Returns: + The quantized model. """ super().__init__(quant_config) self.tokenizer = None @@ -109,15 +135,21 @@ def __init__( self.low_gpu_mem_usage = low_gpu_mem_usage self.iters = iters self.seqlen = seqlen - self.n_samples = n_samples + self.nsamples = nsamples self.sampler = sampler self.seed = seed - self.n_blocks = n_blocks + self.nblocks = nblocks self.gradient_accumulate_steps = gradient_accumulate_steps self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap self.data_type = data_type self.scale_dtype = scale_dtype + self.multimodal = multimodal + self.act_bits = act_bits + self.act_group_size = act_group_size + self.act_sym = act_sym + self.act_dynamic = act_dynamic + self.low_cpu_mem_usage = low_cpu_mem_usage def prepare(self, model: torch.nn.Module, *args, **kwargs): """Prepares a given model for quantization. @@ -137,7 +169,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): model=model, tokenizer=None, dataset=dataloader, - weight_config=self.quant_config or {}, + layer_config=self.quant_config or {}, enable_full_range=self.enable_full_range, batch_size=self.batch_size, amp=self.amp, @@ -150,15 +182,21 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): low_gpu_mem_usage=self.low_gpu_mem_usage, iters=self.iters, seqlen=self.seqlen, - n_samples=self.n_samples, + nsamples=self.nsamples, sampler=self.sampler, seed=self.seed, - n_blocks=self.n_blocks, + nblocks=self.nblocks, gradient_accumulate_steps=self.gradient_accumulate_steps, not_use_best_mse=self.not_use_best_mse, dynamic_max_gap=self.dynamic_max_gap, data_type=self.data_type, scale_dtype=self.scale_dtype, + multimodal=self.multimodal, + act_bits=self.act_bits, + act_group_size=self.act_group_size, + act_sym=self.act_sym, + act_dynamic=self.act_dynamic, + low_cpu_mem_usage=self.low_cpu_mem_usage, ) model, weight_config = rounder.quantize() model.autoround_config = weight_config @@ -166,7 +204,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): return model -def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512): +def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=128): """Generate a DataLoader for calibration using specified parameters. Args: @@ -186,6 +224,6 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42 from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401 dataloader = get_dataloader( - tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples + tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, nsamples=nsamples ) return dataloader diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index e582363126c..0f65633c58e 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -572,6 +572,10 @@ def autoround_quantize_entry( "bits": quant_config.bits, "sym": quant_config.use_sym, "group_size": quant_config.group_size, + "act_bits": quant_config.act_bits, + "act_group_size": quant_config.act_group_size, + "act_sym": quant_config.act_sym, + "act_dynamic": quant_config.act_dynamic, } enable_full_range = quant_config.enable_full_range batch_size = quant_config.batch_size @@ -583,14 +587,16 @@ def autoround_quantize_entry( low_gpu_mem_usage = quant_config.low_gpu_mem_usage iters = quant_config.iters seqlen = quant_config.seqlen - n_samples = quant_config.n_samples + nsamples = quant_config.nsamples sampler = quant_config.sampler seed = quant_config.seed - n_blocks = quant_config.n_blocks + nblocks = quant_config.nblocks gradient_accumulate_steps = quant_config.gradient_accumulate_steps not_use_best_mse = quant_config.not_use_best_mse dynamic_max_gap = quant_config.dynamic_max_gap scale_dtype = quant_config.scale_dtype + multimodal = quant_config.multimodal + low_cpu_mem_usage = quant_config.use_layer_wise kwargs.pop("example_inputs") @@ -608,14 +614,16 @@ def autoround_quantize_entry( low_gpu_mem_usage=low_gpu_mem_usage, iters=iters, seqlen=seqlen, - n_samples=n_samples, + nsamples=nsamples, sampler=sampler, seed=seed, - n_blocks=n_blocks, + nblocks=nblocks, gradient_accumulate_steps=gradient_accumulate_steps, not_use_best_mse=not_use_best_mse, dynamic_max_gap=dynamic_max_gap, scale_dtype=scale_dtype, + multimodal=multimodal, + low_cpu_mem_usage=low_cpu_mem_usage, ) model = quantizer.execute(model=model, mode=mode, *args, **kwargs) model.qconfig = configs_mapping diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 047fb8a2c13..49238ec2ee5 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -735,8 +735,8 @@ class AutoRoundConfig(TorchBaseConfig): "minmax_lr", "iters", "seqlen", - "n_samples", - "n_blocks", + "nsamples", + "nblocks", "gradient_accumulate_steps", "not_use_best_mse", "dynamic_max_gap", @@ -750,6 +750,10 @@ def __init__( use_sym: bool = False, group_size: int = 128, # AUTOROUND + act_bits: int = 32, + act_group_size: int = None, + act_sym: bool = None, + act_dynamic: bool = True, enable_full_range: bool = False, batch_size: int = 8, lr_scheduler=None, @@ -759,16 +763,17 @@ def __init__( minmax_lr: float = None, low_gpu_mem_usage: bool = True, iters: int = 200, - seqlen: int = 512, - n_samples: int = 512, + seqlen: int = 2048, + nsamples: int = 128, sampler: str = "rand", seed: int = 42, - n_blocks: int = 1, + nblocks: int = 1, gradient_accumulate_steps: int = 1, not_use_best_mse: bool = False, dynamic_max_gap: int = -1, scale_dtype: str = "fp16", use_layer_wise: bool = False, + multimodal: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): """Init AUTOROUND weight-only quantization config. @@ -778,6 +783,10 @@ def __init__( bits (int): Number of bits used to represent weights, default is 4. use_sym (bool): Indicates whether weights are symmetric, default is False. group_size (int): Size of weight groups, default is 128. + act_bits (int): Number of bits for activation quantization. Default is 32. + act_group_size (int): Group size for activation quantization. Default is None. + act_sym (bool): Whether to use symmetric activation quantization. Default is None. + act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. enable_full_range (bool): Whether to enable full range quantization (default is False). batch_size (int): Batch size for training (default is 8). lr_scheduler: The learning rate scheduler to be used. @@ -788,21 +797,27 @@ def __init__( low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True). iters (int): Number of iterations (default is 200). seqlen (int): Length of the sequence. - n_samples (int): Number of samples (default is 512). + nsamples (int): Number of samples (default is 512). sampler (str): The sampling method (default is "rand"). seed (int): The random seed (default is 42). - n_blocks (int): Number of blocks (default is 1). + nblocks (int): Number of blocks (default is 1). gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). not_use_best_mse (bool): Whether to use mean squared error (default is False). dynamic_max_gap (int): The dynamic maximum gap (default is -1). scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels have different choices. + use_layer_wise (bool): Enables quantize model per layer. Defaults to False. + multimodal(bool): Enable multimodal model quantization, (default is "False"). """ super().__init__(white_list=white_list) self.dtype = dtype self.bits = bits self.use_sym = use_sym self.group_size = group_size + self.act_bits = act_bits + self.act_group_size = act_group_size + self.act_sym = act_sym + self.act_dynamic = act_dynamic self.enable_full_range = enable_full_range self.batch_size = batch_size self.lr_scheduler = lr_scheduler @@ -813,15 +828,16 @@ def __init__( self.low_gpu_mem_usage = low_gpu_mem_usage self.iters = iters self.seqlen = seqlen - self.n_samples = n_samples + self.nsamples = nsamples self.sampler = sampler self.seed = seed - self.n_blocks = n_blocks + self.nblocks = nblocks self.gradient_accumulate_steps = gradient_accumulate_steps self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap self.scale_dtype = scale_dtype self.use_layer_wise = use_layer_wise + self.multimodal = multimodal self._post_init() @classmethod @@ -1526,7 +1542,7 @@ def get_woq_tuning_config() -> list: the list of WOQ quant config. """ RTN_G32ASYM = RTNConfig(use_sym=False, group_size=32) - AUTO_ROUND_CONFIG = AutoRoundConfig(use_sym=False, group_size=32) + AUTO_ROUND_CONFIG = AutoRoundConfig(use_sym=False, group_size=32, seqlen=512) GPTQ_G32ASYM = GPTQConfig(use_sym=False, group_size=32) AWQ_G32ASYM = AWQConfig(use_sym=False, group_size=32) return [RTN_G32ASYM, AUTO_ROUND_CONFIG, GPTQ_G32ASYM, AWQ_G32ASYM] diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index f1539b072b7..f5351656595 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -49,7 +49,7 @@ def setup_class(self): tokenizer = transformers.AutoTokenizer.from_pretrained( "hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True ) - self.dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=10) + self.dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=10) self.label = self.gptj(self.inp)[0] def teardown_class(self): @@ -61,7 +61,7 @@ def setup_method(self, method): @pytest.mark.parametrize("quant_lm_head", [True, False]) def test_autoround(self, quant_lm_head): fp32_model = copy.deepcopy(self.gptj) - quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") + quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32") if quant_lm_head is False: quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") @@ -83,7 +83,7 @@ def test_autoround(self, quant_lm_head): def test_autoround_with_quantize_API(self): gpt_j_model = copy.deepcopy(self.gptj) - quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") + quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32") quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") @@ -101,7 +101,7 @@ def test_autoround_with_quantize_API(self): def test_save_and_load(self): fp32_model = copy.deepcopy(self.gptj) - quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") + quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32") # quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") @@ -133,7 +133,7 @@ def test_conv1d(self): text = "Replace me by any text you'd like." encoded_input = tokenizer(text, return_tensors="pt") out1 = model(**encoded_input)[0] - quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") + quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32") model = prepare(model=model, quant_config=quant_config) run_fn(model, self.dataloader) q_model = convert(model) diff --git a/test/3x/torch/requirements.txt b/test/3x/torch/requirements.txt index bdf99d92cf0..cc1ee22fe83 100644 --- a/test/3x/torch/requirements.txt +++ b/test/3x/torch/requirements.txt @@ -1,4 +1,4 @@ -auto_round +auto_round @ git+https://github.com/intel/auto-round.git@24b2e74070f2b4e6f26ff069ec75af74cf5b177c expecttest intel_extension_for_pytorch numpy