From da4c92cdcc1a16df2643a87ab35b49b277c2fb5b Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Mon, 7 Aug 2023 10:15:16 +0800 Subject: [PATCH] Support weight only for ONNX (#1055) Signed-off-by: Mengni Wang --- docs/source/quantization_weight_only.md | 6 +- neural_compressor/adaptor/onnxrt.py | 209 +++++- neural_compressor/adaptor/onnxrt.yaml | 24 +- .../adaptor/ox_utils/weight_only.py | 648 ++++++++++++++++++ neural_compressor/model/onnx_model.py | 40 +- neural_compressor/strategy/strategy.py | 3 + .../test_weight_only_adaptor.py | 172 +++++ test/requirements.txt | 2 +- 8 files changed, 1086 insertions(+), 18 deletions(-) create mode 100644 neural_compressor/adaptor/ox_utils/weight_only.py create mode 100644 test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index 3c561837ab4..4140b51db27 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -24,11 +24,11 @@ There are many excellent works for weight only quantization to improve its accur ## Supported Framework Model Matrix -| Algorithms/Framework | PyTorch | ONNX | +| Algorithms/Framework | PyTorch | ONNX Runtime | |:--------------:|:----------:|:----------:| | RTN | ✔ | ✔ | -| AWQ | ✔ | stay tuned | -| GPTQ | ✔ | stay tuned | +| AWQ | ✔ | ✔ | +| GPTQ | ✔ | ✔ | | TEQ | ✔ | stay tuned | ## Examples diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 26a1ff8baaa..4f79ccc6ed1 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -940,7 +940,8 @@ def query_fw_capability(self, model): if precision == 'bf16' and \ (not self.use_bf16 or (not CpuInfo().bf16 and os.getenv('FORCE_BF16') != '1')): continue - + elif precision == 'weight_only_integer': + continue # get supported optype for target precision optypes = query.get_op_types_by_precision(precision) if \ query.get_op_types_by_precision(precision) != ['*'] else \ @@ -1484,6 +1485,201 @@ def _inference_model_on_batches(self, model, tune_cfg, dataloader, predictions.extend(session.run(None, ort_inputs)) return predictions +@adaptor_registry +class ONNXRT_WeightOnlyAdaptor(ONNXRUNTIMEAdaptor): + """The ONNXRT adaptor layer, do onnx-rt quantization, calibration, inspect layer tensors. + + Args: + framework_specific_info (dict): framework specific configuration for quantization. + """ + + def __init__(self, framework_specific_info): + super().__init__(framework_specific_info) + + @dump_elapsed_time("Pass quantize model") + def quantize(self, tune_cfg, model, data_loader, q_func=None): + """The function is used to do calibration and quanitization in post-training + quantization. + + Args: + tune_cfg (dict): quantization config. + model (object): model need to do quantization. + data_loader (object): calibration dataset. + q_func (optional): training function for quantization aware training mode, + unimplement yet for onnx. + + Returns: + (dict): quantized model + """ + assert q_func is None, "quantization aware training has not been supported on ONNXRUNTIME" + for precision in self.query_handler.get_precisions(): + if precision == 'weight_only_integer': + self.quantizable_op_types += \ + self.query_handler.get_op_types_by_precision(precision=precision) + self.quantizable_ops = self._query_quantizable_ops(model.model) + + quant_config = self._cfg_to_quantize_config(tune_cfg) + algos = set([item["weight"]["algorithm"] for key, item in quant_config.items() if isinstance(item, dict)]) + if "GPTQ" in algos: + from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize + + percdamp = self.recipes.get('gptq_args', {}).get('percdamp', 0.01) + blocksize = self.recipes.get('gptq_args', {}).get('blocksize', 128) + actorder = self.recipes.get('gptq_args', {}).get('actorder', False) + mse = self.recipes.get('gptq_args', {}).get('mse', False) + perchannel = self.recipes.get('gptq_args', {}).get('perchannel', True) + calib_sampling_size = tune_cfg.get('calib_sampling_size', 1) + model = gptq_quantize(model, + quant_config, + data_loader, + calib_sampling_size, + percdamp=percdamp, + blocksize=blocksize, + actorder=actorder, + mse=mse, + perchannel=perchannel) + if "AWQ" in algos: + from neural_compressor.adaptor.ox_utils.weight_only import awq_quantize + + auto_scale = self.recipes.get('awq_args', {}).get('auto_scale', True) + mse_range = self.recipes.get('awq_args', {}).get('mse_range', True) + n_blocks = self.recipes.get('awq_args', {}).get('n_blocks', 5) + calib_sampling_size = tune_cfg.get('calib_sampling_size', 1) + model = awq_quantize(model, + quant_config, + data_loader, + calib_sampling_size, + auto_scale, + mse_range, + n_blocks) + elif "RTN" in algos: + from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize + model = rtn_quantize(model, quant_config) + model.q_config = copy.deepcopy(quant_config) + self._dump_model_op_stats(model, tune_cfg) + model.topological_sort() + return model + + def _dump_model_op_stats(self, model, tune_cfg): + res = {} + # collect all dtype info and build empty results with existing op_type + dtype_set = set() + for op, config in tune_cfg['op'].items(): + op_type = op[1] + if not config['weight']['dtype'] == 'fp32': + num_bits = config['weight']['bits'] + group_size = config['weight']['group_size'] + dtype_str = "A32W{}G{}".format(num_bits, group_size) + dtype_set.add(dtype_str) + dtype_set.add('FP32') + dtype_list = list(dtype_set) + dtype_list.sort() + for op, config in tune_cfg['op'].items(): + op_type = op[1] + if op_type not in res.keys(): + res[op_type] = {dtype: 0 for dtype in dtype_list} + + # fill in results with op_type and dtype + for op, config in tune_cfg['op'].items(): + if config['weight']['dtype'] == 'fp32': + res[op_type]['FP32'] += 1 + else: + num_bits = config['weight']['bits'] + group_size = config['weight']['group_size'] + dtype_str = "A32W{}G{}".format(num_bits, group_size) + res[op_type][dtype_str] += 1 + + # update stats format for dump. + field_names = ["Op Type", "Total"] + field_names.extend(dtype_list) + output_data = [] + for op_type in res.keys(): + field_results = [op_type, sum(res[op_type].values())] + field_results.extend([res[op_type][dtype] for dtype in dtype_list]) + output_data.append(field_results) + + Statistics(output_data, + header='Mixed Precision Statistics', + field_names=field_names).print_stat() + self.optype_statistics = field_names, output_data + + def _cfg_to_quantize_config(self, tune_cfg): + quantize_config = {} + quantize_config['calib_iteration'] = tune_cfg['calib_iteration'] + + for _, op in enumerate(self.quantizable_ops): + if (op.name, op.op_type) not in tune_cfg['op']: + continue + if tune_cfg['op'][(op.name, op.op_type)]['weight']['dtype'] in \ + self.query_handler.get_fallback_list(): + quantize_config[op.name] = \ + tune_cfg['op'][(op.name, op.op_type)]['weight']['dtype'] + else: + quantize_config[op.name] = copy.deepcopy(tune_cfg['op'][(op.name, op.op_type)]) + + return quantize_config + + def query_fw_capability(self, model): + """The function is used to query framework capability. + TODO: will be replaced by framework query API + + Args: + model: onnx model + + Returns: + (dict): quantization capability + """ + # optype_wise and op_wise capability + self._pre_optimize(model) + + quantizable_optype = set([i.op_type for i in self.pre_optimized_model.nodes()]) + optype_wise = OrderedDict() + op_wise = OrderedDict() + for query in [self.query_handler, self.query_handler_ext]: + if query is None: + continue + precisions = query.get_precisions() + + for precision in precisions: + if precision != 'weight_only_integer': + continue + # get supported optype for target precision + optypes = query.get_op_types_by_precision(precision) if \ + query.get_op_types_by_precision(precision) != ['*'] else \ + optype_wise.keys() + + configs = query.get_quantization_capability()[precision] if \ + precision in query.get_quantization_capability() else \ + {'default': {'weight': {'dtype': precision}, 'activation': {'dtype': precision}}} + + for op in optypes: + if op not in quantizable_optype: + continue + if op not in configs: + if 'default' in configs: + op_capability = copy.deepcopy(configs['default']) + else: + continue + else: + op_capability = copy.deepcopy(configs[op]) + op_capability['activation']['quant_mode'] = 'weight_only' + if op not in optype_wise.keys(): + optype_wise[op] = [op_capability] + elif op_capability not in optype_wise[op]: + optype_wise[op].append(op_capability) + + for node in self.pre_optimized_model.nodes(): + if node.op_type in ['MatMul', 'Attention'] and model.get_initializer(node.input[1]) is None: + op_wise.update( + {(node.name, node.op_type): [{'weight': {'dtype': 'fp32'}, 'activation': {'dtype': 'fp32'}}]}) + continue + if node.op_type in optype_wise: + op_wise.update( + {(node.name, node.op_type): copy.deepcopy(optype_wise[node.op_type])}) + + return {'optypewise': optype_wise, 'opwise': op_wise, 'recipes_ops': {}, 'block_wise': []} + + @adaptor_registry class ONNXRT_QLinearOpsAdaptor(ONNXRUNTIMEAdaptor): """The ONNXRT adaptor layer, do onnx-rt quantization, calibration, inspect layer tensors. @@ -1594,6 +1790,7 @@ def _compare(version1, version2): # generate specified version config according to quantization approach and format config = {} + config['capabilities'] = {} for k, v in version_config.items(): if k == 'version': config['version'] = v @@ -1601,15 +1798,15 @@ def _compare(version1, version2): config['graph_optimization'] = v['graph_optimization'] else: if self.static and 'static' in v: - config['capabilities'] = {k: {node_op: node_config + config['capabilities'].update({k: {node_op: node_config for node_op, node_config in v['static'].items() if 'mode' in node_config and \ self.format.split('ops')[0].lower() in \ - [mode.lower() for mode in node_config['mode']]}} + [mode.lower() for mode in node_config['mode']]}}) elif self.dynamic and 'dynamic' in v: - config['capabilities'] = {k: v['dynamic']} - if 'capabilities' not in config: - config['capabilities'] = {} + config['capabilities'].update({k: v['dynamic']}) + elif k == 'weight_only_integer': + config['capabilities'].update({k: v}) # generate other config content including precisions and ops precisions = list(version_config.keys() - {'version', 'recipes'}) diff --git a/neural_compressor/adaptor/onnxrt.yaml b/neural_compressor/adaptor/onnxrt.yaml index 39a9103aa7d..44e2ee9a5d1 100644 --- a/neural_compressor/adaptor/onnxrt.yaml +++ b/neural_compressor/adaptor/onnxrt.yaml @@ -17,6 +17,21 @@ - version: name: '1.6.0' + weight_only_integer: &cap_weight_only { + 'MatMul': &cap_weight_only_matmul { + 'weight': { + 'dtype': ['int'], # no need to care uint + 'bits': [4, 3, 8], # [1-8] + 'group_size': [32, -1, 1, 16, 64, 128, 256, 512, 1024], # [1-inf] + 'scheme': ['sym', 'asym'], # sym, no ZP + 'algorithm': ['RTN', 'AWQ', 'GPTQ'] + }, + 'activation': { + 'dtype': ['fp32'] + } + }, + 'Attention': *cap_weight_only_matmul + } int8: &ref_1_6 { 'static': &ref_1_6_static { 'Conv': { @@ -109,6 +124,7 @@ - version: name: '1.7.0' + weight_only_integer: *cap_weight_only int8: { 'static': { 'FusedConv': { @@ -148,6 +164,7 @@ - version: name: '1.8.0' + weight_only_integer: *cap_weight_only int8: { 'static': { 'FusedConv': { @@ -215,6 +232,7 @@ - version: name: '1.9.0' + weight_only_integer: *cap_weight_only int8: { 'static': { 'FusedConv': { @@ -289,6 +307,7 @@ - version: name: '1.10.0' + weight_only_integer: *cap_weight_only int8: { 'static': { 'FusedConv': { @@ -343,6 +362,7 @@ - version: name: '1.11.0' + weight_only_integer: *cap_weight_only int8: &ref_1_11 { 'static': { 'FusedConv': { @@ -405,6 +425,7 @@ version: name: '1.12.0' int8: *ref_1_11 + weight_only_integer: *cap_weight_only recipes: <<: *default_optimization @@ -412,5 +433,6 @@ version: name: 'default' int8: *ref_1_6 + weight_only_integer: *cap_weight_only recipes: - <<: *default_optimization \ No newline at end of file + <<: *default_optimization diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py new file mode 100644 index 00000000000..d87dcf8032a --- /dev/null +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -0,0 +1,648 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""WeightOnly for onnxrt adaptor.""" + +import sys +import os +import math +import copy +import onnx +import logging +import numpy as np +from onnx import onnx_pb as onnx_proto +from neural_compressor.utils.utility import LazyImport +from neural_compressor.model.model import BaseModel +from neural_compressor.model.onnx_model import ONNXModel +from onnx import numpy_helper, helper + +ort = LazyImport("onnxruntime") +logger = logging.getLogger("neural_compressor") + +def qdq_tensor(data, config, ratio=1.): + """Quant and dequant tensor per group. + + Args: + data : input weight + config (dict): quantization config + ratio (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: qdq weight + """ + bit = config.get("bits", 8) + scheme = config.get("scheme", "asym") + if scheme == "sym": + maxq = 2 ** (bit - 1) - 1 if bit != 1 else 0 + minq = -2 ** (bit - 1) if bit != 1 else -1 + elif scheme == "asym": + maxq = 2 ** bit - 1 + minq = 0 + + rmin = np.min(data, axis=0, keepdims=True) * ratio + rmax = np.max(data, axis=0, keepdims=True) * ratio + if scheme == "sym": + max_range = np.maximum(np.abs(rmin), np.abs(rmax)) + scale = np.ones(rmax.shape, dtype="float32") + scale[max_range > 0] = np.array([float(i) / (maxq - minq) for i in \ + (max_range[max_range > 0] * 2.).flatten().tolist()], dtype="float32") + zero_point = np.zeros(scale.shape) + else: + scale = np.ones(rmax.shape, dtype="float32") + scale[rmin != rmax] = np.array([float(i) / (maxq - minq) for i in \ + (rmax - rmin)[rmin != rmax].flatten().tolist()], dtype="float32") + zero_point = ((np.zeros(scale.shape) - rmin) / scale).round() + + return scale * (np.clip((data / scale + zero_point).round(), minq, maxq) - zero_point) + +def rtn_quantize(model, tune_cfg, ratios={}): + """Quant the model with round to nearst method. + + Args: + model (ModelProto or ONNXModel): onnx model + tune_cfg (dict): quantization config + For example, + tune_cfg={ + 'fc2': + { + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym', + 'algorithm': 'RTN' + } + } + ratios (dict, optional): percentile of clip. Defaults to {}. + + Returns: + model: fake quantized ONNXModel + """ + model = model if isinstance(model, BaseModel) else ONNXModel(model) + for node in model.nodes(): + if node.op_type in ["MatMul", "Attention"] and model.get_initializer(node.input[1]) is not None: + weight = numpy_helper.to_array( + model.get_initializer(node.input[1]), + base_dir=os.path.dirname(model.model_path)).copy() + dtype = weight.dtype + config = tune_cfg[node.name].get("weight", {}) + + org_w_shape = weight.shape # ic, oc + group_size = config.get("group_size", -1) if config.get("group_size", -1) != -1 else org_w_shape[0] + + if org_w_shape[0] % group_size == 0: + weight = weight.reshape(group_size, -1) + weight = qdq_tensor(weight, config, ratios.get(node.input[1], 1)) + weight = weight.reshape(org_w_shape) + else: + index = org_w_shape[0] // group_size * group_size + if index != 0: + part_weight = weight[:index, :].reshape(group_size, -1) + part_weight = qdq_tensor(part_weight, config, ratios.get(node.input[1], 1)) + weight[:index, :] = part_weight.reshape(index, -1) + weight[index:, :] = qdq_tensor(weight[index:, :], config, ratios.get(node.input[1], 1)) + model.set_initializer(node.input[1], weight.astype(dtype), raw=True) + return model + +def get_weight_scale(weight, group_size): + """Get the scale of weight.""" + org_shape = weight.shape + weight = np.reshape(weight, (group_size, -1)) if group_size != -1 else weight + scale = np.mean( + np.reshape(np.abs(weight) / np.max(np.abs(weight), axis=0, keepdims=True), org_shape), + axis=1) + return scale + +def apply_awq_scale(model, tune_cfg, absorb_pairs, output_dicts): + """Apply scale for salient weight.""" + best_scales = {} + new_init_tensors = [] + new_added_mul_nodes = [] + replace_input = [] + updated_nodes = [] + + for parent, nodes in absorb_pairs.items(): + if any([node.input[0] not in output_dicts for node in nodes]): + logger.warning("Miss tensors of node {} during AWQ, skip it!".format(node.name)) + continue + inp = np.concatenate(output_dicts[nodes[0].input[0]], axis=0) + inp_scale = np.mean(np.reshape(np.abs(inp), (-1, inp[0].shape[-1])), axis=0) + weight = [] + org_out = [] + config = tune_cfg.get(nodes[0].name, {}) + + # search scale + best_error = float("inf") + best_ratio = -1 + best_scale = None + n_grid = 20 + + for ratio in range(n_grid): + ratio = ratio * 1 / n_grid + loss = 0 + for node in nodes: + weight = numpy_helper.to_array(model.get_initializer(node.input[1]), + os.path.dirname(model.model_path)) + w_scale = get_weight_scale(weight, config.get("weight", {}).get("group_size", -1)) + org_out = np.matmul(inp, weight) + scales = np.clip(np.power(inp_scale, ratio) / np.power(w_scale, (1 - ratio)), 1e-4, None) + scales = np.reshape(scales / np.sqrt(np.max(scales) * np.min(scales)), (-1, 1)) + + q_weight = qdq_tensor(weight * scales, config.get("weight", {})) / scales + out = np.matmul(inp, q_weight) + loss += np.mean(np.power((org_out - out), 2)) + + is_best = loss < best_error + if is_best: + best_error = loss + best_ratio = ratio + best_scale = scales + + for node in nodes: + tensor = numpy_helper.to_array(model.get_initializer(node.input[1]), + os.path.dirname(model.model_path)) + new_tensor = tensor * best_scale + model.set_initializer(node.input[1], new_tensor.astype(tensor.dtype), raw=True) + output_dicts[node.input[0]] = output_dicts[node.input[0]] / np.reshape(best_scale, (1, -1)) + + parent = model.get_node(parent) + if parent.name in updated_nodes: + continue + + if parent.op_type in ["LayerNormalization", "BatchNormalization", "InstanceNormalization"]: # pragma: no cover + for idx in [1, 2]: + tensor = numpy_helper.to_array(model.get_initializer(parent.input[idx]), + os.path.dirname(model.model_path)) + new_tensor = tensor / np.reshape(best_scale, (1, -1)) + model.set_initializer(parent.input[idx], new_tensor.astype(tensor.dtype), raw=True) + updated_nodes.append(parent.name) + output_dicts[parent.output[0]] = output_dicts[parent.output[0]] / np.reshape(best_scale, (1, -1)) + + elif parent.op_type in ["SimplifiedLayerNormalization", "MatMul", "Gemm", "Mul"] and \ + not all([model.get_initializer(inp) is None for inp in parent.input]): + for inp in parent.input: + if model.get_initializer(inp) is not None: + tensor = numpy_helper.to_array(model.get_initializer(inp), + os.path.dirname(model.model_path)) + new_tensor = tensor / np.reshape(best_scale, (1, -1)) + model.set_initializer(inp, new_tensor.astype(tensor.dtype), raw=True) + updated_nodes.append(parent.name) + output_dicts[parent.output[0]] = output_dicts[parent.output[0]] / np.reshape(best_scale, (1, -1)) + + elif parent.op_type in ["Conv", "FusedConv"]: # pragma: no cover + tensor = numpy_helper.to_array(model.get_initializer(parent.input[2]), + os.path.dirname(model.model_path)) + new_tensor = tensor / np.reshape(best_scale, (1, -1)) + model.set_initializer(parent.input[2], new_tensor.astype(tensor.dtype), raw=True) + updated_nodes.append(parent.name) + output_dicts[parent.output[0]] = output_dicts[parent.output[0]] / np.reshape(best_scale, (1, -1)) + + else: # pragma: no cover + # insert mul + scale_tensor = helper.make_tensor( + name=parent.output[0] + "_weight_only_scale", + data_type=onnx_proto.TensorProto.FLOAT, + dims=best_scale.shape, + vals=(1. / best_scale).flatten().tolist()) + new_init_tensors.append(scale_tensor) + mul_output_name = parent.output[0] + "_weight_only_out" + mul_node = helper.make_node( + "Mul", + inputs=[nodes[0].input[0], scale_tensor.name], + outputs=[mul_output_name], + name=nodes[0].input[0] + "_weight_only_mul" + ) + new_added_mul_nodes.append(mul_node) + for node in nodes: + replace_input.append([node, node.input[0], mul_node.output[0]]) + updated_nodes.append(parent.name) + output_dicts[mul_node.output[0]] = output_dicts[mul_node.input[0]] / np.reshape(best_scale, (1, -1)) + + model.add_nodes(new_added_mul_nodes) + model.add_initializers(new_init_tensors) + for node, old_input_name, new_input_name in replace_input: + model.replace_node_input(node, old_input_name, new_input_name) + + return model, output_dicts + +def apply_awq_clip(model, tune_cfg, absorb_pairs, output_dicts): + """Apply clip for weight by checking mse.""" + ratios = {} + for parent, nodes in absorb_pairs.items(): + if any([node.input[0] not in output_dicts for node in nodes]): + logger.warning("Miss tensors of node {} during AWQ, skip it!".format(node.name)) + continue + + inp = np.concatenate(output_dicts[nodes[0].input[0]], axis=0) + + for node in nodes: + config = tune_cfg.get(node.name, {}) + org_weight = numpy_helper.to_array( + model.get_initializer(node.input[1]), + base_dir=os.path.dirname(model.model_path)) + org_w_shape = org_weight.shape # ic, oc + group_size = config.get("group_size", -1) if config.get("group_size", -1) != -1 else org_w_shape[0] + org_out = np.matmul(inp, org_weight) # n_token, oc + + best_error = float("inf") + best_ratio = 1 + for i_s in range(10): + ratio = 1 - i_s / 100 + weight = copy.deepcopy(org_weight) + if org_w_shape[0] % group_size == 0: + weight = weight.reshape(group_size, -1) + weight = qdq_tensor(weight, config, ratios.get(node.input[1], 1)) + weight = weight.reshape(org_w_shape) + else: + index = org_w_shape[0] // group_size * group_size + if index != 0: + part_weight = weight[:index, :].reshape(group_size, -1) + part_weight = qdq_tensor(part_weight, config, ratios.get(node.input[1], 1)) + weight[:index, :] = part_weight.reshape(index, -1) + weight[index:, :] = qdq_tensor(weight[index:, :], config, ratios.get(node.input[1], 1)) + + cur_out = np.matmul(inp, weight) + loss = np.mean(np.power((org_out - cur_out), 2)) + is_best = loss < best_error + if is_best: + best_error = loss + best_ratio = ratio + ratios[node.input[1]] = best_ratio + model = rtn_quantize(model, tune_cfg, ratios) + return model + +def prepare_inputs(model, n_samples, dataloader): + """Prepare inputs for weight only quantization. + + Args: + model (ModelProto or ONNXModel): onnx model + n_samples (int, optional): calibration sample number. + dataloader (object): dataloader for calibration. + + Returns: + inputs: prepared inputs. + so: session options + """ + from importlib.util import find_spec + from neural_compressor.adaptor.ox_utils.util import to_numpy + + so = ort.SessionOptions() + if sys.version_info < (3, 10) and find_spec('onnxruntime_extensions'): # pragma: no cover + from onnxruntime_extensions import get_library_path + so.register_custom_ops_library(get_library_path()) + if model.is_large_model: + onnx.save_model(model.model, + model.model_path + '_augment.onnx', + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False) + + session = ort.InferenceSession( + model.model.SerializeToString(), + so, + providers=ort.get_available_providers()) if not model.is_large_model else \ + ort.InferenceSession( + model.model_path + '_augment.onnx', + so, + providers=ort.get_available_providers()) + inputs_names = [i.name for i in session.get_inputs()] + del session + + inputs = [] + for i, data in enumerate(dataloader): + if ((i + 1) * dataloader.batch_size) > n_samples: + break + if len(inputs_names) != 1 or isinstance(data[0], dict): + assert len(data[0]) == len(inputs_names), "Input number mismatch, " \ + "require {} but get {}".format(len(inputs_names), len(data[0])) + + if isinstance(data[0], dict): + inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()])) + else: + inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0])])) + return inputs, so + +def awq_quantize(model, + tune_cfg, + dataloader, + n_samples=128, + auto_scale=True, + mse_range=True, + n_blocks=5 + ): + """Quant the model with Activation-aware Weight quantization(AWQ) method. + + Args: + model (ModelProto or ONNXModel): onnx model + tune_cfg (dict): quantization config + For example, + tune_cfg={ + 'fc2': + { + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym', + 'algorithm': 'AWQ' + } + } + dataloader (object): dataloader for calibration. + n_samples (int, optional): calibration sample number. + auto_scale (bool, optional): whether enable scale for salient weight. Defaults to True. + mse_range (bool, optional): whether enable clip for weight by checking mse. Defaults to True. + n_blocks (int, optional): split model into block number to avoid OOM. + + Returns: + model: fake quantized ONNXModel + """ + model = model if isinstance(model, BaseModel) else ONNXModel(model) + output_dicts = {} + + if mse_range or mse_range: + absorb_pairs = model.get_absorb_pairs(["MatMul", "Attention"]) + + inputs, so = prepare_inputs(model, n_samples, dataloader) + del dataloader + + org_output = copy.deepcopy(model.model.graph.output) + model.remove_tensors_from_outputs([i.name for i in org_output]) + num_block = math.ceil(len(absorb_pairs) / n_blocks) + dump_pairs = {} + for idx, parent in enumerate(absorb_pairs): + if (idx + 1) % num_block == 0 or (idx + 1) == len(absorb_pairs): + dump_pairs[parent] = absorb_pairs[parent] + output_dicts = {} + dump_tensor = list(set([i.input[0] for nodes in dump_pairs.values() for i in nodes])) + model.add_tensors_to_outputs(dump_tensor) + + if model.is_large_model: + onnx.save_model(model.model, + model.model_path + '_augment.onnx', + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False) + + session = ort.InferenceSession( + model.model.SerializeToString(), + so, + providers=ort.get_available_providers()) if not model.is_large_model else \ + ort.InferenceSession( + model.model_path + '_augment.onnx', + so, + providers=ort.get_available_providers()) + + for inp in inputs: + for output_idx, output in enumerate(session.run(None, inp)): + output_dicts.setdefault(dump_tensor[output_idx], []).append(output) + + model.remove_tensors_from_outputs(dump_tensor) + if auto_scale: + model, output_dicts = apply_awq_scale(model, tune_cfg, dump_pairs, output_dicts) + if mse_range: + model = apply_awq_clip(model, tune_cfg, dump_pairs, output_dicts) + del output_dicts + dump_pairs = {} + else: + dump_pairs[parent] = absorb_pairs[parent] + + model.model.graph.output.MergeFrom(org_output) + return model + +def gptq(Ws, Hs, config, blocksize=128, percdamp=.01, actorder=False, mse=False, perchannel=True): + """Quant the model with Activation-aware Weight quantization(AWQ) method. + + Args: + Ws (list): list of weight. + Hs (list): list of Hessian matrix. + config (dict): quantizaion config. + blocksize (int, optional): blocksize to quantize weight. + percdamp (float, optional): percent of the average Hessian diagonal to use for dampening. + actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. + mse (bool, optional): whether get scale and zero point with mse error. + perchannel (bool, optional): whether quantize weight per-channel. + + Returns: + Qs: fake quantized weights + """ + Qs = [] + group_size = config.get("weight", {}).get("group_size", -1) + bits = config.get("weight", {}).get("bits", 8) + scheme = config.get("weight", {}).get("scheme", "asym") + maxq = 2 ** bits - 1 + grid=100 + maxshrink=.8 + norm=2.4 + + def find_params(weight): + org_shape = weight.shape + # find zp, scale + if not perchannel: + weight = np.expand_dims(weight.flatten(), axis=1) + tmp = np.zeros(weight.shape[1]) + xmin = np.minimum(np.min(weight, axis=0), tmp) + xmax = np.maximum(np.max(weight, axis=0), tmp) + if scheme == "sym": + xmax = np.maximum(np.abs(xmin), xmax) + tmp = xmin < 0 + if np.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + scale = (xmax - xmin) / maxq + if scheme == "sym": + zero = np.ones(scale.shape) * (maxq + 1) / 2 + else: + zero = np.round(-xmin / scale) + if mse: + best = np.ones([weight.shape[1]]) * float("inf") + for i in range(int(maxshrink * grid)): + p = 1 - i / grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / maxq + zero1 = np.round(-xmin1 / scale1) if scheme != "sym" else zero + q = np.clip(np.round(weight / scale1) + zero1, 0, maxq) + q -= weight + q = np.power(np.abs(q), norm) + err = np.sum(q, 0) + tmp = err < best + if np.any(tmp): + best[tmp] = err[tmp] + scale[tmp] = scale1[tmp] + zero[tmp] = zero1[tmp] + if not perchannel: + tmp = org_shape[1] + scale = np.repeat(scale, tmp) + zero = np.repeat(zero, tmp) + shape = [-1] + [1] * (len(org_shape) - 1) + scale = np.reshape(scale, shape) + zero = np.reshape(zero, shape) + return scale, zero + + for W, H in zip(Ws, Hs): + dtype = W.dtype + shape = W.shape + scale, zp = find_params(W) + dead = np.diag(H) == 0 + H[dead, dead] = 1 + W[dead, :] = 0 # such channel makes no contribution to quantization computation + + # rearrange considering the diag's value + if actorder: + perm = np.argsort(np.diag(H))[::-1] + W = W[perm, :] + H = H[perm, :][:, perm] + Losses = np.zeros(W.shape) + Q = np.zeros(W.shape) + damp = percdamp * np.mean(np.diag(H)) + diag = np.arange(shape[0]) + H[diag, diag] += damp # add a average value of + H = np.linalg.cholesky(np.linalg.inv(H)).T + Hinv = H + for i1 in range(0, shape[0], blocksize): + i2 = min(i1 + blocksize, shape[0]) + count = i2 - i1 + + W1 = copy.deepcopy(W[i1:i2, :]) + Q1 = np.zeros(W1.shape) + Err1 = np.zeros(W1.shape) + Losses1 = np.zeros(W1.shape) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): # within a block, channel wise + w = W1[i, :] + d = Hinv1[i, i] + + if group_size != -1: + if (i1 + i) % group_size == 0: + scale, zp = find_params(W[(i1 + i):(i1 + i + group_size), :]) + + q = (scale * (np.clip(np.round(np.expand_dims(w, axis=1) / scale) + zp, 0, maxq) - zp)).flatten() + Q1[i, :] = q + Losses1[i, :] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[i:, :] -= np.matmul(np.expand_dims(Hinv1[i:, i], axis=1), np.expand_dims(err1, axis=0)) + Err1[i, :] = err1 + + Q[i1:i2, :] = Q1 + Losses[i1:i2, :] = Losses1 / 2 + + W[i2:, :] -= np.matmul(Hinv[i2:, i1:i2], Err1) + + if actorder: + invperm = np.argsort(perm) + Q = Q[invperm, :] + + Qs.append(np.reshape(Q, W.shape).astype(dtype)) + del Ws + return Qs + +def gptq_quantize(model, + tune_cfg, + dataloader, + n_samples=128, + percdamp=.01, + blocksize=128, + actorder=False, + mse=False, + perchannel=True + ): + """Quant the model with Activation-aware Weight quantization(AWQ) method. + + Args: + model (ModelProto or ONNXModel): onnx model + tune_cfg (dict): quantization config + For example, + tune_cfg={ + 'fc2': + { + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym', + 'algorithm': 'GPTQ' + } + } + dataloader (object): dataloader for calibration. + n_samples (int, optional): calibration sample number. + percdamp (float, optional): percent of the average Hessian diagonal to use for dampening. + blocksize (int, optional): blocksize to quantize weight. + actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. + mse (bool, optional): whether get scale and zero point with mse error. + perchannel (bool, optional): whether quantize weight per-channel. + + Returns: + model: fake quantized ONNXModel + """ + model = model if isinstance(model, BaseModel) else ONNXModel(model) + output_dicts = {} + absorb_pairs = model.get_absorb_pairs(["MatMul", "Attention"]) + + inputs, so = prepare_inputs(model, n_samples, dataloader) + del dataloader + org_output = copy.deepcopy(model.model.graph.output) + model.remove_tensors_from_outputs([i.name for i in org_output]) + for parent, nodes in absorb_pairs.items(): + dump_tensor = list(set([i.input[0] for i in nodes])) + model.add_tensors_to_outputs(dump_tensor) + + if model.is_large_model: + onnx.save_model(model.model, + model.model_path + '_augment.onnx', + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False) + + session = ort.InferenceSession( + model.model.SerializeToString(), + so, + providers=ort.get_available_providers()) if not model.is_large_model else \ + ort.InferenceSession( + model.model_path + '_augment.onnx', + so, + providers=ort.get_available_providers()) + + weights = [copy.deepcopy(numpy_helper.to_array(model.get_initializer(node.input[1]), + os.path.dirname(model.model_path))) for node in nodes] + Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights] + nsamples = 0 + for inp in inputs: + output_dicts = {} + for output_idx, output in enumerate(session.run(None, inp)): + output_dicts.setdefault(dump_tensor[output_idx], []).append(output) + + inp = output_dicts[nodes[0].input[0]][0] + tmp = inp.shape[0] + if len(inp.shape) == 3: + inp = np.reshape(inp, (-1, inp.shape[-1])) + Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs] + nsamples += tmp + inp = np.sqrt(2 / nsamples) * inp + Hs = [i + np.matmul(inp.T, inp) for i in Hs] + + model.remove_tensors_from_outputs(dump_tensor) + weights = gptq(weights, + Hs, + tune_cfg.get(nodes[0].name, {}), + blocksize=blocksize, + percdamp=percdamp, + actorder=actorder, + mse=mse, + perchannel=perchannel + ) + for name, weight in zip([i.input[1] for i in nodes], weights): + model.set_initializer(name, weight, raw=True) + model.model.graph.output.MergeFrom(org_output) + return model diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index 00e86e8700b..608816b2e4f 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -144,11 +144,15 @@ def save(self, root): from onnx.external_data_helper import convert_model_to_external_data, \ load_external_data_for_model load_external_data_for_model(self._model, os.path.split(self._model_path)[0]) - convert_model_to_external_data(self._model, - all_tensors_to_one_file=True, - location="int8_weights.pb", - convert_attribute=False) - onnx.save(self._model, root) + onnx.save_model(self._model, + root, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="int8_weights.pb", + size_threshold=1024, + convert_attribute=False) + else: + onnx.save(self._model, root) if self._config is not None: model_type = '' if not hasattr(self._config, 'model_type') else getattr(self._config, 'model_type') @@ -228,13 +232,14 @@ def remove_initializers(self, init_to_remove): for initializer in init_to_remove: self.remove_initializer(initializer) - def set_initializer(self, tensor, array): + def set_initializer(self, tensor, array, raw=False): """Update initializer.""" old_tensor = self.get_initializer(tensor) self.remove_initializer(old_tensor) dims = old_tensor.dims data_type = old_tensor.data_type - new_tensor = onnx.helper.make_tensor(tensor, data_type, dims, array.flatten().tolist()) + new_tensor = onnx.helper.make_tensor(tensor, data_type, dims, array.flatten().tolist()) if not raw \ + else onnx.helper.make_tensor(tensor, data_type, dims, array.tostring(), raw=raw) self.add_initializer(new_tensor) @property @@ -758,6 +763,27 @@ def match_parent( return None + def get_absorb_pairs(self, target_optype): + """Find absorbable nodes based on parent op_type and their own input status. + + Args: + target_optype (list): target absorbable optype. + + Returns: + absorb_pairs (dict): a dict of absorb pairs {parent: list of absorbable children}. + """ + absorbable_optypes = ["LayerNormalization", "BatchNormalization", "InstanceNormalization", "Conv", + "SimplifiedLayerNormalization", "MatMul", "Gemm", "Mul", "FusedConv"] + absorb_pairs = {} + for node in self.nodes(): + if node.op_type in target_optype and self.get_initializer(node.input[1]) is not None: + parent = self.get_parent(node, 0) + if parent is None or parent.op_type not in absorbable_optypes or \ + self.get_initializer(parent.input[1]) is None: + continue + absorb_pairs.setdefault(parent.name, []).append(node) + return absorb_pairs + def match_parent_path( self, node, diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py index 48edd6329e3..1cbf76a5212 100644 --- a/neural_compressor/strategy/strategy.py +++ b/neural_compressor/strategy/strategy.py @@ -1371,6 +1371,9 @@ def _set_framework_info(self, q_dataloader, q_func=None): framework_specific_info['use_bf16'] = True if framework_specific_info['backend'] == 'onnxrt_dnnl_ep' and self.config.device == 'cpu': framework_specific_info['use_bf16'] = True + if self.config.approach =='post_training_weight_only': + framework = 'onnxrt_weightonly' # use specific adaptor for weight_only approach + if framework == 'pytorch_ipex' or framework == 'pytorch' or framework == 'pytorch_fx': if self.config.backend == 'ipex': framework = 'pytorch_ipex' diff --git a/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py b/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py new file mode 100644 index 00000000000..e94adabcd40 --- /dev/null +++ b/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py @@ -0,0 +1,172 @@ +import os +import onnx +import shutil +import subprocess +import unittest +import numpy as np +import onnxruntime as ort +from transformers import AutoTokenizer +from neural_compressor import quantization, PostTrainingQuantConfig + +def Inference(model, data): + sess = ort.InferenceSession(model.SerializeToString(), providers=ort.get_all_providers()) + out = sess.run(None, data) + return out + +class DummyNLPDataloader(object): + def __init__(self, model_name): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.sequence_a = "intel-extension-for-transformers is based in SH" + self.sequence_b = "Where is intel-extension-for-transformers based? NYC or SH" + self.encoded_dict = self.tokenizer(self.sequence_a, self.sequence_b, return_tensors='pt') + self.encoded_dict['labels'] = 1 + self.batch_size = 1 + + def __iter__(self): + yield {'input_ids': self.encoded_dict['input_ids'].detach().cpu().numpy(), + 'attention_mask': self.encoded_dict['attention_mask'].detach().cpu().numpy()}, self.encoded_dict['labels'] + +class TestWeightOnlyAdaptor(unittest.TestCase): + + @classmethod + def setUpClass(self): + cmd = 'optimum-cli export onnx --model hf-internal-testing/tiny-random-gptj --task text-generation gptj/' + p = subprocess.Popen(cmd, preexec_fn=os.setsid, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) # nosec + p.communicate() + + self.model = onnx.load('gptj/decoder_model.onnx') + self.dataloader = DummyNLPDataloader('hf-internal-testing/tiny-random-gptj') + + @classmethod + def tearDownClass(self): + shutil.rmtree("nc_workspace", ignore_errors=True) + shutil.rmtree("gptj", ignore_errors=True) + + def test_RTN_quant(self): + + conf = PostTrainingQuantConfig( + approach='weight_only', + ) + q_model = quantization.fit(self.model, conf) + for data, _ in self.dataloader: + q_out = Inference(q_model.model, data) + org_out = Inference(self.model, data) + for q, org in zip(q_out, org_out): + self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 8, # 1-8 bits + 'group_size': -1, # -1 (per-channel) + 'scheme': 'sym', + 'algorithm': 'RTN', + }, + }, + }, + ) + q_model = quantization.fit(self.model, conf, calib_dataloader=self.dataloader) + for data, _ in self.dataloader: + q_out = Inference(q_model.model, data) + org_out = Inference(self.model, data) + for q, org in zip(q_out, org_out): + self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + + def test_AWQ_quant(self): + + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': -1, # -1 (per-channel) + 'scheme': 'sym', + 'algorithm': 'AWQ', + }, + }, + }, + recipes={ + 'awq_args':{'auto_scale': True, 'mse_range': True}, + }, + ) + q_model = quantization.fit(self.model, conf, calib_dataloader=self.dataloader) + for data, _ in self.dataloader: + q_out = Inference(q_model.model, data) + org_out = Inference(self.model, data) + for q, org in zip(q_out, org_out): + self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': 32, + 'scheme': 'sym', + 'algorithm': 'AWQ', + }, + }, + }, + recipes={ + 'awq_args':{'auto_scale': False, 'mse_range': True}, + }, + ) + q_model = quantization.fit(self.model, conf, calib_dataloader=self.dataloader) + for data, _ in self.dataloader: + q_out = Inference(q_model.model, data) + org_out = Inference(self.model, data) + for q, org in zip(q_out, org_out): + self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': 32, + 'scheme': 'asym', + 'algorithm': 'AWQ', + }, + }, + }, + recipes={ + 'awq_args':{'auto_scale': True, 'mse_range': False}, + }, + ) + q_model = quantization.fit(self.model, conf, calib_dataloader=self.dataloader) + for data, _ in self.dataloader: + q_out = Inference(q_model.model, data) + org_out = Inference(self.model, data) + for q, org in zip(q_out, org_out): + self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + + def test_GPTQ_quant(self): + + conf = PostTrainingQuantConfig( + approach='weight_only', + op_type_dict={ + '.*':{ # re.match + "weight": { + 'bits': 4, # 1-8 bits + 'group_size': -1, # -1 (per-channel) + 'scheme': 'sym', + 'algorithm': 'GPTQ', + }, + }, + }, + ) + q_model = quantization.fit(self.model, conf, calib_dataloader=self.dataloader) + for data, _ in self.dataloader: + q_out = Inference(q_model.model, data) + org_out = Inference(self.model, data) + for q, org in zip(q_out, org_out): + self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/requirements.txt b/test/requirements.txt index b9b11da2ca4..7719db52820 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -20,4 +20,4 @@ dynast==1.3.0 intel-extension-for-pytorch tf2onnx xgboost -optimum +optimum \ No newline at end of file