From 501440ab560056e2e3a1a75c922361ebf614fc04 Mon Sep 17 00:00:00 2001 From: wenhuach21 <108330088+wenhuach21@users.noreply.github.com> Date: Thu, 29 Jun 2023 13:58:27 +0800 Subject: [PATCH] PT supports weight only (#1009) * add weight only naive code Signed-off-by: wenhuach21 * fix 1 bit issue Signed-off-by: wenhuach21 * fix bug Signed-off-by: wenhuach21 * align with torch * fix some issues * add weight_only approach Signed-off-by: Xin He * add stats dump API Signed-off-by: Xin He * add save API Signed-off-by: Xin He * add quant type Signed-off-by: yiliu30 * add init value for weight only Signed-off-by: yiliu30 * add debug mode for weight_config Signed-off-by: Xin He * fix bug Signed-off-by: Xin He * fix bug Signed-off-by: Xin He * add group_size=-1 and enhance UT Signed-off-by: Xin He * add doc, WIP * fix typos Signed-off-by: wenhuach21 * fix spelling Signed-off-by: Xin He * support no-list config, change bit to bits and doc Signed-off-by: Xin He * fix doc Signed-off-by: Xin He * fix doc Signed-off-by: Xin He * fix typo Signed-off-by: Xin He * fix docstring Signed-off-by: Xin He * fix load bug Signed-off-by: Xin He * add docsting; update doc Signed-off-by: wenhuach21 * add readme link Signed-off-by: Xin He --------- Signed-off-by: wenhuach21 Signed-off-by: Xin He Signed-off-by: yiliu30 Co-authored-by: Xin He Co-authored-by: yiliu30 --- .../scripts/codeScan/pyspelling/inc_dict.txt | 9 +- README.md | 3 + docs/source/quantization_weight_only.md | 72 ++++++ neural_compressor/adaptor/pytorch.py | 205 +++++++++++++++++- neural_compressor/adaptor/pytorch_cpu.yaml | 17 ++ neural_compressor/adaptor/torch_utils/util.py | 45 +++- .../adaptor/torch_utils/weight_only.py | 130 +++++++++++ neural_compressor/config.py | 68 +++++- neural_compressor/model/torch_model.py | 10 +- neural_compressor/strategy/strategy.py | 15 +- neural_compressor/strategy/utils/constant.py | 4 +- .../strategy/utils/tuning_space.py | 5 +- neural_compressor/strategy/utils/utility.py | 65 +++++- neural_compressor/utils/pytorch.py | 8 + .../test_weight_only_adaptor.py | 142 ++++++++++++ .../test_weight_only_quantization.py | 41 ++++ 16 files changed, 813 insertions(+), 26 deletions(-) create mode 100644 docs/source/quantization_weight_only.md create mode 100644 neural_compressor/adaptor/torch_utils/weight_only.py create mode 100644 test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py create mode 100644 test/quantization/test_weight_only_quantization.py diff --git a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt index 551771fdd28..1e464325a94 100644 --- a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt +++ b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt @@ -2650,4 +2650,11 @@ CCE CCFF FFFFFF classDef -bdf \ No newline at end of file +bdf +bmm +AWQ +GPTQ +RTN +awq +gptq +percdamp \ No newline at end of file diff --git a/README.md b/README.md index a10884ed239..aca460e916a 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,9 @@ q_model = fit( Distillation for Quantization SmoothQuant + + Weight-Only Quantization + diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md new file mode 100644 index 00000000000..24c6d91b2a3 --- /dev/null +++ b/docs/source/quantization_weight_only.md @@ -0,0 +1,72 @@ +Weight Only Quantization +===== + +1. [Introduction](#introduction) + +2. [Supported Framework Model Matrix](#supported-framework-model-matrix) + +3. [Examples](#examples) + + +## Introduction + +As large language models (LLMs) become more prevalent, there is a growing need for new and improved quantization methods that can meet the computational demands of these modern architectures while maintaining the accuracy. Compared to normal quantization like W8A8, weight only quantization is probably a better trade-off to balance the performance and the accuracy, since we will see below that the bottleneck of deploying LLMs is the memory bandwidth and normally weight only quantization could lead to better accuracy. + +Model inference: Roughly speaking , two key steps are required to get the model's result. The first one is moving the model from the memory to the cache piece by piece, in which, memory bandwidth $B$ and parameter count $P$ are the key factors, theoretically the time cost is $P*4 /B$. The second one is computation, in which, the device's computation capacity $C$ measured in FLOPS and the forward FLOPs $F$ play the key roles, theoretically the cost is $F/C$. + +Text generation: The most famous application of LLMs is text generation, which predicts the next token/word based on the inputs/context. To generate a sequence of texts, we need to predict them one by one. In this scenario, $F\approx P$ if some operations like bmm are ignored and past key values have been saved. However, the $C/B$ of the modern device could be to **100X,** that makes the memory bandwidth as the bottleneck in this scenario. + +Besides, as mentioned in many papers[1][2], activation quantization is the main reason to cause the accuracy drop. So for text generation task, weight only quantization is a preferred option in most cases. + + +## Supported Framework Model Matrix + +| Framework | Weight-only | +| :---: | :---:| +| PyTorch | ✔ | +| ONNX | WIP | + + +## Examples + +The quantization capability of weight-only approach is as follows: +| Config | Capability | +| :---: | :---:| +| bits | [1-8] | +| group_size | [-1, 1-N] | +| scheme | ['asym', 'sym'] | +| algorithm | ['RTN', ] | + +**Note**: `group_size=-1` indicates the per-channel quantization per output channel. `group_size=[1-N]` indicates splitting the input channel elements per group_size. + +The use case code is as follows: +```python +conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bit': 8, # 1-8 bit + 'group_size': -1, # -1 (per-channel) + 'scheme': 'sym', + 'algorithm': 'RTN', + }, + }, + }, + ### AWQ and GPTQ is WIP + # recipes={ + # 'gptq_args':{'percdamp': 0.01}, + # 'awq_args':{'alpha': 'auto', 'clip': True}, + # }, +) +q_model = quantization.fit(model, conf, eval_func=eval_func) +q_model.save('saved_results') +``` + +The saved_results folder contains two files: `best_model.pt` and `weight_config.json`, and the generated q_model is a fake quantized model. + +## Reference + +[1]Xiao, Guangxuan, et al. "Smoothquant: Accurate and efficient post-training quantization for large language models." arXiv preprint arXiv:2211.10438 (2022). + +[2]Wei, Xiuying, et al. "Outlier suppression: Pushing the limit of low-bit transformer language models." arXiv preprint arXiv:2209.13325 (2022). diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 499f56a9c02..9bce77d3a83 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -832,6 +832,8 @@ def __init__(self, framework_specific_info): else: self.q_mapping = \ tq.quantization_mappings.get_default_dynamic_quant_module_mappings() + elif framework_specific_info['approach'] == "post_training_weight_only": + pass else: if not self.benchmark: assert False, "Unsupport approach: {}".format(self.approach) @@ -1107,7 +1109,28 @@ def _get_quantizable_ops(self, model): q_capability['opwise'][q_op].append(fp32_config) if fp32_config not in q_capability['optypewise'][q_op[1]]: q_capability['optypewise'][q_op[1]].append(fp32_config) + elif self.approach == "post_training_weight_only": + capability_pair = [(self.query_handler.get_quantization_capability('weight_only_integer'), 'weight_only')] + fp32_config = {'activation': {'dtype': 'fp32'}, 'weight': {'dtype': 'fp32'}} + for pair in capability_pair: + capability, mode = pair + for q_op in quantizable_ops: + if q_op not in q_capability['opwise']: + q_capability['opwise'][q_op] = [] + if q_op[1] not in q_capability['optypewise']: + q_capability['optypewise'][q_op[1]] = [] + op_cfg = copy.deepcopy(capability[q_op[1]]) if q_op[1] in capability \ + else copy.deepcopy(capability['default']) + op_cfg['activation']['quant_mode'] = mode + if op_cfg not in q_capability['opwise'][q_op]: + q_capability['opwise'][q_op].append(op_cfg) + q_capability['opwise'][q_op].append(fp32_config) + if op_cfg not in q_capability['optypewise'][q_op[1]]: + q_capability['optypewise'][q_op[1]].append(op_cfg) + q_capability['optypewise'][q_op[1]].append(fp32_config) else: + if 'weight_only_integer' in quant_datatypes: # TODO: need to enhance + quant_datatypes.remove('weight_only_integer') for datatype in quant_datatypes: if self.approach == "post_training_dynamic_quant": capability_pair = [ @@ -4434,6 +4457,186 @@ def calculate_op_sensitivity(self, model, dataloader, tune_cfg, output_op_names, return ordered_ops +@adaptor_registry +class PyTorchWeightOnlyAdaptor(TemplateAdaptor): + """Adaptor of PyTorch framework, all PyTorch API is in this class. + + Args: + framework_specific_info (dict): dictionary of tuning configure from yaml file. + """ + def __init__(self, framework_specific_info): + super(PyTorchWeightOnlyAdaptor, self).__init__(framework_specific_info) + self.tune_cfg = None + if self.device == "cpu": + query_config_file = "pytorch_cpu.yaml" + else: # pragma: no cover + assert False, "Unsupport this device {}".format(self.device) + self.query_handler = PyTorchQuery( + local_config_file=os.path.join(os.path.dirname(__file__), query_config_file)) + + self.white_list = [torch.nn.Linear, torch.nn.Conv2d] + # Contains parameters for algorithms such as AWQ, GPTQ, etc. + self.recipes = framework_specific_info['recipes'] + self.optype_statistics = None + + @dump_elapsed_time("Pass quantize model") + def quantize(self, tune_cfg, model, dataloader, q_func=None): + """Execute the quantize process on the specified model. + + Args: + tune_cfg (dict): quantization config. + model (object): model need to do quantization. + dataloader (object): calibration dataset. + q_func (objext, optional): training function for quantization aware training mode. + + Returns: + (object): quantized model + """ + assert isinstance(model._model, torch.nn.Module), \ + "The model passed in is not the instance of torch.nn.Module" + if self.performance_only: + q_model = model + else: + try: + q_model = copy.deepcopy(model) + except Exception as e: # pragma: no cover + logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format( + repr(e))) + q_model = model + + # For tensorboard display + self.tune_cfg = tune_cfg + self.tune_cfg["approach"] = self.approach + self.tune_cfg["framework"] = "pytorch" + assert self.approach=='post_training_weight_only', "Please make sure the approach is weight_only" + + q_model._model = self.rtn_quantize(q_model._model, tune_cfg) + q_model._model = self.gptq_quantize(q_model._model, tune_cfg, dataloader) + q_model._model = self.awq_quantize(q_model._model, tune_cfg, dataloader) + + q_model.q_config = copy.deepcopy(self.tune_cfg) + q_model.is_quantized = True + self._dump_model_op_stats(q_model._model, q_model.q_config) + return q_model + + def rtn_quantize(self, model, tune_cfg): + logger.debug("quantizing with the round-to-nearest algorithm") + from .torch_utils.weight_only import rtn_quantize + from .torch_utils.util import fetch_module + for key, config in tune_cfg['op'].items(): + op_name, op_type = key + if config['weight']['dtype'] == 'fp32': + continue + else: + num_bits = config['weight']['bits'] + group_size = config['weight']['group_size'] + scheme = config['weight']['scheme'] + algorithm = config['weight']['algorithm'] + if algorithm != 'RTN': + continue + m = fetch_module(model, op_name) + rtn_quantize(m, num_bits, group_size, scheme) + return model + + def gptq_quantize(self, model, tune_cfg, dataloader): + logger.debug("quantizing with the GPTQ algorithm") + if 'gptq_args' in self.recipes: + percdamp = self.recipes['gptq_args'].get('percdamp', 0.01) + # GPTQ(model, dataloader, w_bit, group_size, percdamp=0.01) + # TODO: implementation + return model + + def awq_quantize(self, model, tune_cfg, dataloader): + logger.debug("quantizing with the AWQ algorithm") + # set default value if has args in recipes, else we use function + if 'awq_args' in self.recipes: + alpha = self.recipes['awq_args'].get('alpha', 'auto') + # AWQ(model, dataloader, w_bit, group_size, alpha='auto', clip=True) + # TODO: implementation + return model + + def _dump_model_op_stats(self, model, tune_cfg): + """This is a function to dump quantizable ops of model to user. + Args: + model (object): input model + tune_cfg (dict): quantization config + Returns: + None + """ + res = {} + # collect all dtype info and build empty results with existing op_type + dtype_set = set() + for op, config in tune_cfg['op'].items(): + op_type = op[1] + if not config['weight']['dtype'] == 'fp32': + num_bits = config['weight']['bits'] + group_size = config['weight']['group_size'] + dtype_str = "A32W{}G{}".format(num_bits, group_size) + dtype_set.add(dtype_str) + dtype_set.add('FP32') + dtype_list = list(dtype_set) + dtype_list.sort() + for op, config in tune_cfg['op'].items(): + op_type = op[1] + if op_type not in res.keys(): + res[op_type] = {dtype: 0 for dtype in dtype_list} + + # fill in results with op_type and dtype + for op, config in tune_cfg['op'].items(): + if config['weight']['dtype'] == 'fp32': + res[op_type]['FP32'] += 1 + else: + num_bits = config['weight']['bits'] + group_size = config['weight']['group_size'] + dtype_str = "A32W{}G{}".format(num_bits, group_size) + res[op_type][dtype_str] += 1 + + # update stats format for dump. + field_names = ["Op Type", "Total"] + field_names.extend(dtype_list) + output_data = [] + for op_type in res.keys(): + field_results = [op_type, sum(res[op_type].values())] + field_results.extend([res[op_type][dtype] for dtype in dtype_list]) + output_data.append(field_results) + + Statistics(output_data, + header='Mixed Precision Statistics', + field_names=field_names).print_stat() + self.optype_statistics = field_names, output_data + + + def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops): + """This is a helper function for `query_fw_capability`, + and it will get all quantizable ops from model. + + Args: + model (object): input model + prefix (string): prefix of op name + quantizable_ops (list): list of quantizable ops from model include op name and type. + + Returns: + None + """ + + module_dict = dict(model.named_modules()) + for op_name, child in module_dict.items(): + if type(child) in self.white_list: + quantizable_ops.append((op_name, str(child.__class__.__name__))) + + @dump_elapsed_time("Pass query framework capability") + def query_fw_capability(self, model): + """This is a helper function to get all quantizable ops from model. + + Args: + model (object): input model which is Neural Compressor model + + Returns: + q_capability (dictionary): tuning capability for each op from model. + """ + self.pre_optimized_model = model + return self._get_quantizable_ops(model.model) + class PyTorchQuery(QueryBackendCapability): def __init__(self, local_config_file=None): super().__init__() @@ -4500,7 +4703,7 @@ def get_quant_datatypes(self): # TODO to handle other data types such FP8, FP8E4M3 datatype_lst = [] for key in self.cur_config: - if key.startswith('int'): + if key.startswith('int') or key == 'weight_only_integer': datatype_lst.append(key) return datatype_lst diff --git a/neural_compressor/adaptor/pytorch_cpu.yaml b/neural_compressor/adaptor/pytorch_cpu.yaml index f544aaf88b5..d62bd63d44c 100644 --- a/neural_compressor/adaptor/pytorch_cpu.yaml +++ b/neural_compressor/adaptor/pytorch_cpu.yaml @@ -259,6 +259,23 @@ }, } + weight_only_integer: &cap_weight_only_integer { + 'Linear': &cap_weight_only_integer_linear { # only Linear now + 'weight': { + 'dtype': ['int'], # no need to care uint + 'bits': [4, 1, 2, 3, 5, 6, 7, 8], # [1-8], # 4 + # group_size=-1 means per-channel, others means per-group + 'group_size': [32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024], # [1-inf], # 32 + 'scheme': ['sym', 'asym'], # sym, no ZP + 'algorithm': ['RTN'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order + }, + 'activation': { + 'dtype': ['fp32'], + }, + }, + 'Conv2d': *cap_weight_only_integer_linear, + } + - version: diff --git a/neural_compressor/adaptor/torch_utils/util.py b/neural_compressor/adaptor/torch_utils/util.py index 0b0b1a62597..59f1be62b65 100644 --- a/neural_compressor/adaptor/torch_utils/util.py +++ b/neural_compressor/adaptor/torch_utils/util.py @@ -915,4 +915,47 @@ def get_op_type_by_name(op_name, quantizable_ops): for pair in quantizable_ops: if pair[0] == op_name: return pair[1] - return None \ No newline at end of file + return None + +def collect_weight_info(q_config): + """collect weight info from q_config for dumping into weight_config.json + + weight_config.json example: + ``` + { + 'fc': { + 'bits': 4, + 'group_size': 128, + 'scheme': 'asym', + 'algorithm': 'RTN' + } + ... + } + ``` + + Args: + q_config (_type_): quantization configue + """ + weight_info = {} + from neural_compressor.utils.logger import level, DEBUG + for op, config in q_config['op'].items(): + op_name, op_type = op + if config['weight']['dtype'] == 'fp32': + weight_info[op_name] = {'dtype': 'fp32'} + else: + if level == DEBUG: + weight_info[op_name] = { + 'dtype': config['weight']['dtype'], + 'bits': config['weight']['bits'], + 'group_size': config['weight']['group_size'], + 'scheme': config['weight']['scheme'], + 'algorithm': config['weight']['algorithm'] + } + else: + weight_info[op_name] = { + 'dtype': config['weight']['dtype'], + 'bits': config['weight']['bits'], + 'group_size': config['weight']['group_size'], + 'scheme': config['weight']['scheme'], + } + return weight_info \ No newline at end of file diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py new file mode 100644 index 00000000000..cd802ace957 --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -0,0 +1,130 @@ +from ...utils import logger +from ...utils.utility import LazyImport + +tqdm = LazyImport("tqdm") +torch = LazyImport("torch") + + +def qdq_weight_asym(weight, num_bits=4): + """quant and dequant tensor with asym schema + :param weight: input weight + :param num_bits: num_bits + :return: qdq weight + """ + maxq = torch.tensor(2 ** num_bits - 1) + zeros = torch.zeros(weight.shape[0], device=weight.device) + wmin = torch.minimum(weight.min(1)[0], zeros) + wmax = torch.maximum(weight.max(1)[0], zeros) + tmp = (wmin == 0) & (wmax == 0) + wmin[tmp] = -1 + wmax[tmp] = +1 + scale = (wmax - wmin) / maxq + zp = torch.round(-wmin / scale) + scale.unsqueeze_(dim=-1) + zp.unsqueeze_(dim=-1) + q = torch.clamp(torch.round(weight / scale) + zp, 0, maxq) + return scale * (q - zp) + + +def qdq_weight_sym(weight, num_bits=4): + """quant and dequant tensor with sym schema + :param weight: input weight + :param num_bits: num_bits + :return: qdq weight + """ + # assert num_bits > 1, "symmetric scheme only supports num_bits > 1" + maxq = torch.tensor(2 ** (num_bits - 1) - 1).to(weight.device) + minq = torch.tensor(-2 ** (num_bits - 1)).to(weight.device) + if num_bits == 1: + maxq = torch.tensor(2 ** (num_bits - 1)) + minq = torch.tensor(2 ** (num_bits - 1) - 1) + + wmax = torch.abs(weight).max(1)[0] + tmp = (wmax == 0) + wmax[tmp] = +1 + scale = wmax / ((maxq - minq) / 2) + scale.unsqueeze_(dim=-1) + q = torch.clamp(torch.round(weight / scale), minq, maxq) + return scale * q + + +def qdq_weight_actor(weight, num_bits, scheme): + """quant and dequant tensor per channel + :param weight: input weight + :param num_bits: num_bits + :param scheme: sym or asym + :return: qdq weight + """ + assert num_bits > 0, "num_bits should be larger than 0" + if scheme == "sym": + return qdq_weight_sym(weight, num_bits) + else: + return qdq_weight_asym(weight, num_bits) + +def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym"): + """quant and dequant tensor with group size + :param weight: input weight + :param num_bits: num_bits + :param group_size: how many elements share one scale/zp + :param scheme: sym or asym + :return: qdq weight + """ + if group_size == -1 or weight.shape[1] < group_size: + return qdq_weight_actor(weight, num_bits, scheme=scheme) + + orig_shape = weight.shape + if weight.shape[1] % group_size == 0: + weight = weight.reshape(-1, group_size) + weight = qdq_weight_actor(weight, num_bits, scheme=scheme) + weight = weight.reshape(orig_shape) + return weight + else: + split_index = weight.shape[1] // group_size * group_size + weight1 = weight[:, :split_index] + weight1 = weight1.reshape(-1, group_size) + weight1 = qdq_weight_actor(weight1, num_bits, scheme=scheme) + weight1 = weight1.reshape(orig_shape[0], split_index) + weight2 = weight[:, split_index:] + weight2 = qdq_weight_actor(weight2, num_bits, scheme=scheme) + weight = torch.cat([weight1, weight2], dim=1) + return weight + + +def rtn_quantize(model, num_bits, group_size=-1, scheme="asym", w_layers_config={}): + """ quant the model with round to nearst method + :param model: torch module + :param num_bits: num bits + :param group_size: how many elements share one scale/zp + :param scheme: sym or asym + :param w_layers_config: specific layer wise configirations {"layer_name":[num_bits,group_size,schema]} + :return: + """ + assert isinstance(model, torch.nn.Module), "only support torch module" + assert num_bits > 0, "bit for weight only should large than zero!" + ##supported_layers = ['Linear', 'Conv2d'] + supported_layers = ['Linear'] + for n, m in model.named_modules(): + if m.__class__.__name__ not in supported_layers: + continue + if n in w_layers_config: # pragma: no cover + num_bits = w_layers_config[n][0] + group_size = w_layers_config[n][1] + scheme = w_layers_config[n][2] + logger.debug(f"RTN quantized module:{n, m}") + logger.debug(f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, scheme={scheme}") + if num_bits <= 0: + logger.info(f"skip {n}") + continue + if m.__class__.__name__ == "Conv2d": + weight = m.weight + orig_shape = weight.shape + weight = weight.permute(1, 0, 2, 3) + weight = weight.reshape(weight.shape[0], -1) + else: + weight = m.weight + q_weight = quant_weight(weight, num_bits, group_size, scheme) + if m.__class__.__name__ == "Conv2d": + q_weight = q_weight.reshape(orig_shape[1], orig_shape[0], orig_shape[2], orig_shape[3]) + q_weight = q_weight.permute(1, 0, 2, 3) + m.weight.data.copy_(q_weight) + return model diff --git a/neural_compressor/config.py b/neural_compressor/config.py index 7e05f6e8ca0..4528926ea71 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -17,7 +17,7 @@ """Configs for Neural Compressor 2.x.""" import datetime import logging -from schema import Schema, And, Optional +from schema import Schema, And, Optional, Or from .utils import alias_param logger = logging.getLogger("neural_compressor") @@ -28,7 +28,7 @@ "auto": "post_training_auto_quant", "dynamic": "post_training_dynamic_quant", "static": "post_training_static_quant", - "qat": "quant_aware_training", + "weight_only": "post_training_weight_only", } @@ -44,8 +44,15 @@ list, lambda s: all(i in ['int8', 'uint8', 'fp32', 'bf16', 'fp16'] for i in s)), Optional('algorithm'): And( + list, # TODO: allow AWQ+GPTQ algo + lambda s: all(i in ['minmax', 'RTN', 'AWQ', 'GPTQ',] for i in s)), + Optional('bits'): And( list, - lambda s: all(i in ['minmax'] for i in s))}, + lambda s: all(0 < i <= 8 and type(i)==int for i in s)), + Optional('group_size'): And( + list, + lambda s: all(i >= -1 and i != 0 and type(i)==int for i in s)), + }, Optional('activation', default=None): { Optional('granularity'): And( list, @@ -58,7 +65,9 @@ lambda s: all(i in ['int8', 'uint8', 'fp32', 'bf16', 'fp16', 'None'] for i in s)), Optional('algorithm'): And( list, - lambda s: all(i in ['minmax', 'kl', 'placeholder', 'percentile'] for i in s))}}) + lambda s: all(i in ['minmax', 'kl', 'placeholder', 'percentile'] for i in s)) + } +}) def _check_value(name, src, supported_type, supported_value=[]): @@ -91,8 +100,28 @@ def datatype(self, datatype): return True +def _list_wrapper(config): + """A help function to wrapper custom op_type_dict and op_name_dict items with list. + + Args: + config (dict): op_type_dict/op_name_dict. + for example: {'weight': {'dtype': 'fp32'}, ...} + + Returns: + config: new_config wrapped with list + for example: {'weight': {'dtype': ['fp32']}, ...} + """ + for k, v in config.items(): + # k = weight/activation + for m, n in v.items(): + # m = dtype, bits, etc. + if not isinstance(n, list): + v[m] = [n] + return config + + class DotDict(dict): - """access yaml using attributes instead of using the dictionary notation. + """Access yaml using attributes instead of using the dictionary notation. Args: value (dict): The dict object to access. @@ -692,7 +721,7 @@ class _BaseQuantizationConfig: } }, } - reduce_range: Whether use 7 bit to quantization. + reduce_range: Whether use 7 bits to quantization. example_inputs: Used to trace PyTorch model with torch.jit/torch.fx. excluded_precisions: Precisions to be excluded, Default value is empty list. Neural compressor enable the mixed precision with fp32 + bf16 + int8 by default. @@ -796,6 +825,18 @@ def smooth_quant_args(val=None): else: return {} + def awq_args(val=None): + if val is not None: + return _check_value("awq_args", val, dict) + else: + return {} + + def gptq_args(val=None): + if val is not None: + return _check_value("gptq_args", val, dict) + else: + return {} + def fast_bias_correction(val=None): if val is not None: return _check_value("fast_bias_correction", val, bool) @@ -868,7 +909,9 @@ def dedicated_qdq_pair(val=None): "pre_post_process_quantization": pre_post_process_quantization, "add_qdq_pair_to_weight": add_qdq_pair_to_weight, "optypes_to_exclude_output_quant": optypes_to_exclude_output_quant, - "dedicated_qdq_pair": dedicated_qdq_pair + "dedicated_qdq_pair": dedicated_qdq_pair, + "awq_args": awq_args, + "gptq_args": gptq_args, } self._recipes = {} for k in RECIPES.keys(): @@ -934,6 +977,7 @@ def op_name_dict(self, op_name_dict): self._op_name_dict = op_name_dict elif isinstance(op_name_dict, dict): for k, v in op_name_dict.items(): + v = _list_wrapper(v) ops_schema.validate(v) self._op_name_dict = op_name_dict else: @@ -950,6 +994,7 @@ def op_type_dict(self, op_type_dict): self._op_type_dict = op_type_dict elif isinstance(op_type_dict, dict): for k, v in op_type_dict.items(): + v = _list_wrapper(v) ops_schema.validate(v) self._op_type_dict = op_type_dict else: @@ -1060,7 +1105,8 @@ class PostTrainingQuantConfig(_BaseQuantizationConfig): quant_format: Support 'default', 'QDQ' and 'QOperator', only required in ONNXRuntime. inputs: Inputs of model, only required in tensorflow. outputs: Outputs of model, only required in tensorflow. - approach: Post-Training Quantization method. Neural compressor support 'static', 'dynamic' and 'auto' method. + approach: Post-Training Quantization method. Neural compressor support 'static', 'dynamic', + 'weight_only' and 'auto' method. Default value is 'static'. For strategy 'basic', 'auto' method means neural compressor will quantize all OPs support PTQ static or PTQ dynamic. For OPs supporting both PTQ static and PTQ dynamic, @@ -1097,7 +1143,7 @@ class PostTrainingQuantConfig(_BaseQuantizationConfig): } }, } - reduce_range: Whether use 7 bit to quantization. + reduce_range: Whether use 7 bits to quantization. excluded_precisions: Precisions to be excluded, Default value is empty list. Neural compressor enable the mixed precision with fp32 + bf16 + int8 by default. If you want to disable bf16 data type, you can specify excluded_precisions = ['bf16]. @@ -1176,7 +1222,7 @@ def approach(self, approach): approach = 'static' if 'dynamic' in approach: approach = 'dynamic' - if _check_value("approach", approach, str, ["static", "dynamic", "auto"]): + if _check_value("approach", approach, str, ["static", "dynamic", "auto", "weight_only"]): self._approach = QUANTMAPPING[approach] @property @@ -1225,7 +1271,7 @@ class QuantizationAwareTrainingConfig(_BaseQuantizationConfig): } }, } - reduce_range: Whether use 7 bit to quantization. + reduce_range: Whether use 7 bits to quantization. model_name: The name of the model. Default value is empty. excluded_precisions: Precisions to be excluded, Default value is empty list. Neural compressor enable the mixed precision with fp32 + bf16 + int8 by default. diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index f90fc0c444a..30841775e35 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -313,7 +313,15 @@ def save(self, root=None): try: stat_dict = self._model.state_dict() if self.q_config: - stat_dict['best_configure'] = self.q_config + if self.q_config['approach'] == 'post_training_weight_only': + from neural_compressor.adaptor.torch_utils.util import collect_weight_info + weight_config_path = os.path.join(root, "weight_config.json") + weight_config = collect_weight_info(self.q_config) + with open(weight_config_path, 'w') as f: + json.dump(weight_config, f, indent = 4) + f.close() + else: + stat_dict['best_configure'] = self.q_config torch.save(stat_dict, os.path.join(root, "best_model.pt")) logger.info("Save config file and weights of quantized model to {}.".format(root)) except IOError as e: # pragma: no cover diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index 3d60ecd7a63..3624b342286 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -53,7 +53,7 @@ from .utils.tuning_space import TuningSpace from .utils.tuning_structs import OpTuningConfig from .utils.constant import FALLBACK_RECIPES_SET -from .utils.utility import build_slave_faker_model +from .utils.utility import build_slave_faker_model, quant_options @@ -139,6 +139,7 @@ def __init__(self, self.model = model self.conf = conf self.config = self._initialize_config(conf) + self._set_quant_type(self.config) self.history_path = self._create_path(options.workspace, './history.snapshot') self.deploy_path = self._create_path(options.workspace, 'deploy.yaml') self.calib_dataloader = q_dataloader @@ -301,6 +302,11 @@ def algo_scheduler(self, value): """ self._algo_scheduler = value + def _set_quant_type(self, config): + if config.approach == 'post_training_weight_only': + quant_options.quant_type = 3 + # TODO for future usage(other quantization type) + def _initialize_algo_scheduler(self): algo_scheduler = AlgorithmScheduler(self.config.recipes) # reuse the calibration iteration @@ -1021,7 +1027,8 @@ def initial_tuning_cfg(self): quant_mode_wise_items (OrderedDict): key is quant_mode/precision; value is item list. initial_op_tuning_cfg (OrderedDict): key is (op_name, op_type); value is the initialized tuning config. """ - from .utils.constant import auto_query_order, static_query_order, dynamic_query_order + from .utils.constant import auto_query_order, static_query_order, dynamic_query_order, \ + weight_only_query_order from .utils.tuning_space import initial_tuning_cfg_with_quant_mode if self.config.approach == 'post_training_auto_quant': query_order = auto_query_order @@ -1029,6 +1036,8 @@ def initial_tuning_cfg(self): query_order = dynamic_query_order elif self.config.approach == 'post_training_static_quant': query_order = static_query_order + elif self.config.approach == 'post_training_weight_only': + query_order = weight_only_query_order elif self.config.approach == 'quant_aware_training': query_order = auto_query_order @@ -1260,6 +1269,8 @@ def _set_framework_info(self, q_dataloader, q_func=None): {"default_qconfig": self.config.op_name_dict['default_qconfig']}) framework_specific_info.update({"q_func": q_func}) framework_specific_info.update({"example_inputs": self.config.example_inputs}) + if self.config.approach =='post_training_weight_only': + framework = 'pytorchweightonly' # use specific adaptor for weight_only approach return framework, framework_specific_info def _set_objectives(self): diff --git a/neural_compressor/strategy/utils/constant.py b/neural_compressor/strategy/utils/constant.py index 3538241cd52..771e27cbeda 100644 --- a/neural_compressor/strategy/utils/constant.py +++ b/neural_compressor/strategy/utils/constant.py @@ -22,7 +22,8 @@ LOWER_BIT_LIST = ['int4'] TUNING_ITEMS_LST = [('activation','scheme'), ('activation','algorithm'), ('activation','granularity'), - ('weight','scheme'), ('weight','algorithm'), ('weight','granularity'), 'sampling_size'] + ('weight','scheme'), ('weight','algorithm'), ('weight','granularity'), + ('weight','bits'), ('weight','group_size'), 'sampling_size'] PRECISION_SET_V2_0 = {'fp32', 'bf16'} @@ -30,6 +31,7 @@ static_query_order = ['static', 'bf16', 'fp16', 'fp32'] dynamic_query_order = ['dynamic', 'bf16', 'fp16', 'fp32'] auto_query_order_o0 = ['bf16', 'fp16', 'fp32', 'static', 'dynamic'] +weight_only_query_order = ['weight_only', 'fp32'] FALLBACK_RECIPES_SET = {'first_conv_or_matmul_quantization', 'last_conv_or_matmul_quantization', \ diff --git a/neural_compressor/strategy/utils/tuning_space.py b/neural_compressor/strategy/utils/tuning_space.py index bf322ee80b7..330ba73b07d 100644 --- a/neural_compressor/strategy/utils/tuning_space.py +++ b/neural_compressor/strategy/utils/tuning_space.py @@ -23,7 +23,7 @@ from copy import deepcopy import itertools from ...utils import logger -from .utility import OrderedDefaultDict +from .utility import OrderedDefaultDict, preprocess_user_cfg from .tuning_structs import OpTuningConfig from .constant import TUNING_ITEMS_LST @@ -217,6 +217,7 @@ def _merge_op_cfg(self, cur_op_cap, op_user_cfg, fw_op_cap): from .utility import extract_data_type, reverted_data_type fw_op_cap = deepcopy(fw_op_cap) new_op_cap = deepcopy(cur_op_cap) + op_user_cfg = preprocess_user_cfg(op_user_cfg) for att in ['activation', 'weight']: if op_user_cfg.get(att, None) is not None: user_dtype_lst = op_user_cfg[att]['dtype'] if op_user_cfg[att].get('dtype', None) is not None else [] @@ -471,7 +472,7 @@ def _parse_cap_helper(self, cap): # The dtype should be a string, need to align with fwk.yaml. self.ops_data_type[op_name_type][(quant_mode, att, _data_type, signed_flag)] = \ item_options[0] if isinstance(item_options, list) else item_options - if item_name not in ['dtype', 'quant_mode']: + if item_name not in ['quant_mode']: parsed_op_cap[quant_mode][att][_data_type][signed_flag][item_name] = item_options else: # Parse the data info for item with unique value. diff --git a/neural_compressor/strategy/utils/utility.py b/neural_compressor/strategy/utils/utility.py index b4c5e766153..c8ad40163a1 100644 --- a/neural_compressor/strategy/utils/utility.py +++ b/neural_compressor/strategy/utils/utility.py @@ -16,16 +16,69 @@ # limitations under the License. """Tuning utility.""" -import os -import pickle from collections import OrderedDict -from typing import List, Optional, Any +from copy import deepcopy +import enum +from typing import Dict -import prettytable +class QuantType(enum.IntEnum): + """Quantization type.""" + DYNAMIC = 0 + STATIC = 1 + QAT = 2 + WEIGHT_ONLY = 3 + AUTO = 4 -from neural_compressor.utils import logger -from neural_compressor.utils.utility import print_table, dump_table, OpEntry +class QuantOptions: + """Option Class for Quantization. + This class is used for configuring global variable related to quantization. + The global variable quant_options is created with this class. + + Args: + quant_type(int): Quantization type. Default value is 1. + """ + def __init__(self, quant_type=1): + """Init an QuantOptions object.""" + self._quant_type = quant_type + + @property + def quant_type(self): + """Get quant type.""" + return self._quant_type + + @quant_type.setter + def quant_type(self, quant_type): + """Set quant type. + + Args: + quant_type(int): Quantization type. Default value is 1. + """ + self._quant_type = quant_type + +quant_options = QuantOptions() + +def preprocess_user_cfg(op_user_cfg: Dict): + """Preprocess the op user config for weight only. + + Args: + op_user_cfg: The original user config. + + Example: + op_user_cfg = {'activation': {'bits': [4]}} + op_user_cfg_modified = {'activation': {'bits': [4], 'group_size': [32]}} + + Returns: + The modified config. + """ + op_user_cfg_modified = deepcopy(op_user_cfg) + if quant_options.quant_type == QuantType.WEIGHT_ONLY: + for att, att_cfg in op_user_cfg.items(): + if 'bits' not in att_cfg: + op_user_cfg_modified[att]['bits'] = [4] + if 'group_size' not in att_cfg: + op_user_cfg_modified[att]['group_size'] = [32] + return op_user_cfg_modified class OrderedDefaultDict(OrderedDict): """Ordered default dict.""" diff --git a/neural_compressor/utils/pytorch.py b/neural_compressor/utils/pytorch.py index fe39e8ca691..bb9da0f7b8b 100644 --- a/neural_compressor/utils/pytorch.py +++ b/neural_compressor/utils/pytorch.py @@ -212,6 +212,14 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): try: weights_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), 'best_model.pt') + # for weight only quantized model. + weights_only_config_file = os.path.join( + os.path.abspath(os.path.expanduser(checkpoint_dir)),'weight_config.json') + if os.path.exists(weights_only_config_file): + model.load_state_dict(torch.load(weights_file)) + logger.info('Load weight_only quantized model') + return model + # ------------------------------- try: stat_dict = torch.jit.load(weights_file) logger.info("torch.jit.load is used to recovery the int8 model quantized by INC IPEX backend") diff --git a/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py new file mode 100644 index 00000000000..a6edf6a0dce --- /dev/null +++ b/test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py @@ -0,0 +1,142 @@ +import shutil +import torch +import unittest +from neural_compressor import quantization, PostTrainingQuantConfig + + +class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(30, 40) + self.fc2 = torch.nn.Linear(40, 30) + self.fc3 = torch.nn.Linear(30, 10) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + out = self.fc3(out) + return out + + +def eval_func(model): + # switch to evaluate mode + model.eval() + with torch.no_grad(): + input = torch.randn(3,30) + # compute output + output = model(input) + return 0.0 + + +class TestPytorchWeightOnlyAdaptor(unittest.TestCase): + approach = 'weight_only' + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_RTN_func(self): + # TODO + pass + + def test_RTN_quant(self): + input = torch.randn(3,30) + model = Model() + out1 = model(input) + + conf = PostTrainingQuantConfig( + approach='weight_only', + ) + q_model = quantization.fit(model, conf) + out2 = q_model(input) + self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) + self.assertFalse(torch.all(out1 == out2)) + + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 8, # 1-8 bits + 'group_size': -1, # -1 (per-channel) + 'scheme': 'sym', + 'algorithm': 'RTN', + }, + }, + }, + recipes={ + 'gptq_args':{'percdamp': 0.01}, + 'awq_args':{'alpha': 'auto', 'clip': True}, + }, + ) + q_model = quantization.fit(model, conf, eval_func=eval_func) + out2 = q_model(input) + self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) + self.assertFalse(torch.all(out1 == out2)) + + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': 32, # 1 - 1024 or higher + 'scheme': 'asym', + 'algorithm': 'RTN', + }, + }, + }, + recipes={ + 'gptq_args':{'percdamp': 0.01}, + 'awq_args':{'alpha': 'auto', 'clip': True}, + }, + ) + q_model = quantization.fit(model, conf, eval_func=eval_func) + out2 = q_model(input) + self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) + self.assertFalse(torch.all(out1 == out2)) + + conf = PostTrainingQuantConfig( + approach='weight_only', + op_name_dict={ + 'fc1':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': 32, # 1 - 1024 or higher + 'scheme': 'sym', + 'algorithm': 'RTN', + }, + }, + 'fc2':{ # re.match + "weight": { + 'bits': 3, # 1-8 bits + 'group_size': 16, # 1 - 1024 or higher + 'scheme': 'asym', + 'algorithm': 'RTN', + }, + }, + 'fc3':{ # re.match + "weight": { + 'dtype': 'fp32', + }, + }, + }, + recipes={ + 'gptq_args':{'percdamp': 0.01}, + 'awq_args':{'alpha': 'auto', 'clip': True}, + }, + ) + q_model = quantization.fit(model, conf, eval_func=eval_func) + out2 = q_model(input) + self.assertTrue(torch.all(torch.isclose(out1, out2, atol=5e-1))) + self.assertFalse(torch.all(out1 == out2)) + q_model.save('saved') + from neural_compressor.utils.pytorch import load + new_model = load('saved', model) + out1 = new_model(input) + self.assertTrue(torch.all(out1 == out2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/quantization/test_weight_only_quantization.py b/test/quantization/test_weight_only_quantization.py new file mode 100644 index 00000000000..549a8a7d112 --- /dev/null +++ b/test/quantization/test_weight_only_quantization.py @@ -0,0 +1,41 @@ +import unittest +import copy +import torch +from neural_compressor.adaptor.torch_utils.weight_only import rtn_quantize + + +class TestWeightOnlyQuant(unittest.TestCase): + @classmethod + def setUpClass(self): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = torch.nn.Conv2d(3, 4, 2, 2) + self.act = torch.nn.ReLU6() + self.conv2 = torch.nn.Conv2d(4, 10, 3, 3) + + def forward(self, x): + out = self.conv1(x) + out = self.act(out) + out = self.conv2(out) + x + return out + + self.model = Model() + + @classmethod + def tearDownClass(self): + pass + + def test_conv(self): + fp32_model = copy.deepcopy(self.model) + model1 = rtn_quantize(fp32_model, num_bits=3, group_size=-1) + w_layers_config = { + # 'op_name': (bit, group_size, sheme) + 'conv1': (8, 128, 'sym'), + 'conv2': (4, 32, 'asym') + } + model2 = rtn_quantize(fp32_model, num_bits=3, group_size=-1, w_layers_config=w_layers_config) + + +if __name__ == "__main__": + unittest.main()