diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 10f173d981b..b317a01e5fc 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -35,6 +35,7 @@ from neural_compressor.adaptor.ox_utils.util import ONNXRT_BACKENDS, PROVIDERS, to_numpy from neural_compressor.adaptor.query import QueryBackendCapability from neural_compressor.data.dataloaders.base_dataloader import BaseDataLoader +from neural_compressor.model.onnx_model import ONNXModel from neural_compressor.utils.utility import GLOBAL_STATE, MODE, CpuInfo, LazyImport, Statistics, dump_elapsed_time onnx = LazyImport("onnx") @@ -267,8 +268,6 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): ): # pragma: no cover from onnx import version_converter - from neural_compressor.model.onnx_model import ONNXModel - try: model = self._rename_node(ONNXModel(version_converter.convert_version(model.model, 15))) except: @@ -308,18 +307,146 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): iterations = tune_cfg.get("calib_iteration", 1) calib_sampling_size = tune_cfg.get("calib_sampling_size", 1) - if not self.dynamic: - calib_iterations = self._reset_calib_iter(data_loader, calib_sampling_size, iterations) - quantize_params = self._get_quantize_params(tmp_model, data_loader, quantize_config, calib_iterations) + + if self.recipes.get("layer_wise_quant", False) and not self.dynamic: + # layer-wise quantization + # details refer to docs/source/quantization_weight_only.md#layer-wise-quantization + _model_to_split = copy.deepcopy(tmp_model) + + split_nodes = _model_to_split.find_split_nodes() + logger.info( + "Will split model into {} parts to do layer-wise quantization".format( + len([node.name for node in split_nodes]) + 1 + ) + ) + logger.debug( + "Will split model with these nodes for layer-wise quantization: {}".format( + [node.name for node in split_nodes] + ) + ) + + split_idx = 1 + model_to_split = [_model_to_split] + dataloader_for_split_model = [data_loader] + quantize_params = {} + quantized_model_merged = None + + while len(model_to_split) != 0: + split_model = model_to_split.pop(0) + split_node = split_nodes.pop(0) + save_both_split_models = True if len(split_nodes) == 0 else False + shape_infer = True if split_idx == 1 else False + + # split model with given split_node + split_model_part_1, split_model_part_2 = split_model.split_model_with_node( + split_node.name, tmp_model.model_path, shape_infer, save_both_split_models + ) + if not save_both_split_models: + # append split_model_part_2 to do next split + model_to_split.append(split_model_part_2) + + logger.info("Quantize split model {}".format(split_idx)) + # get quantize params of split model + split_quantize_params, dataloder_for_next_split_model = self._get_split_model_quantize_params( + split_model_part_1, dataloader_for_split_model, quantize_config, calib_sampling_size, iterations + ) + dataloader_for_split_model.append(dataloder_for_next_split_model) + quantize_params.update(split_quantize_params) + + # quantize split model + quantized_model_merged = self._quantize_split_model( + split_model_part_1, quantize_config, split_quantize_params, quantized_model_merged + ) + + split_idx += 1 + + # if this is the last split, then quantize the last split model + if save_both_split_models: + logger.info("Quantize split model {}".format(split_idx)) + # get quantize params of split model + split_quantize_params, dataloder_for_next_split_model = self._get_split_model_quantize_params( + split_model_part_2, dataloader_for_split_model, quantize_config, calib_sampling_size, iterations + ) + quantize_params.update(split_quantize_params) + + # quantize split model + quantized_model_merged = self._quantize_split_model( + split_model_part_2, quantize_config, split_quantize_params, quantized_model_merged + ) + quantized_model_merged.re_org_output(tmp_model.output()) # re-org output as the origin output + + self.quantize_params = quantize_params + tmp_model.q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params) + tmp_model.model = quantized_model_merged.model + self.quantize_config = quantize_config # update so other methods can know current configs + self._dump_model_op_stats(tmp_model) + tmp_model.topological_sort() + tmp_model.check_is_large_model() + return tmp_model + else: - quantize_params = None - self.quantize_params = quantize_params + if not self.dynamic: + calib_iterations = self._reset_calib_iter(data_loader, calib_sampling_size, iterations) + quantize_params, _ = self._get_quantize_params( + tmp_model, data_loader, quantize_config, calib_iterations + ) + else: + quantize_params = None + self.quantize_params = quantize_params + + from neural_compressor import options + from neural_compressor.adaptor.ox_utils.quantizer import Quantizer + quantizer = Quantizer( + tmp_model, + quantize_config, + format, + self.static, + quantize_params, + self.quantizable_op_types, + self.query_handler.get_fallback_list(), + self.reduce_range, + options.onnxrt.qdq_setting.AddQDQPairToWeight + if "add_qdq_pair_to_weight" not in self.recipes + else self.recipes.get("add_qdq_pair_to_weight", False), + options.onnxrt.qdq_setting.OpTypesToExcludeOutputQuantizatioin + if "optypes_to_exclude_output_quant" not in self.recipes + else self.recipes.get("optypes_to_exclude_output_quant", []), + options.onnxrt.qdq_setting.DedicatedQDQPair + if "dedicated_qdq_pair" not in self.recipes + else self.recipes.get("dedicated_qdq_pair", False), + self.backend, + ) + quantizer.quantize_model() + tmp_model.q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params) + tmp_model.model = quantizer.model.model + self.quantize_config = quantize_config # update so other methods can know current configs + self._dump_model_op_stats(tmp_model) + tmp_model.topological_sort() + return tmp_model + + def _get_split_model_quantize_params( + self, split_model, split_dataloader, quantize_config, calib_sampling_size, iterations + ): + """Get quantize params for current split model and get dataloader for next split model.""" + dataloader = split_dataloader.pop(0) + calib_iterations = self._reset_calib_iter(dataloader, calib_sampling_size, iterations) + split_quantize_params, dataloder_for_next_split_model = self._get_quantize_params( + split_model, + dataloader, + quantize_config, + calib_iterations, + split_model_input_names=split_model.input(), + ) + return split_quantize_params, dataloder_for_next_split_model + + def _quantize_split_model(self, split_model, quantize_config, quantize_params, quantized_model_merged): + """Quantize split model, and merge the quantized models to generate final model.""" from neural_compressor import options from neural_compressor.adaptor.ox_utils.quantizer import Quantizer quantizer = Quantizer( - tmp_model, + split_model, quantize_config, format, self.static, @@ -339,12 +466,16 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): self.backend, ) quantizer.quantize_model() - tmp_model.q_config = self._generate_qconfig(model.model, tune_cfg, quantize_params) - tmp_model.model = quantizer.model.model - self.quantize_config = quantize_config # update so other methods can know current configs - self._dump_model_op_stats(tmp_model) - tmp_model.topological_sort() - return tmp_model + split_model.model = quantizer.model.model + split_model.topological_sort() + + if quantized_model_merged is None: + quantized_model_merged = quantizer.model + quantized_model_merged.write_external_data_to_new_location(overwrite=True) + else: + quantized_model_merged.merge_split_models(quantizer.model) + + return quantized_model_merged def _check_backend_available(self, backend): """Check backend is available or not.""" @@ -570,7 +701,7 @@ def _dump_model_op_stats(self, model): Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat() self.optype_statistics = field_names, output_data - def _get_quantize_params(self, model, data_loader, quantize_config, iterations): + def _get_quantize_params(self, model, data_loader, quantize_config, iterations, **kwargs): from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment from neural_compressor.model.onnx_model import ONNXModel @@ -588,10 +719,12 @@ def _get_quantize_params(self, model, data_loader, quantize_config, iterations): iterations=list(range(0, iterations)), backend=self.backend, reduce_range=self.reduce_range, + **kwargs, ) self.min_max = augment.dump_minmax(quantize_config) quantize_params = augment.dump_calibration(quantize_config, min_max=self.min_max) - return quantize_params + dataloder_for_next_split_model = augment.dataloder_for_next_split_model + return quantize_params, dataloder_for_next_split_model def inspect_tensor( self, @@ -606,7 +739,6 @@ def inspect_tensor( ): """The function is used by tune strategy class for dumping tensor info.""" from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment - from neural_compressor.model.onnx_model import ONNXModel from neural_compressor.utils.utility import dump_data_to_local if not isinstance(model, ONNXModel): @@ -763,6 +895,9 @@ def _pre_optimize(self, model, level=1): } if not isinstance(self.query_handler.get_graph_optimization(), list): level = self.query_handler.get_graph_optimization() + elif self.recipes.get("layer_wise_quant"): + level = "ENABLE_BASIC" + logger.info("Force set graph optimization level to 'ENABLE_BASIC' for layer-wise quantization") elif options.onnxrt.graph_optimization.level is not None: level = options.onnxrt.graph_optimization.level elif self.recipes.get("graph_optimization_level", None) is not None: @@ -778,10 +913,23 @@ def _pre_optimize(self, model, level=1): ) sess_options.graph_optimization_level = optimization_levels[level] sess_options.optimized_model_filepath = os.path.join(self.work_space, "Optimized_model.onnx") + if model.is_large_model and self.recipes.get("layer_wise_quant", False): + # save the model and external data for layer-wise quantization + external_data_filename = os.path.basename(sess_options.optimized_model_filepath) + "_data" + external_data_file_threshold = 1024 + sess_options.add_session_config_entry( + "session.optimized_model_external_initializers_file_name", external_data_filename + ) + sess_options.add_session_config_entry( + "session.optimized_model_external_initializers_min_size_in_bytes", str(external_data_file_threshold) + ) + logger.info("Saving optimized model for layer-wise quantization. This may take a while...") + if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover from onnxruntime_extensions import get_library_path sess_options.register_custom_ops_library(get_library_path()) + if not model.is_large_model: sess = ort.InferenceSession( model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"] @@ -792,13 +940,14 @@ def _pre_optimize(self, model, level=1): else: # pragma: no cover logger.warning("Please use model path instead of onnx model object to quantize") del sess - tmp_model = onnx.load(sess_options.optimized_model_filepath, load_external_data=False) - if model.is_large_model: # pragma: no cover + # load external data if model is large and not layer wise quantization + if model.is_large_model and not self.recipes.get("layer_wise_quant", False): # pragma: no cover from onnx.external_data_helper import load_external_data_for_model load_external_data_for_model(tmp_model, os.path.split(model.model_path)[0]) + model.model_path = sess_options.optimized_model_filepath model.model = ( self._replace_gemm_with_matmul(tmp_model).model @@ -903,8 +1052,6 @@ def _replace_gemm_with_matmul(model): new_nodes = [] from onnx import numpy_helper - from neural_compressor.model.onnx_model import ONNXModel - if not isinstance(model, ONNXModel): model = ONNXModel(model) diff --git a/neural_compressor/adaptor/ox_utils/calibration.py b/neural_compressor/adaptor/ox_utils/calibration.py index c97cef7638c..2e15acd0d1e 100644 --- a/neural_compressor/adaptor/ox_utils/calibration.py +++ b/neural_compressor/adaptor/ox_utils/calibration.py @@ -63,6 +63,7 @@ def __init__( iterations=[], backend="CPUExecutionProvider", reduce_range=False, + **kwargs, ): """Initialization. @@ -94,6 +95,16 @@ def __init__( self.ort_version = Version(onnxruntime.__version__) self.reduce_range = reduce_range + self.layer_wise = True if len(kwargs.get("split_model_input_names", [])) != 0 else False + if self.layer_wise: + self.split_model_input_names = kwargs.get("split_model_input_names", []) + self._dataloder_for_next_split_model = None + + @property + def dataloder_for_next_split_model(self): + """Return dataloader for next split model for layer-wise quantization.""" + return self._dataloder_for_next_split_model + def augment_graph(self, activation_only=False, weight_only=False): """Augment_graph. @@ -245,12 +256,13 @@ def get_intermediate_outputs(self, q_config=None): len_inputs = len(session.get_inputs()) inputs_names = [session.get_inputs()[i].name for i in range(len_inputs)] + len_outputs = len(session.get_outputs()) + outputs_names = [session.get_outputs()[i].name for i in range(len_outputs)] node_output_names = [ output.name if output.name not in self.dequantized_output else self.dequantized_output[output.name] for output in session.get_outputs() ] - augment_model_wrapper = ( ONNXModel(self.augmented_model) if not self.model_wrapper.is_large_model @@ -271,6 +283,7 @@ def get_intermediate_outputs(self, q_config=None): output_dicts = {} intermediate_tensor = {} name_to_calibrator = {} + ort_inputs_for_next_split_model = [] for idx, (inputs, labels) in enumerate(self.dataloader): ort_inputs = {} @@ -281,7 +294,9 @@ def get_intermediate_outputs(self, q_config=None): else: ort_inputs.update({inputs_names[0]: to_numpy(inputs)}) else: - assert len_inputs == len(inputs), "number of input tensors must align with graph inputs" + if not self.layer_wise: + # for layer-wise calibration + assert len_inputs == len(inputs), "number of input tensors must align with graph inputs" if isinstance(inputs, dict): for name, input in inputs.items(): @@ -289,7 +304,15 @@ def get_intermediate_outputs(self, q_config=None): else: ort_inputs = dict(zip(inputs_names, [to_numpy(i) for i in inputs])) - def _collect_data(): + def _collect_data(ort_inputs): + if self.layer_wise: + # for layer-wise calibration + ort_inputs = { + input_name: input_tensor + for input_name, input_tensor in ort_inputs.items() + if input_name in self.split_model_input_names + } + for output_idx, output in enumerate(session.run(None, ort_inputs)): if q_config is not None and output.size != 0: node_name = name_to_node[node_output_names[output_idx]] @@ -321,13 +344,18 @@ def _collect_data(): elif q_config is None: output_dicts.setdefault(node_output_names[output_idx], []).append(output) + if self.layer_wise: + # for layer-wise calibration + ort_inputs.update({outputs_names[output_idx]: output}) + ort_inputs_for_next_split_model.append((ort_inputs, labels)) + if self.iterations != []: if idx > max(self.iterations): break if idx in self.iterations: - _collect_data() + _collect_data(ort_inputs) else: - _collect_data() + _collect_data(ort_inputs) # for kl and percentile method, collect calibration range after all tensors are collected. merged_dict = intermediate_tensor @@ -344,6 +372,9 @@ def _collect_data(): output_dicts.setdefault(output_name, []).append(list(calibrator.calib_range)) calibrator.clear() del calibrator + + self._dataloder_for_next_split_model = ort_inputs_for_next_split_model + return list(output_dicts.keys()), output_dicts def _dequantize(self, tensor, scale_tensor, zo_tensor): diff --git a/neural_compressor/model/onnx_model.py b/neural_compressor/model/onnx_model.py index 5186c4ca9b5..31ea1a1bbc5 100644 --- a/neural_compressor/model/onnx_model.py +++ b/neural_compressor/model/onnx_model.py @@ -43,9 +43,11 @@ def __init__(self, model, **kwargs): """ self._model = model if not isinstance(model, str) else onnx.load(model) self._model_path = None if not isinstance(model, str) else model - self._is_large_model = self.check_large_model() - if self._is_large_model and self._model_path is None: + + self.check_is_large_model() + if self._is_large_model and self._model_path is None and not kwargs.get("ignore_warning", False): logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize") + self._config = None if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()): from transformers import PretrainedConfig @@ -61,25 +63,28 @@ def __init__(self, model, **kwargs): self._get_graph_info() self._q_config = None - def check_large_model(self): + def check_is_large_model(self): """Check model > 2GB.""" init_size = 0 for init in self._model.graph.initializer: # if initializer has external data location, return True if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: - return True + self._is_large_model = True + return # if raise error of initializer size > 2GB, return True try: init_bytes = init.SerializeToString() init_size += sys.getsizeof(init_bytes) except Exception as e: if "exceeds maximum protobuf size of 2GB" in str(e): - return True + self._is_large_model = True + return else: # pragma: no cover raise e if init_size > MAXIMUM_PROTOBUF: - return True - return False + self._is_large_model = True + return + self._is_large_model = False @property def is_large_model(self): @@ -158,7 +163,7 @@ def save(self, root): if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]): raise ValueError('"root" directory does not exists.') if self.is_large_model: # pragma: no cover - from onnx.external_data_helper import convert_model_to_external_data, load_external_data_for_model + from onnx.external_data_helper import load_external_data_for_model load_external_data_for_model(self._model, os.path.split(self._model_path)[0]) onnx.save_model( @@ -434,8 +439,10 @@ def _searcher(tensor_name): def save_model_to_file(self, output_path, use_external_data_format=False): """Save model to external data, which is needed for model size > 2GB.""" + from onnx.external_data_helper import convert_model_to_external_data + if use_external_data_format: - onnx.external_data_helper.convert_model_to_external_data( + convert_model_to_external_data( self._model, all_tensors_to_one_file=True, location=Path(output_path).name + ".data" ) onnx.save_model(self._model, output_path) @@ -619,6 +626,82 @@ def get_nodes_chain(self, start, stop, result_chain=[]): return result_chain + def find_split_node_for_layer_wise_quantization(self): + """Find split node for layer wise quantization.""" + # find split nodes of decoder blocks + # embed -> decoder.0 -(split_node)-> ... -(split_node)-> decoder.n -(split_node)-> norm -> head + # after split: embed -> decoder.0, + # decoder.1, + # decoder.2, + # ..., + # decoder.n, + # norm -> head + start_nodes = [] + for node in self._model.graph.node: + start_node, qkv_nodes_list = None, None + if node.op_type == "SkipLayerNormalization": + start_node = node + qkv_nodes_list = [ + self.match_parent_path( + start_node, + ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ), + ] + if node.op_type == "Add": + start_node = node + qkv_nodes_list = [ + # match base attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [0, None, 0, 0, 0], + ), + self.match_parent_path( + start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0] + ), + # match gpt attention no past structure + self.match_parent_path( + start_node, + ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], + [None, 0, 0, 0, 0, 0], + output_name_to_node=self.output_name_to_node, + return_indice=[], + ), + # match bart attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [0, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["MatMul", "Mul", "MatMul", "Mul", "Div", "Add"], + [None, 0, None, 0, None, 0], + ), + self.match_parent_path( + start_node, + ["MatMul", "Mul", "MatMul", "SimplifiedLayerNormalization", "Add"], + [None, 0, None, 0, 0], + ), + ] + if not start_node: + continue + if not any(qkv_nodes_list): + continue + start_nodes.append(start_node) + return start_nodes + def find_qkv_in_attention(self, find_all=False): """Find qkv MatMul in Attention. @@ -680,7 +763,6 @@ def find_qkv_in_attention(self, find_all=False): [1, None, 0, 0, 0, 0], ), ] - if not start_node: continue if not any(qkv_nodes_list): @@ -894,3 +976,264 @@ def is_smoothquant_model(self): if "_smooth_scale" in init.name: return True return False + + def find_split_nodes(self): + """Find split nodes for layer-wise quantization.""" + split_nodes = self.find_split_node_for_layer_wise_quantization() + return split_nodes + + def split_model_with_node( + self, split_node_name, path_of_model_to_split, shape_infer=True, save_both_split_models=True + ): + """Split model into two parts at a given node. + + Args: + split_node_name (str): name of the node where the model is split at> + path_of_model_to_split (str): path of model to be split. + shape_infer (bool): do shape inference. Default is True. + save_both_split_models (bool): whether to save the two split models. + False means only save the first split model. + True means save both the two split models. + Default id True. + + Returns: + tuple: the first split model, the second split model + """ + # origin model : ... -> node_1 -> split_node -> node_2 -> ... + # split model 1: ... -> node_1 -> split_node + # split model 2: node_2 -> ... + + split_model_part_1 = onnx.ModelProto() + split_model_part_1.CopyFrom(self._model) + split_model_part_1.graph.ClearField("node") + + split_model_part_2 = onnx.ModelProto() + split_model_part_2.CopyFrom(self._model) + split_model_part_2.graph.ClearField("node") + + split_node_output = None + part_idx = 1 + for node in self._model.graph.node: + if part_idx == 1: + split_model_part_1.graph.node.append(node) + elif part_idx == 2: + split_model_part_2.graph.node.append(node) + + if node.name == split_node_name: + split_node_output = node.output + part_idx = 2 + + assert len(split_node_output) == 1, ( + "Only support split at node with 1 output tensor, while " + "current split node {} has {} output tensors".format(split_node_name, len(split_node_output)) + ) + split_tensor_name = split_node_output[0] + + # infer shape of the model to be split + if shape_infer: + try: + # need ort.GraphOptimizationLevel <= ORT_ENABLE_BASIC + import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer + + self._model = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(self._model, auto_merge=True) + except Exception as e: # pragma: no cover + logger.error("Shape infer fails for layer-wise quantization") + if "Incomplete symbolic shape inference" in str(e): + logger.warning("Please set graph optimization level to 'ENABLE_BASIC' for layer-wise quantization.") + raise e + + split_tensor_type, split_tensor_shape = self._get_output_type_shape_by_tensor_name(split_tensor_name) + split_tensor = onnx.helper.make_tensor_value_info(split_tensor_name, split_tensor_type, split_tensor_shape) + + split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True) + split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True) + + # remove unused input & output + split_model_part_1._remove_unused_input_output() + split_model_part_2._remove_unused_input_output() + + split_model_part_1.model.graph.output.append(split_tensor) + split_model_part_2.model.graph.input.append(split_tensor) + + insert_output_for_model_1 = [] + insert_input_for_model_2 = [] + for output in split_model_part_1.output_name_to_node.keys(): + if output in split_model_part_2.input_name_to_nodes.keys(): + output_type, output_shape = self._get_output_type_shape_by_tensor_name(output) + output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape) + if output_tensor not in split_model_part_1.model.graph.output: + insert_output_for_model_1.append(output_tensor) + if output_tensor not in split_model_part_2.model.graph.input: + insert_input_for_model_2.append(output_tensor) + + # insert model 1 output + for output in insert_output_for_model_1: + split_model_part_1.model.graph.output.append(output) + + # insert model 2 input + for input in insert_input_for_model_2: + split_model_part_2.model.graph.input.append(input) + + # remove unused init + split_model_part_1.remove_unused_init() + split_model_part_2.remove_unused_init() + + split_model_part_1.update() + split_model_part_2.update() + + dir_of_model_to_split = os.path.dirname(path_of_model_to_split) + + split_model_part_1.load_model_initializer_by_tensor(dir_of_model_to_split) + split_model_part_1_path = os.path.join(dir_of_model_to_split, "split_model_part_1.onnx") + split_model_part_1.model_path = split_model_part_1_path + split_model_part_1._save_split_model(split_model_part_1_path) + split_model_part_1.check_is_large_model() + logger.debug("save split model part 1 to {} for layer wise quantization".format(split_model_part_1_path)) + + if save_both_split_models: + split_model_part_2.load_model_initializer_by_tensor(dir_of_model_to_split) + split_model_part_2_path = os.path.join(dir_of_model_to_split, "split_model_part_2.onnx") + split_model_part_2.model_path = split_model_part_2_path + split_model_part_2._save_split_model(split_model_part_2_path) + split_model_part_2.check_is_large_model() + logger.debug("save split model part 2 to {} for layer wise quantization".format(split_model_part_2_path)) + return split_model_part_1, split_model_part_2 + else: + return split_model_part_1, split_model_part_2 + + def _save_split_model(self, save_path): + """Save split model as external data for layer wise quantization. + + Args: + save_path (str): the path to save the split model + """ + if os.path.exists(save_path + "_data"): + os.remove(save_path + "_data") + onnx.save_model( + self._model, + save_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=save_path.split("/")[-1] + "_data", + size_threshold=1024, + convert_attribute=False, + ) + + def _get_output_type_shape_by_tensor_name(self, tensor_name): + """Get output type and shape with a tensor name. + + Args: + tensor_name (str): name of a tensor + + Returns: + tuple: output type and shape + """ + elem_type = onnx.TensorProto.FLOAT + shape = None + for output in self._model.graph.value_info: + if output.name == tensor_name: + elem_type = output.type.tensor_type.elem_type + shape = [ + dim.dim_value if dim.HasField("dim_value") else -1 for dim in output.type.tensor_type.shape.dim + ] + break + return elem_type, shape + + def _remove_unused_input_output(self): + """Remove unused input & output for split model.""" + remove_outputs = [] + remove_inputs = [] + for output in self._model.graph.output: + if output.name not in self.output_name_to_node.keys(): + remove_outputs.append(output) + + for input in self._model.graph.input: + if input.name not in self.input_name_to_nodes.keys(): + remove_inputs.append(input) + + for output in remove_outputs: + self._model.graph.output.remove(output) + for input in remove_inputs: + self._model.graph.input.remove(input) + + def remove_unused_init(self): + """Remove unused init.""" + remov_inits = [] + for init in self._model.graph.initializer: + if init.name not in self.input_name_to_nodes.keys(): + remov_inits.append(init) + self.remove_initializers(remov_inits) + + def load_model_initializer_by_tensor(self, data_path=None): + """Load model initializer by tensor. + + Args: + data_path (str, optional): the directory of saved initializer. Defaults to None. + """ + from onnx.external_data_helper import load_external_data_for_tensor + + if data_path is None: + data_path = os.path.dirname(self._model_path) + for init in self._model.graph.initializer: + if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: + load_external_data_for_tensor(init, data_path) + + def write_external_data_to_new_location(self, external_data_location="external.data", overwrite=False): + """Write external data of merged quantized model to new location to save memory. + + Args: + external_data_location (str, optional): external data location of merged quantized model. + Defaults to "external.data". + overwrite (bool, optional): if True, remove existed externa data. Defaults to False. + """ + from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors + + if overwrite and os.path.exists(os.path.join(os.path.dirname(self._model_path), external_data_location)): + os.remove(os.path.join(os.path.dirname(self._model_path), external_data_location)) + self.load_model_initializer_by_tensor() + convert_model_to_external_data(self._model, location=external_data_location) + # TODO : if init is already saved, skip write it + write_external_data_tensors(self._model, filepath=os.path.dirname(self._model_path)) + + def merge_split_models(self, to_merge_model): + """Merge two split model into final model.""" + to_merge_model.write_external_data_to_new_location() + self.add_nodes([node for node in to_merge_model.nodes()]) + self.add_initializers([init for init in to_merge_model.initializer()]) + self.update() + + # add new output + for output in to_merge_model.graph().output: + if output.name not in self.output(): + self._model.graph.output.append(output) + + # remove unused output + remove_output = [] + for output in self._model.graph.output: + if output.name in to_merge_model.input(): + remove_output.append(output) + for output in remove_output: + self._model.graph.output.remove(output) + + # add new input + for input in to_merge_model.graph().input: + if ( + input.name not in self.input() + and input.name not in self.output() + and input.name not in self.output_name_to_node.keys() + ): + self._model.graph.input.append(input) + + def re_org_output(self, origin_output): + """Re-org output of merged model for layer-wise quantization.""" + outputs = {} + tmp_remove = [] + for output in self._model.graph.output: + outputs[output.name] = output + tmp_remove.append(output) + + for output in tmp_remove: + self._model.graph.output.remove(output) + + for out_name in origin_output: + self._model.graph.output.append(outputs[out_name]) diff --git a/test/adaptor/onnxrt_adaptor/test_layer_wise.py b/test/adaptor/onnxrt_adaptor/test_layer_wise.py new file mode 100644 index 00000000000..1b7cd01d33d --- /dev/null +++ b/test/adaptor/onnxrt_adaptor/test_layer_wise.py @@ -0,0 +1,79 @@ +import os +import shutil +import subprocess +import unittest + +import onnx +import onnxruntime as ort +from transformers import AutoTokenizer + +from neural_compressor import PostTrainingQuantConfig, quantization +from neural_compressor.utils.constant import FP32 + + +def Inference(model_path, data): + sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + out = sess.run(None, data) + return out + + +class DummyNLPDataloader(object): + def __init__(self, model_name): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.sequence_a = "intel-extension-for-transformers is based in SH" + self.sequence_b = "Where is intel-extension-for-transformers based? NYC or SH" + self.encoded_dict = self.tokenizer(self.sequence_a, self.sequence_b, return_tensors="pt") + self.encoded_dict["labels"] = 1 + self.batch_size = 1 + + def __iter__(self): + yield { + "input_ids": self.encoded_dict["input_ids"].detach().cpu().numpy(), + "attention_mask": self.encoded_dict["attention_mask"].detach().cpu().numpy(), + }, self.encoded_dict["labels"] + + +class TestWeightOnlyAdaptor(unittest.TestCase): + @classmethod + def setUpClass(self): + cmd = "optimum-cli export onnx --model yujiepan/llama-2-tiny-3layers-random --task text-generation --legacy tiny-llama/" + p = subprocess.Popen( + cmd, preexec_fn=os.setsid, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True + ) # nosec + p.communicate() + + self.model = onnx.load("tiny-llama/decoder_model.onnx") + self.dataloader = DummyNLPDataloader("yujiepan/llama-2-tiny-3layers-random") + + @classmethod + def tearDownClass(self): + shutil.rmtree("nc_workspace", ignore_errors=True) + shutil.rmtree("tiny-llama", ignore_errors=True) + + def test_layer_wise_W8A8_quant(self): + # layer-wise quantization + layerwise_quantized_model_path = "tiny-llama/layerwise_quantized_decoder_model.onnx" + config = PostTrainingQuantConfig( + calibration_sampling_size=[1], recipes={"layer_wise_quant": True}, op_type_dict={"^((?!(MatMul)).)*$": FP32} + ) + q_model = quantization.fit("tiny-llama/decoder_model.onnx", config, calib_dataloader=self.dataloader) + q_model.save(layerwise_quantized_model_path) + + # not layer-wise quantization + quantized_model_path = "tiny-llama/quantized_decoder_model.onnx" + config = PostTrainingQuantConfig( + calibration_sampling_size=[1], + recipes={"layer_wise_quant": False, "graph_optimization_level": "ENABLE_BASIC"}, + op_type_dict={"^((?!(MatMul)).)*$": FP32}, + ) + q_model = quantization.fit("tiny-llama/decoder_model.onnx", config, calib_dataloader=self.dataloader) + q_model.save(quantized_model_path) + + for data, _ in self.dataloader: + layerwise_q_out = Inference(layerwise_quantized_model_path, data) + q_out = Inference(quantized_model_path, data) + self.assertTrue((layerwise_q_out[0] == q_out[0]).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/model/test_onnx_model.py b/test/model/test_onnx_model.py index cb4e1a6c704..f384ac8f77a 100644 --- a/test/model/test_onnx_model.py +++ b/test/model/test_onnx_model.py @@ -433,14 +433,17 @@ def forward(self, x): torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") model = ONNXModel(model) # pass ModelProto - self.assertTrue(model.check_large_model()) + model.check_is_large_model() + self.assertTrue(model.is_large_model) model = ONNXModel("model.onnx") # pass string - self.assertTrue(model.check_large_model()) + model.check_is_large_model() + self.assertTrue(model.is_large_model) model = onnx.load("model.onnx", load_external_data=False) # not load init model = ONNXModel(model) - self.assertTrue(model.check_large_model()) + model.check_is_large_model() + self.assertTrue(model.is_large_model) # model < 2GB model = Net(10, 10 * 10) @@ -449,13 +452,16 @@ def forward(self, x): torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13) model = onnx.load("model.onnx") model = ONNXModel(model) # pass ModelProto - self.assertFalse(model.check_large_model()) + model.check_is_large_model() + self.assertFalse(model.is_large_model) model = ONNXModel("model.onnx") # pass string - self.assertFalse(model.check_large_model()) + model.check_is_large_model() + self.assertFalse(model.is_large_model) model = ONNXModel("model.onnx", load_external_data_for_model=False) # not load init - self.assertFalse(model.check_large_model()) + model.check_is_large_model() + self.assertFalse(model.is_large_model) if __name__ == "__main__":