From 77aeb456b39d61a363e77582a886ee9af18e6f75 Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Tue, 16 Jul 2024 23:50:46 -0700 Subject: [PATCH 1/4] add docstring for mx quant Signed-off-by: Mengni Wang --- .../torch/algorithms/mx_quant/__init__.py | 1 + .../torch/algorithms/mx_quant/mx.py | 9 +- .../torch/algorithms/mx_quant/utils.py | 131 +++++++++++------- 3 files changed, 88 insertions(+), 53 deletions(-) diff --git a/neural_compressor/torch/algorithms/mx_quant/__init__.py b/neural_compressor/torch/algorithms/mx_quant/__init__.py index e54bfa18052..b9824388148 100644 --- a/neural_compressor/torch/algorithms/mx_quant/__init__.py +++ b/neural_compressor/torch/algorithms/mx_quant/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. # pylint:disable=import-error +"""MX quantization.""" \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/mx_quant/mx.py b/neural_compressor/torch/algorithms/mx_quant/mx.py index 76af3511e20..208f6ad2698 100644 --- a/neural_compressor/torch/algorithms/mx_quant/mx.py +++ b/neural_compressor/torch/algorithms/mx_quant/mx.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""MX quantization.""" from collections import OrderedDict @@ -31,6 +31,8 @@ class MXLinear(torch.nn.Linear): + """Linear for MX data type.""" + def __init__( self, in_features, @@ -39,6 +41,7 @@ def __init__( mx_specs=None, name=None, ): + """Initialization function.""" self.mx_none = mx_specs is None self.name = name @@ -46,6 +49,7 @@ def __init__( super().__init__(in_features, out_features, bias) def apply_mx_specs(self): + """Apply MX data type to weight.""" if self.mx_specs is not None: if self.mx_specs.out_dtype != "float32": self.weight.data = quantize_elemwise_op(self.weight.data, mx_specs=self.mx_specs) @@ -63,6 +67,7 @@ def apply_mx_specs(self): ) def forward(self, input): + """Forward function.""" if self.mx_none: return super().forward(input) @@ -93,6 +98,8 @@ def forward(self, input): class MXQuantizer(Quantizer): + """Quantizer of MX data type.""" + def __init__(self, quant_config: OrderedDict = {}): """Init a MXQuantizer object. diff --git a/neural_compressor/torch/algorithms/mx_quant/utils.py b/neural_compressor/torch/algorithms/mx_quant/utils.py index 2da59c6c700..dfd879901b8 100644 --- a/neural_compressor/torch/algorithms/mx_quant/utils.py +++ b/neural_compressor/torch/algorithms/mx_quant/utils.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""MX quantization utils.""" from enum import Enum, IntEnum @@ -28,6 +28,7 @@ class ElemFormat(Enum): + """Element format.""" int8 = 1 int4 = 2 int2 = 3 @@ -44,6 +45,7 @@ class ElemFormat(Enum): @staticmethod def from_str(s): + """Get element format with str.""" assert s is not None, "String elem_format == None" s = s.lower() if hasattr(ElemFormat, s): @@ -53,6 +55,7 @@ def from_str(s): @staticmethod def is_bf(s): + """Whether the format is brain floating-point format.""" if isinstance(s, str): assert s is not None, "String elem_format == None" s = s.lower() @@ -65,6 +68,7 @@ def is_bf(s): @staticmethod def is_fp(s): + """Whether the format is floating-point format.""" if isinstance(s, str): assert s is not None, "String elem_format == None" s = s.lower() @@ -77,6 +81,7 @@ def is_fp(s): @staticmethod def is_int(s): + """Whether the format is integer format.""" if isinstance(s, str): assert s is not None, "String elem_format == None" s = s.lower() @@ -89,12 +94,14 @@ def is_int(s): class RoundingMode(IntEnum): + """Rounding mode.""" nearest = 0 floor = 1 even = 2 @staticmethod def string_enums(): + """Rounding mode names.""" return [s.name for s in list(RoundingMode)] @@ -115,13 +122,18 @@ def _get_max_norm(ebits, mbits): def _get_format_params(fmt): - """Allowed formats: + """Get parameters of the format. + + Allowed formats: - intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation - floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf - bfloatX/bfX: 9 <= X <= 32 - fp4, no NaN/Inf - fp6_e3m2/e2m3, no NaN/Inf - fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior + + Args: + fmt (str od ElemFormat): format Returns: ebits: exponent bits @@ -198,17 +210,19 @@ def _safe_rshift(x, bits, exp): def _round_mantissa(A, bits, round, clamp=False): - """ - Rounds mantissa to nearest bits depending on the rounding method 'round' + """Rounds mantissa to nearest bits depending on the rounding method 'round'. + Args: - A {PyTorch tensor} -- Input tensor - round {str} -- Rounding method - "floor" rounds to the floor - "nearest" rounds to ceil or floor, whichever is nearest + A (torch.Tensor): input tensor + bits (int): bit number of mantissa + round (str): rounding method + "floor" rounds to the floor + "nearest" rounds to ceil or floor, whichever is nearest + clamp (bool, optional): Whether do clip. Defaults to False. + Returns: - A {PyTorch tensor} -- Tensor with mantissas rounded + torch.Tensor: tensor with mantissas rounded """ - if round == "dither": rand_A = torch.rand_like(A, requires_grad=False) A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A) @@ -235,16 +249,18 @@ def _shared_exponents(A, method="max", axes=None, ebits=0): """Get shared exponents for the passed matrix A. Args: - A {PyTorch tensor} -- Input tensor - method {str} -- Exponent selection method. - "max" uses the max absolute value - "none" uses an exponent for each value (i.e., no sharing) - axes {list(int)} -- List of integers which specifies the axes across which - shared exponents are calculated. + A (torch.Tensor): Input tensor + method (str, optional): Exponent selection method. + "max" uses the max absolute value. + "none" uses an exponent for each value (i.e., no sharing) + Defaults to "max". + axes (list(int), optional): list of integers which specifies the axes across which + shared exponents are calculated. Defaults to None. + ebits (int, optional): bit number of the shared exponents. Defaults to 0. + Returns: - shared_exp {PyTorch tensor} -- Tensor of shared exponents + shared_exp (torch.Tensor): Tensor of shared exponents """ - if method == "max": if axes is None: shared_exp = torch.max(torch.abs(A)) @@ -346,21 +362,20 @@ def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes): def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", saturate_normals=False, allow_denorm=True): - """Core function used for element-wise quantization - Arguments: - A {PyTorch tensor} -- A tensor to be quantized - bits {int} -- Number of mantissa bits. Includes - sign bit and implicit one for floats - exp_bits {int} -- Number of exponent bits, 0 for ints - max_norm {float} -- Largest representable normal number - round {str} -- Rounding mode: (floor, nearest, even) - saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf) - that exceed max norm are clamped. - Must be True for correct MX conversion. - allow_denorm {bool} -- If False, flush denorm numbers in the - elem_format to zero. + """Core function used for element-wise quantization. + + Args: + A (torch.Tensor): tensor to be quantized + bits (int): number of mantissa bits. Includes sign bit and implicit one for floats + exp_bits (int): number of exponent bits, 0 for ints + max_norm (float): largest representable normal number + round (str, optional): rounding mode: (floor, nearest, even). Defaults to "nearest". + saturate_normals (bool, optional): whether clip normal numbers that exceed max norm. + Must be True for correct MX conversion. Defaults to False. + allow_denorm (bool, optional): if False, flush denorm numbers in the elem_format to zero. Defaults to True. + Returns: - quantized tensor {PyTorch tensor} -- A tensor that has been quantized + torch.Tensor: tensor that has been quantized """ # Flush values < min_norm to zero if denorms are not allowed if not allow_denorm and exp_bits > 0: @@ -401,15 +416,19 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", satura def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_denorm=True): - """Quantize values to IEEE fpX format. - - The format defines NaN/Inf - and subnorm numbers in the same way as FP32 and FP16. - Arguments: - exp_bits {int} -- number of bits used to store exponent - mantissa_bits {int} -- number of bits used to store mantissa, not - including sign or implicit 1 - round {str} -- Rounding mode, (floor, nearest, even) + """Quantize values to IEEE fpX format.. + + The format defines NaN/Inf and subnorm numbers in the same way as FP32 and FP16. + + Args: + A (torch.Tensor): a tensor that needs to be quantized + exp_bits (int, optional): number of bits used to store exponent. Defaults to None. + mantissa_bits (int, optional): number of bits used to store mantissa, not including sign or implicit 1. Defaults to None. + round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest". + allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True. + + Returns: + torch.Tensor: tensor that has been quantized """ # Shortcut for no quantization if exp_bits is None or mantissa_bits is None: @@ -425,11 +444,16 @@ def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_de def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True): - """Quantize values to bfloatX format - Arguments: - bfloat {int} -- Total number of bits for bfloatX format, - Includes 1 sign, 8 exp bits, and variable - mantissa bits. Must be >= 9. + """Quantize values to bfloatX format. + + Args: + A (torch.Tensor): a tensor that needs to be quantized + bfloat (int): total number of bits for bfloatX format. Includes 1 sign, 8 exp bits, and variable mantissa bits. Must be >= 9. + round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest". + allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True. + + Returns: + torch.Tensor: tensor that has been quantized """ # Shortcut for no quantization if bfloat == 0 or bfloat == 32: @@ -443,12 +467,14 @@ def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True): def quantize_elemwise_op(A, mx_specs): - """A function used for element-wise quantization with mx_specs - Arguments: - A {PyTorch tensor} -- a tensor that needs to be quantized - mx_specs {dictionary} -- dictionary to specify mx_specs + """A function used for element-wise quantization with mx_specs. + + Args: + A (torch.Tensor): a tensor that needs to be quantized + mx_specs (dict): dictionary to specify mx_specs + Returns: - quantized value {PyTorch tensor} -- a tensor that has been quantized + torch.Tensor: tensor that has been quantized """ if mx_specs is None: return A @@ -530,7 +556,7 @@ def _quantize_mx( def quantize_mx_op( - A, + A: torch.Tensor, elem_format: str, round: str, block_size: int, @@ -538,6 +564,7 @@ def quantize_mx_op( axes=None, expand_and_reshape=False, ): + """Quantize tensor to MX data type.""" if elem_format is None: return A elif type(elem_format) is str: From 1523b61ee8b82356a4bbf486f9e500699acaa4d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 06:55:30 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/torch/algorithms/mx_quant/__init__.py | 2 +- neural_compressor/torch/algorithms/mx_quant/utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/neural_compressor/torch/algorithms/mx_quant/__init__.py b/neural_compressor/torch/algorithms/mx_quant/__init__.py index b9824388148..d85854ffd8f 100644 --- a/neural_compressor/torch/algorithms/mx_quant/__init__.py +++ b/neural_compressor/torch/algorithms/mx_quant/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. # pylint:disable=import-error -"""MX quantization.""" \ No newline at end of file +"""MX quantization.""" diff --git a/neural_compressor/torch/algorithms/mx_quant/utils.py b/neural_compressor/torch/algorithms/mx_quant/utils.py index dfd879901b8..5df513af417 100644 --- a/neural_compressor/torch/algorithms/mx_quant/utils.py +++ b/neural_compressor/torch/algorithms/mx_quant/utils.py @@ -29,6 +29,7 @@ class ElemFormat(Enum): """Element format.""" + int8 = 1 int4 = 2 int2 = 3 @@ -95,6 +96,7 @@ def is_int(s): class RoundingMode(IntEnum): """Rounding mode.""" + nearest = 0 floor = 1 even = 2 @@ -131,7 +133,7 @@ def _get_format_params(fmt): - fp4, no NaN/Inf - fp6_e3m2/e2m3, no NaN/Inf - fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior - + Args: fmt (str od ElemFormat): format @@ -419,7 +421,7 @@ def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_de """Quantize values to IEEE fpX format.. The format defines NaN/Inf and subnorm numbers in the same way as FP32 and FP16. - + Args: A (torch.Tensor): a tensor that needs to be quantized exp_bits (int, optional): number of bits used to store exponent. Defaults to None. From d13fd07bf5ddfeb0ddcc56fab2e05771184ead9b Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Wed, 17 Jul 2024 01:48:36 -0700 Subject: [PATCH 3/4] fix CI Signed-off-by: Mengni Wang --- neural_compressor/torch/algorithms/mx_quant/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/neural_compressor/torch/algorithms/mx_quant/utils.py b/neural_compressor/torch/algorithms/mx_quant/utils.py index 5df513af417..210e0255cc4 100644 --- a/neural_compressor/torch/algorithms/mx_quant/utils.py +++ b/neural_compressor/torch/algorithms/mx_quant/utils.py @@ -425,7 +425,8 @@ def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_de Args: A (torch.Tensor): a tensor that needs to be quantized exp_bits (int, optional): number of bits used to store exponent. Defaults to None. - mantissa_bits (int, optional): number of bits used to store mantissa, not including sign or implicit 1. Defaults to None. + mantissa_bits (int, optional): number of bits used to store mantissa. + Not including sign or implicit 1. Defaults to None. round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest". allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True. @@ -450,7 +451,8 @@ def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True): Args: A (torch.Tensor): a tensor that needs to be quantized - bfloat (int): total number of bits for bfloatX format. Includes 1 sign, 8 exp bits, and variable mantissa bits. Must be >= 9. + bfloat (int): total number of bits for bfloatX format. + Includes 1 sign, 8 exp bits, and variable mantissa bits. Must be >= 9. round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest". allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True. From 00b52a8a1b1570e94c9314389568afff723952ad Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Fri, 19 Jul 2024 10:36:06 +0800 Subject: [PATCH 4/4] Update scan_path.txt --- .azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt b/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt index b524f1f61db..9b8fd10e8ed 100644 --- a/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt +++ b/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt @@ -15,3 +15,4 @@ /neural-compressor/neural_compressor/strategy /neural-compressor/neural_compressor/training.py /neural-compressor/neural_compressor/utils +/neural_compressor/torch/algorithms/mx_quant