From 296c5d4f1138e5bf33584fb75cea0f6ca5080122 Mon Sep 17 00:00:00 2001 From: Yi Liu <106061964+yiliu30@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:08:05 +0800 Subject: [PATCH] Add docstring for PT2E and HQQ (#1937) Signed-off-by: yiliu30 --- .../scripts/codeScan/pydocstyle/scan_path.txt | 4 + .../torch/algorithms/pt2e_quant/__init__.py | 1 + .../torch/algorithms/pt2e_quant/core.py | 39 +++++++- .../pt2e_quant/half_precision_rewriter.py | 48 +++++++++- .../torch/algorithms/pt2e_quant/save_load.py | 17 ++++ .../torch/algorithms/pt2e_quant/utility.py | 22 +++++ .../algorithms/weight_only/hqq/__init__.py | 1 + .../algorithms/weight_only/hqq/bitpack.py | 89 ++++++++++++++++++- .../algorithms/weight_only/hqq/config.py | 17 +++- .../torch/algorithms/weight_only/hqq/core.py | 65 +++++++++++++- .../algorithms/weight_only/hqq/optimizer.py | 18 +++- .../algorithms/weight_only/hqq/qtensor.py | 50 ++++++++--- .../algorithms/weight_only/hqq/quantizer.py | 40 +++++++-- .../algorithms/weight_only/hqq/utility.py | 76 ---------------- neural_compressor/torch/export/__init__.py | 1 + neural_compressor/torch/export/pt2e_export.py | 27 +++++- .../torch/quantization/config.py | 53 +++++++++-- 17 files changed, 454 insertions(+), 114 deletions(-) delete mode 100644 neural_compressor/torch/algorithms/weight_only/hqq/utility.py diff --git a/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt b/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt index b524f1f61db..1fb4c3ceda7 100644 --- a/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt +++ b/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt @@ -15,3 +15,7 @@ /neural-compressor/neural_compressor/strategy /neural-compressor/neural_compressor/training.py /neural-compressor/neural_compressor/utils +/neural_compressor/torch/algorithms/pt2e_quant +/neural_compressor/torch/export +/neural_compressor/common +/neural_compressor/torch/algorithms/weight_only/hqq \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/pt2e_quant/__init__.py b/neural_compressor/torch/algorithms/pt2e_quant/__init__.py index b3c530ce2fd..27ef2e0d8d0 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/__init__.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""The PT2E-related modules.""" from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer diff --git a/neural_compressor/torch/algorithms/pt2e_quant/core.py b/neural_compressor/torch/algorithms/pt2e_quant/core.py index a1b4d1f65b6..4707295cd32 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/core.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/core.py @@ -14,7 +14,7 @@ # Some code snippets are taken from the X86InductorQuantizer tutorial. # https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html - +"""The quantizer using PT2E path.""" from typing import Any @@ -30,13 +30,24 @@ class W8A8PT2EQuantizer(Quantizer): + """The W8A8 quantizer using PT2E.""" + is_dynamic = False def __init__(self, quant_config=None): + """Initialize the quantizer.""" super().__init__(quant_config) @staticmethod def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer: + """Updates the quantizer based on the given quantization configuration. + + Args: + quant_config (dict): The quantization configuration. Defaults to None. + + Returns: + X86InductorQuantizer: The updated quantizer object. + """ if not quant_config: quantizer = X86InductorQuantizer() quantizer.set_global( @@ -47,9 +58,18 @@ def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuan return quantizer def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule: - """Prepare the model for calibration. + """Prepares the model for calibration. Create the `quantizer` according to the `quant_config`, and insert the observers accordingly. + + Args: + model (GraphModule): The model to be prepared for calibration. + example_inputs (tuple, optional): Example inputs to be used for calibration. Defaults to None. + inplace (bool, optional): Whether to modify the model in-place or return a new prepared model. + Defaults to True. + + Returns: + GraphModule: The prepared model. """ quant_config = self.quant_config assert model._exported, "The model should be exported before preparing it for calibration." @@ -58,7 +78,14 @@ def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, return prepared_model def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule: - """Convert the calibrated model into qdq mode.""" + """Convert the calibrated model into qdq mode. + + Args: + model (GraphModule): The prepared model. + + Returns: + GraphModule: The converted quantized model. + """ fold_quantize = kwargs.get("fold_quantize", False) converted_model = convert_pt2e(model, fold_quantize=fold_quantize) logger.warning("Converted the model in qdq mode, please compile it to accelerate inference.") @@ -67,6 +94,12 @@ def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule: return converted_model def half_precision_transformation(self, model, config): + """Applies half-precision transformation to the given model in-place. + + Args: + model: The model to apply the transformation to. + config: The configuration for the transformation. + """ half_precision_node_set = hp_rewriter.get_half_precision_node_set(model, config) logger.info("Try to convert %d nodes to half precision.", len(half_precision_node_set)) hp_rewriter.transformation(model, half_precision_node_set) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index bd1865e674c..9f767684054 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Rewrite the FP32 operators to FP16 or BF16 operators.""" from dataclasses import dataclass from functools import partial @@ -34,6 +35,14 @@ @dataclass class PatternPair: + """Represents a pair of patterns used for search and replacement in a graph. + + Attributes: + fn (TorchFuncType): The function type associated with the pattern pair. + search_pattern (torch.fx.GraphModule): The search pattern to be matched in the graph. + replace_pattern (torch.fx.GraphModule): The replacement pattern to be used when a match is found. + """ + fn: TorchFuncType search_pattern: torch.fx.GraphModule replace_pattern: torch.fx.GraphModule @@ -101,6 +110,15 @@ def _register_pattern_pair(dtype: torch.dtype) -> None: def get_filter_fn(node_list, fn): + """Filter function to check if a node with the target operator is in the given `node_list`. + + Args: + node_list (list): List of nodes to check against. + fn (str): Target operator. + + Returns: + bool: True if the node with the target operator is in the `node_list`, False otherwise. + """ target_op = FN_ATEN_OPS_MAPPING[fn] def is_target_node_in_candidate_list(match, original_graph, pattern_graph): @@ -119,6 +137,16 @@ def is_target_node_in_candidate_list(match, original_graph, pattern_graph): def apply_single_pattern_pair(gm: torch.fx.GraphModule, pattern_pair: PatternPair, node_list): + """Applies a single pattern pair to a given GraphModule. + + Args: + gm (torch.fx.GraphModule): The GraphModule to apply the pattern pair to. + pattern_pair (PatternPair): The pattern pair containing the search and replace patterns. + node_list: The list of nodes to filter for pattern matching. + + Returns: + List[Match]: A list of Match objects representing the matches found after applying the pattern pair. + """ filter_fn = get_filter_fn(node_list, pattern_pair.fn) match_and_replacements = subgraph_rewriter.replace_pattern_with_filters( gm=gm, @@ -133,6 +161,14 @@ def apply_single_pattern_pair(gm: torch.fx.GraphModule, pattern_pair: PatternPai def get_unquantized_node_set(gm: torch.fx.GraphModule): + """Retrieves the set of unquantized nodes from a given GraphModule. + + Args: + gm (torch.fx.GraphModule): The GraphModule to retrieve unquantized nodes from. + + Returns: + set: A set containing the unquantized nodes. + """ unquantized_node_set = set() for node in gm.graph.nodes: if meta := getattr(node, "meta"): @@ -180,7 +216,17 @@ def _parse_node_candidate_set_from_user_config(config, gm): def get_half_precision_node_set(gm, config): - """Intersection between `unquantized_node_set` and `node_set_from_user_config`""" + """Retrieves a set of nodes from the given graph model (gm) that are candidates for conversion to half precision. + + The result is the intersection between `unquantized_node_set` and `node_set_from_user_config`. + + Args: + gm (GraphModel): The graph model to search for nodes. + config (dict): User configuration for node candidate set. + + Returns: + set: A set of nodes that are candidates for conversion to half precision. + """ # TODO: implement it, current return all unquantized_node_set node_set_from_user_config = _parse_node_candidate_set_from_user_config(config, gm) diff --git a/neural_compressor/torch/algorithms/pt2e_quant/save_load.py b/neural_compressor/torch/algorithms/pt2e_quant/save_load.py index 606c31f41c2..7e2700e94cf 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/save_load.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/save_load.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Save and load the quantized model.""" + import json import os @@ -22,6 +24,13 @@ def save(model, example_inputs, output_dir="./saved_results"): + """Save the quantized model and its configuration. + + Args: + model (torch.nn.Module): The quantized model to be saved. + example_inputs (torch.Tensor or tuple of torch.Tensor): Example inputs used for tracing the model. + output_dir (str, optional): The directory where the saved results will be stored. Defaults to "./saved_results". + """ os.makedirs(output_dir, exist_ok=True) qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) @@ -37,6 +46,14 @@ def save(model, example_inputs, output_dir="./saved_results"): def load(output_dir="./saved_results"): + """Load a quantized model from the specified output directory. + + Args: + output_dir (str): The directory where the quantized model is saved. Defaults to "./saved_results". + + Returns: + torch.nn.Module: The loaded quantized model. + """ qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) loaded_quantized_ep = torch.export.load(qmodel_file_path) return loaded_quantized_ep.module() diff --git a/neural_compressor/torch/algorithms/pt2e_quant/utility.py b/neural_compressor/torch/algorithms/pt2e_quant/utility.py index e4efd62271e..ecf14ec02a7 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/utility.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/utility.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Utility functions for PT2E quantization.""" from typing import Dict @@ -24,6 +25,18 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec: + """Create a quantization specification based on the given configuration. + + Args: + dtype (str): The desired data type for quantization. Valid options are "int8" and "uint8". + sym (bool): Whether to use symmetric quantization or not. + granularity (str): The granularity of quantization. Valid options are "per_channel" and "per_tensor". + algo (str): The algorithm to use for quantization. Valid options are "placeholder", "minmax", and "kl". + is_dynamic (bool, optional): Whether to use dynamic quantization or not. Defaults to False. + + Returns: + QuantizationSpec: The created quantization specification. + """ dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8} select_dtype = dtype_mapping[dtype] min_max_mapping = {torch.int8: (-128, 127), torch.uint8: (0, 255)} @@ -76,6 +89,15 @@ def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> Quant def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86InductorQuantizer: + """Creates an instance of X86InductorQuantizer based on the given configuration. + + Args: + config: The configuration object containing the quantization settings. + is_dynamic: A boolean indicating whether dynamic quantization is enabled. + + Returns: + An instance of X86InductorQuantizer initialized with the provided configuration. + """ quantizer = xiq.X86InductorQuantizer() # set global global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic) diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py b/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py index b11b6095066..19be7c0ded4 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""HQQ-related modules.""" from .quantizer import HQQuantizer from .config import HQQModuleConfig, QTensorConfig diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py b/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py index 5500201a4ee..75be966caa1 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/bitpack.py @@ -19,31 +19,57 @@ # Notice: Copied from from https://github.com/mobiusml/hqq # Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023 ##################################################### +"""Bit packing logic for HQQ.""" + import numpy as np import torch -from .utility import is_divisible - __all__ = ["Packer"] # Bit packing logic. format: pack/unpack_nBits_target- class BitPack: + """Packing and unpacking tensors into different bit representations.""" + # 8-bit ################################################ @staticmethod def pack_8bit_u8(W_q): + """Packs the given tensor into 8-bit unsigned integers. + + Args: + W_q (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The packed tensor. + """ return W_q.to(torch.uint8) @staticmethod def unpack_8bit_u8(W_q): + """Unpacks the given 8-bit tensor into 8-bit unsigned integer tensor. + + Args: + W_q (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The unpacked tensor. + """ return W_q # 4-bit ################################################ @staticmethod def pack_4bit_u8(W_q): # uint8 > uint8/2 + """Packs the given 4-bit tensor into 8-bit unsigned integers. + + Args: + W_q (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The packed tensor. + """ W_q = W_q.to(torch.uint8) _step = int(len(W_q) / 2) return (W_q[:_step] << 4) | W_q[_step:] @@ -51,6 +77,14 @@ def pack_4bit_u8(W_q): # uint8 > uint8/2 # A bit faster than the _cat version @staticmethod def unpack_4bit_u8(W_q): # uint8/2 > uint8 + """Unpacks the given 4-bit tensor into 8-bit unsigned integers. + + Args: + W_q (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The unpacked tensor. + """ _step = W_q.shape[0] tmp = torch.empty([2 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device) tmp[:_step] = (W_q & 0b11110000) >> 4 @@ -61,6 +95,14 @@ def unpack_4bit_u8(W_q): # uint8/2 > uint8 ################################################ @staticmethod def pack_2bit_u8(W_q): # uint8 > uint8/4 + """Packs the given 2-bit tensor into 8-bit unsigned integers. + + Args: + W_q (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The packed tensor. + """ W_q = W_q.to(torch.uint8) _step = int(len(W_q) / 4) return W_q[:_step] << 6 | W_q[_step : 2 * _step] << 4 | W_q[2 * _step : 3 * _step] << 2 | W_q[3 * _step :] @@ -68,6 +110,14 @@ def pack_2bit_u8(W_q): # uint8 > uint8/4 # A bit faster than the _cat version @staticmethod def unpack_2bit_u8(W_q): + """Unpacks the tensor packed by `pack_2bit_u8`. + + Args: + W_q (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The unpacked tensor. + """ _step = W_q.shape[0] tmp = torch.empty([4 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device) tmp[:_step] = (W_q & 0b11000000) >> 6 @@ -80,6 +130,14 @@ def unpack_2bit_u8(W_q): ################################################ @staticmethod def pack_3bit_32(W_q_in): + """Packs the given 3-bit tensor into 32-bit signed integers. + + Args: + W_q_in (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The packed tensor. + """ W_q = torch.zeros( [int(10 * np.ceil(W_q_in.shape[0] / 10.0)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.int32 ) @@ -102,6 +160,14 @@ def pack_3bit_32(W_q_in): # A bit faster than _cat version @staticmethod def unpack_3bit_32(W_q): + """Unpacks the tensor packed by `pack_3bit_32`. + + Args: + W_q (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The unpacked tensor. + """ _step = W_q.shape[0] tmp = torch.empty([10 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device) tmp[:_step] = (W_q & 0b00111000000000000000000000000000) >> 27 @@ -118,7 +184,8 @@ def unpack_3bit_32(W_q): class Packer: - # TODO: Refine the packer + """Pack/unpack functions collection.""" + bit_to_packing = {8: "8bit_u8", 4: "4bit_u8", 3: "3bit_32", 2: "2bit_u8"} pack_fn_mapping = { @@ -137,8 +204,24 @@ class Packer: @staticmethod def get_pack_fn(nbits: int): + """Get the pack function for the specified number of bits. + + Args: + nbits (int): The number of bits. + + Returns: + function: The pack function for the specified number of bits. + """ return Packer.pack_fn_mapping[Packer.bit_to_packing[nbits]] @staticmethod def get_unpack_fn(nbits: int): + """Get the unpack function for the specified number of bits. + + Args: + nbits (int): The number of bits. + + Returns: + function: The unpack function for the specified number of bits. + """ return Packer.unpack_fn_mapping[Packer.bit_to_packing[nbits]] diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/config.py b/neural_compressor/torch/algorithms/weight_only/hqq/config.py index a0ee29a22d7..b1460713018 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/config.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/config.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Configuration for HQQ.""" import os from collections import namedtuple @@ -33,6 +33,8 @@ class HQQGlobalOptions: + """Global options for HQQ.""" + use_half = os.getenv("HQQ_NOT_USE_HALF", "0") == "0" @@ -41,6 +43,8 @@ class HQQGlobalOptions: @dataclass class QTensorConfig: + """Configuration class for quantized tensors.""" + nbits: int channel_wise: bool = True group_size: int = 128 @@ -49,6 +53,7 @@ class QTensorConfig: pack: bool = True def __repr__(self) -> str: + """Return a string representation of the QTensorConfig.""" return ( f"QTensorConfig(nbits={self.nbits}, channel_wise={self.channel_wise}, " f"group_size={self.group_size}, optimize={self.optimize}, " @@ -67,15 +72,25 @@ class HQQModuleConfig( ["weight", "scale", "zero"], ) ): + """Configuration class for HQQModule. + + Args: + weight (Any): The weight quantization configuration. + scale (Any): The scale quantization configuration. + zero (Any): The zero quantization configuration. + """ + def __new__( cls, weight=default_weight_quant_config, scale=default_scale_quant_config, zero=default_zero_quant_config, ): + """Create a new HQQModuleConfig.""" return super().__new__(cls, weight, scale, zero) def __repr__(self) -> str: + """Return a string representation of the HQQModuleConfig.""" return ( f"HQQModuleConfig(\n" f" weight={self.weight},\n" f" scale={self.scale},\n" f" zero={self.zero}\n)" ) diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/core.py b/neural_compressor/torch/algorithms/weight_only/hqq/core.py index 041e173671d..38289954c55 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/core.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/core.py @@ -18,11 +18,14 @@ # NOTICE: the original `Quantizer` has been modified to `HQQTensorHandle` # and `QTensor` to decouple the data structure and the quantization logic. +"""The HQQ modules.""" + from typing import Any, Dict, Mapping, Tuple import torch +from neural_compressor.common.utils import dump_elapsed_time from neural_compressor.torch.utils import logger from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator @@ -30,7 +33,6 @@ from .config import HQQModuleConfig, QTensorConfig, default_hqq_module_config, hqq_global_option from .optimizer import optimize_weights_proximal from .qtensor import QTensor, QTensorMetaInfo -from .utility import dump_elapsed_time, is_divisible __all__ = [ "HQQTensorHandle", @@ -39,6 +41,8 @@ class HQQTensorHandle: + """HQQ Tensor Handle to quantize and dequantize the tensor.""" + # Refactored the code from https://github.com/mobiusml/hqq. # Store meta-data (we invert the scale for dequantization) @@ -47,6 +51,15 @@ class HQQTensorHandle: @classmethod def quantize(cls, float_tensor, tensor_quant_config: QTensorConfig = None): + """Quantizes a given float tensor using the specified tensor quantization configuration. + + Args: + float_tensor (torch.Tensor): The float tensor to be quantized. + tensor_quant_config (QTensorConfig, optional): The tensor quantization configuration. Defaults to None. + + Returns: + torch.Tensor: The quantized tensor. + """ q_weight, q_tensor_meta = cls._quantize( tensor=float_tensor, tensor_quant_config=tensor_quant_config, @@ -56,7 +69,14 @@ def quantize(cls, float_tensor, tensor_quant_config: QTensorConfig = None): @classmethod def dequantize(cls, q_weight: "QTensor") -> torch.Tensor: - # Dequantized the Qtensor into float tensor + """Dequantizes the QTensor into a float tensor. + + Args: + q_weight (QTensor): The quantized weight tensor. + + Returns: + torch.Tensor: The dequantized float tensor. + """ meta = q_weight.meta_info.to_dict() meta["zero"] = q_weight.zero meta["scale"] = q_weight.scale @@ -88,7 +108,7 @@ def _quantize(cls, tensor, tensor_quant_config: QTensorConfig = None): assert nbits in cls.SUPPORTED_BITS, "nbits=" + str(nbits) + " not supported." assert axis in [0, 1], "axis should be either 0 or 1, but got {}".format(axis) if group_size is not None: - assert is_divisible(tensor.numel(), group_size), ( + assert tensor.numel() % group_size == 0, ( "group_size should be divisible by the total tensor dimensions. shape: " + str(tensor.shape) + ", group_size: " @@ -176,6 +196,7 @@ def _dequantize(cls, W_q, meta): class HQQLinear(torch.nn.Linear): + """HQQ Linear module.""" def __init__( self, @@ -186,6 +207,7 @@ def __init__( device=None, dtype=None, ) -> None: + """Init a HQQ linear.""" super().__init__(in_features, out_features, bias, device, dtype) self.q_weight = q_weight self.quantized = q_weight is not None @@ -196,6 +218,17 @@ def quantize_weight( W: torch.Tensor, quant_config: HQQModuleConfig = default_hqq_module_config, ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Quantizes the weight using HQQ. + + Args: + W (torch.Tensor): The weight tensor to be quantized. + quant_config (HQQModuleConfig, optional): The quantization configuration. + Defaults to default_hqq_module_config. + + Returns: + Tuple[torch.Tensor, Dict[str, Any]]: A tuple containing the quantized weight tensor + and a dictionary of additional information. + """ weight_quant_config, scale_quant_config, zero_quant_config = ( quant_config.weight, quant_config.scale, @@ -227,6 +260,7 @@ def quantize_weight( self.quantized = True def dequantize_weight(self): + """Dequantize the weight tensor.""" assert self.quantized, "model was not quantized" # TODO: move below logic into `HQQTensorHandle` if self.q_weight.is_scale_quantized(): @@ -241,6 +275,7 @@ def dequantize_weight(self): return W_qdq def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward pass of the HQQ linear module.""" out = torch.matmul(input, self.dequantize_weight().t()) if self.bias is not None: out += self.bias @@ -252,6 +287,16 @@ def from_float( float_module: torch.nn.Linear, quant_config: HQQModuleConfig = default_hqq_module_config, ): + """Create a new HQQModule instance from a floating-point linear. + + Args: + float_module (torch.nn.Linear): The floating-point module to convert. + quant_config (HQQModuleConfig, optional): The quantization configuration. + Defaults to default_hqq_module_config. + + Returns: + HQQModule: The converted HQQModule instance. + """ # Create the new module with a toy size to ensure initialization is fast fake_in_features, fake_out_features = 8, 8 new_mod = cls( @@ -260,7 +305,7 @@ def from_float( bias=float_module.bias is not None, ) new_mod.requires_grad_ = False - # Construct the q weight frpm float weight + # Construct the q weight from float weight new_mod.quantize_weight(float_module.weight, quant_config=quant_config) # Update the linear module attributes new_mod.in_features = float_module.in_features @@ -280,6 +325,18 @@ def from_float( return new_mod def state_dict(self, *args, **kwargs): # nn.Module override compatible + """Returns a dictionary containing the state of the module. + + The state dictionary contains the weights of the `q_weight` attribute. + If the `bias` attribute is not None, it is also included in the state dictionary. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + dict: A dictionary containing the state of the module. + """ state_dict = self.q_weight.to_state_dict() if self.bias is not None: state_dict["bias"] = self.bias diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py b/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py index e471e6c017a..6614e28cebf 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Optimization logic of HQQ.""" import numpy as np import torch @@ -35,6 +35,22 @@ def optimize_weights_proximal_legacy( opt_params={"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20}, verbose=False, ): + """Quantize the scale/zero of quantized tensor using the HQQ. + + Args: + tensor (torch.Tensor): The input tensor to optimize. + scale (torch.Tensor): The scaling factor for quantization. + zero (torch.Tensor): The zero-point for quantization. + min_max (tuple): The minimum and maximum values for quantization. + axis (int, optional): The axis along which to compute the mean for zero-point calculation. Defaults to 0. + device (str, optional): The device to use for computation. Defaults to "cuda". + opt_params (dict, optional): Optimization parameters. + Defaults to {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20}. + verbose (bool, optional): Whether to print verbose output. Defaults to False. + + Returns: + tuple: A tuple containing the optimized scale and zero-point tensors. + """ lp_norm, beta, kappa, iters = ( opt_params["lp_norm"], opt_params["beta"], diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py b/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py index f1fbd5bce3a..3d250bdf220 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/qtensor.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""QTensor for HQQ.""" from dataclasses import asdict, dataclass -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch @@ -25,6 +26,16 @@ @dataclass class QTensorMetaInfo: + """Represents the meta information of a quantized tensor. + + Attributes: + nbits (int): The number of bits used for quantization. + group_size (int): The size of the quantization group. + shape (Tuple): The shape of the tensor. + axis (int): The axis along which the tensor is quantized. + packing (bool): Indicates whether the tensor is packed. + """ + nbits: int group_size: int shape: Tuple @@ -32,34 +43,45 @@ class QTensorMetaInfo: packing: bool def to_dict(self): + """Converts the QTensorMetaInfo object to a dictionary. + + Returns: + dict: A dictionary representation of the QTensorMetaInfo object. + """ return asdict(self) class QTensor: - val: torch.Tensor - scale: Union[torch.Tensor, "QTensor"] = None - zero: Union[torch.Tensor, "QTensor"] = None - meta_info: QTensorMetaInfo = None - """ - val: torch.Tensor - scale: + """Represents a quantized tensor. + + Example: val: torch.Tensor - scale: torch.Tensor - zero: torch.Tensor - zero: - torch.Tensor + scale: + val: torch.Tensor + scale: torch.Tensor + zero: torch.Tensor + zero: + torch.Tensor """ + val: torch.Tensor + scale: Union[None, torch.Tensor, "QTensor"] = None + zero: Union[None, torch.Tensor, "QTensor"] = None + meta_info: Optional[QTensorMetaInfo] = None + def __init__(self, val, scale=None, zero=None, meta_info=None): + """Init a QTensor object.""" self.val = val self.scale = scale self.zero = zero self.meta_info = meta_info def is_scale_quantized(self) -> bool: + """Check if the scale is quantized.""" return isinstance(self.scale, QTensor) def is_zero_quantized(self) -> bool: + """Check if the zero is quantized.""" return isinstance(self.zero, QTensor) def _get_scale_repr(self) -> str: @@ -89,6 +111,7 @@ def _get_zero_repr(self) -> str: return self.zero.__repr__() + "\n" def __repr__(self) -> str: + """Return the string representation of the QTensor object.""" # TODO: refine it later return ( f"QTensor(\n" @@ -101,12 +124,14 @@ def __repr__(self) -> str: ) def to(self, *args, **kwargs): + """Move the QTensor object to a new device or new dtype.""" self.val = self.val.to(*args, **kwargs) self.scale = self.scale.to(*args, **kwargs) self.zero = self.zero.to(*args, **kwargs) return self def half(self): + """Convert the QTensor object to half precision.""" # TODO: refine it later if self.val.dtype == torch.float32: self.val = self.val.half() @@ -117,6 +142,7 @@ def half(self): return self def to_state_dict(self): + """Convert the QTensor object to a state dictionary for serialization.""" state = {} state["val"] = self.val state["meta_info"] = self.meta_info.to_dict() diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py b/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py index 43b1dda1b4a..26de60ede23 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""HQQ Quantizer.""" + from typing import Callable, List, Optional, Tuple @@ -35,8 +37,7 @@ def _replace_with_custom_fn_if_matches_filter( cur_fqn: str = "", config_mapping: Optional[ConfigMappingType] = None, ) -> None: - """For each `child` in `model`, replaces it with `replacement_fn(child)` - if `filter_fn(child)` is `True`""" + """Recursively replaces modules in `model` with `replacement_fn` if `filter_fn` is `True`.""" name_to_child = dict(model.named_children()) for name, child in name_to_child.items(): if cur_fqn == "": @@ -64,21 +65,52 @@ def _replace_with_custom_fn_if_matches_filter( def patch_hqq_moduile(mod, config): + """Patch the given module with the HQQLinear module. + + Args: + mod (torch.nn.Module): The module to be patched. + config (dict): Configuration parameters for the HQQLinear module. + + Returns: + torch.nn.Module: The patched module with HQQLinear. + """ new_mod = HQQLinear.from_float(mod, config) return new_mod def filter_fn(mod: torch.nn.Module, name: str, config_mapping: ConfigMappingType) -> bool: + """Filter function used to determine if a module should be quantized. + + Args: + mod (torch.nn.Module): The module to be checked. + name (str): The name of the module. + config_mapping (ConfigMappingType): The configuration mapping. + + Returns: + bool: True if the module should be quantized, False otherwise. + """ return isinstance(mod, torch.nn.Linear) and name in config_mapping def replacement_fn(mod: torch.nn.Module, name: str, config_mapping: ConfigMappingType) -> torch.nn.Module: + """Replaces a Linear with HQQLinear if the module is in the config mapping. + + Args: + mod (torch.nn.Module): The original module to be replaced. + name (str): The name of the module to be replaced. + config_mapping (ConfigMappingType): A mapping of module names to their corresponding configurations. + + Returns: + torch.nn.Module: The patched module. + """ config = config_mapping.get(name, None) logger.debug("Replace module %s", name) return patch_hqq_moduile(mod, config) class HQQuantizer(Quantizer): + """HQQ Quantizer.""" + def __init__(self, quant_config: ConfigMappingType) -> None: """Init a HQQuantizer object. @@ -114,10 +146,6 @@ def convert(self, model: torch.nn.Module, *args, **kwargs) -> Optional[torch.nn. ) return model - def save(self, model, path): - # TODO: to implement it in the next PR - pass - def _convert_hqq_module_config(self, config) -> HQQModuleConfig: # TODO: (Yi) Please note that the configuration defined by INC should be separated from the algorithm. # * 3.x API use `bits` for woq while HQQ internal API use `nbits`, we should change it in algorithm_entry.py diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/utility.py b/neural_compressor/torch/algorithms/weight_only/hqq/utility.py deleted file mode 100644 index 9c9b3700cf6..00000000000 --- a/neural_compressor/torch/algorithms/weight_only/hqq/utility.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import gc -import time - -import numpy as np -import psutil -import torch - -from neural_compressor.torch.utils import logger - -__all__ = [ - "is_divisible", - "dump_elapsed_time", -] - - -def is_divisible(val1, val2): - return int(val2 * np.ceil(val1 / val2)) == val1 - - -def see_cuda_memory_usage(message, force=False): # pragma: no cover - # Copied from https://github.com/microsoft/DeepSpeed - # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports - gc.collect() - - # logger.info message except when distributed but not rank 0 - logger.info(message) - logger.info( - f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ - Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ - CA {round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024),2)} GB \ - Max_CA {round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024))} GB " - ) - vm_stats = psutil.virtual_memory() - used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2) - logger.info(f"CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%") - - # get the peak memory to report correct data, so reset the counter for the next call - torch.cuda.reset_peak_memory_stats() - - -def dump_elapsed_time(customized_msg=""): - """Get the elapsed time for decorated functions. - - Args: - customized_msg (string, optional): The parameter passed to decorator. Defaults to None. - """ - - def f(func): - def fi(*args, **kwargs): - start = time.time() - res = func(*args, **kwargs) - end = time.time() - logger.info( - "%s elapsed time: %s ms" - % (customized_msg if customized_msg else func.__qualname__, round((end - start) * 1000, 2)) - ) - return res - - return fi - - return f diff --git a/neural_compressor/torch/export/__init__.py b/neural_compressor/torch/export/__init__.py index e3e4775e986..7c69d8f289f 100644 --- a/neural_compressor/torch/export/__init__.py +++ b/neural_compressor/torch/export/__init__.py @@ -11,5 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Export module for quantization.""" from neural_compressor.torch.export.pt2e_export import export_model_for_pt2e_quant, export diff --git a/neural_compressor/torch/export/pt2e_export.py b/neural_compressor/torch/export/pt2e_export.py index 579e816894f..af232cc2ad4 100644 --- a/neural_compressor/torch/export/pt2e_export.py +++ b/neural_compressor/torch/export/pt2e_export.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Export model for quantization.""" from typing import Any, Dict, Optional, Tuple, Union @@ -29,7 +30,20 @@ def export_model_for_pt2e_quant( example_inputs: Tuple[Any], dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ) -> Optional[GraphModule]: - """Export the eager model into model with Aten IR.""" + """Exports a eager model for PT2E quantization. + + Args: + model (torch.nn.Module): The PyTorch model to be exported. + example_inputs (Tuple[Any]): Example inputs to the model. + dynamic_shapes (Optional[Union[Dict[str, Any], Tuple[Any]]], optional): + Dynamic shapes for the model inputs. Defaults to None. + + Returns: + Optional[GraphModule]: The exported model as a GraphModule. + + Raises: + AssertionError: If `example_inputs` is not a tuple. + """ assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}" # Set the model to eval mode model = model.eval() @@ -66,6 +80,17 @@ def export( example_inputs: Tuple[Any], dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ) -> Optional[GraphModule]: + """Unified export function for quantization. + + Args: + model (torch.nn.Module): The model to be exported. + example_inputs (Tuple[Any]): Example inputs to the model. + dynamic_shapes (Optional[Union[Dict[str, Any], Tuple[Any]]], optional): + Dynamic shapes for the model. Defaults to None. + + Returns: + Optional[GraphModule]: The exported model for quantization. + """ if not is_ipex_imported(): return export_model_for_pt2e_quant(model, example_inputs, dynamic_shapes) else: diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 2c43f1e59c1..4df82260551 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -1290,9 +1290,12 @@ def get_default_sq_config() -> SmoothQuantConfig: ######################## HQQ Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=HQQ, priority=PRIORITY_HQQ) class HQQConfig(TorchBaseConfig): - # Half-Quadratic Quantization (HQQ), more details: - # Blog: https://mobiusml.github.io/hqq_blog/ - # Code: https://github.com/mobiusml/hqq + """Configuration class for Half-Quadratic Quantization (HQQ). + + HQQ is a quantization algorithm that reduces the precision of weights and activations in neural networks. + For more details, refer to the blog: https://mobiusml.github.io/hqq_blog/ + and the code: https://github.com/mobiusml/hqq + """ name = HQQ params_list = [ @@ -1301,7 +1304,6 @@ class HQQConfig(TorchBaseConfig): "quant_zero", "quant_scale", "scale_quant_group_size", - # quant_lm_head "quant_lm_head", ] supported_configs: List[OperatorConfig] = [] @@ -1314,10 +1316,22 @@ def __init__( quant_zero: bool = True, quant_scale: bool = False, scale_quant_group_size: int = 128, - # quant lm_head quant_lm_head: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): + """Initialize HQQConfig. + + Args: + dtype (str): Data type for quantization. Default is "int". + bits (int): Number of bits for quantization. Default is 4. + group_size (int): Group size for quantization. Default is 64. + quant_zero (bool): Whether to quantize zero values. Default is True. + quant_scale (bool): Whether to quantize scale values. Default is False. + scale_quant_group_size (int): Group size for scale quantization. Default is 128. + quant_lm_head (bool): Whether to quantize the language model head. Default is False. + white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types. + Default is DEFAULT_WHITE_LIST. + """ super().__init__(white_list=white_list) self.dtype = dtype self.bits = bits @@ -1330,7 +1344,11 @@ def __init__( @classmethod def register_supported_configs(cls) -> List[OperatorConfig]: - # TODO: to be refined + """Register supported configurations for HQQ. + + Returns: + List[OperatorConfig]: List of supported operator configurations. + """ supported_configs = [] linear_hqq_config = HQQConfig() operators = list(WOQ_WHITE_LIST) @@ -1339,6 +1357,14 @@ def register_supported_configs(cls) -> List[OperatorConfig]: @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: + """Get information about the model. + + Args: + model (torch.nn.Module): The model. + + Returns: + List[Tuple[str, Callable]]: List of tuples containing the name and type of each module in the model. + """ filter_result = [] for op_name, module in model.named_modules(): if isinstance(module, WOQ_WHITE_LIST): @@ -1349,6 +1375,16 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: + """Convert the configuration to a mapping. + + Args: + config_list (List[BaseConfig]): List of base configurations. Default is None. + model_info (List[Tuple[str, str]]): List of tuples containing the name and type of each module in the model. + Default is None. + + Returns: + OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: The configuration mapping. + """ if not self.quant_lm_head: self.set_local(LM_HEAD_NAMES, HQQConfig(dtype="fp32")) config_mapping = super().to_config_mapping(config_list, model_info) @@ -1356,6 +1392,11 @@ def to_config_mapping( @classmethod def get_config_set_for_tuning(cls) -> Union[None, "HQQConfig", List["HQQConfig"]]: + """Get the configuration set for tuning. + + Returns: + Union[None, "HQQConfig", List["HQQConfig"]]: The configuration set for tuning. + """ return HQQConfig(bits=[4, 8])