Skip to content

Commit

Permalink
add docstring for mx quant (#1932)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: xinhe <[email protected]>
  • Loading branch information
3 people authored Jul 23, 2024
1 parent 0c52e12 commit b787940
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 53 deletions.
1 change: 1 addition & 0 deletions .azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
/neural-compressor/neural_compressor/strategy
/neural-compressor/neural_compressor/training.py
/neural-compressor/neural_compressor/utils
/neural_compressor/torch/algorithms/mx_quant
/neural-compressor/neural_compressor/torch/algorithms/static_quant
/neural-compressor/neural_compressor/torch/algorithms/smooth_quant
/neural_compressor/torch/algorithms/pt2e_quant
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/torch/algorithms/mx_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

# pylint:disable=import-error
"""MX quantization."""
9 changes: 8 additions & 1 deletion neural_compressor/torch/algorithms/mx_quant/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MX quantization."""

from collections import OrderedDict

Expand All @@ -31,6 +31,8 @@


class MXLinear(torch.nn.Linear):
"""Linear for MX data type."""

def __init__(
self,
in_features,
Expand All @@ -39,13 +41,15 @@ def __init__(
mx_specs=None,
name=None,
):
"""Initialization function."""
self.mx_none = mx_specs is None

self.name = name
self.mx_specs = mx_specs
super().__init__(in_features, out_features, bias)

def apply_mx_specs(self):
"""Apply MX data type to weight."""
if self.mx_specs is not None:
if self.mx_specs.out_dtype != "float32":
self.weight.data = quantize_elemwise_op(self.weight.data, mx_specs=self.mx_specs)
Expand All @@ -63,6 +67,7 @@ def apply_mx_specs(self):
)

def forward(self, input):
"""Forward function."""
if self.mx_none:
return super().forward(input)

Expand Down Expand Up @@ -93,6 +98,8 @@ def forward(self, input):


class MXQuantizer(Quantizer):
"""Quantizer of MX data type."""

def __init__(self, quant_config: OrderedDict = {}):
"""Init a MXQuantizer object.
Expand Down
135 changes: 83 additions & 52 deletions neural_compressor/torch/algorithms/mx_quant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MX quantization utils."""

from enum import Enum, IntEnum

Expand All @@ -28,6 +28,8 @@


class ElemFormat(Enum):
"""Element format."""

int8 = 1
int4 = 2
int2 = 3
Expand All @@ -44,6 +46,7 @@ class ElemFormat(Enum):

@staticmethod
def from_str(s):
"""Get element format with str."""
assert s is not None, "String elem_format == None"
s = s.lower()
if hasattr(ElemFormat, s):
Expand All @@ -53,6 +56,7 @@ def from_str(s):

@staticmethod
def is_bf(s):
"""Whether the format is brain floating-point format."""
if isinstance(s, str):
assert s is not None, "String elem_format == None"
s = s.lower()
Expand All @@ -65,6 +69,7 @@ def is_bf(s):

@staticmethod
def is_fp(s):
"""Whether the format is floating-point format."""
if isinstance(s, str):
assert s is not None, "String elem_format == None"
s = s.lower()
Expand All @@ -77,6 +82,7 @@ def is_fp(s):

@staticmethod
def is_int(s):
"""Whether the format is integer format."""
if isinstance(s, str):
assert s is not None, "String elem_format == None"
s = s.lower()
Expand All @@ -89,12 +95,15 @@ def is_int(s):


class RoundingMode(IntEnum):
"""Rounding mode."""

nearest = 0
floor = 1
even = 2

@staticmethod
def string_enums():
"""Rounding mode names."""
return [s.name for s in list(RoundingMode)]


Expand All @@ -115,14 +124,19 @@ def _get_max_norm(ebits, mbits):


def _get_format_params(fmt):
"""Allowed formats:
"""Get parameters of the format.
Allowed formats:
- intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation
- floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf
- bfloatX/bfX: 9 <= X <= 32
- fp4, no NaN/Inf
- fp6_e3m2/e2m3, no NaN/Inf
- fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior
Args:
fmt (str od ElemFormat): format
Returns:
ebits: exponent bits
mbits: mantissa bits: includes sign and implicit bits
Expand Down Expand Up @@ -198,17 +212,19 @@ def _safe_rshift(x, bits, exp):


def _round_mantissa(A, bits, round, clamp=False):
"""
Rounds mantissa to nearest bits depending on the rounding method 'round'
"""Rounds mantissa to nearest bits depending on the rounding method 'round'.
Args:
A {PyTorch tensor} -- Input tensor
round {str} -- Rounding method
"floor" rounds to the floor
"nearest" rounds to ceil or floor, whichever is nearest
A (torch.Tensor): input tensor
bits (int): bit number of mantissa
round (str): rounding method
"floor" rounds to the floor
"nearest" rounds to ceil or floor, whichever is nearest
clamp (bool, optional): Whether do clip. Defaults to False.
Returns:
A {PyTorch tensor} -- Tensor with mantissas rounded
torch.Tensor: tensor with mantissas rounded
"""

if round == "dither":
rand_A = torch.rand_like(A, requires_grad=False)
A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A)
Expand All @@ -235,16 +251,18 @@ def _shared_exponents(A, method="max", axes=None, ebits=0):
"""Get shared exponents for the passed matrix A.
Args:
A {PyTorch tensor} -- Input tensor
method {str} -- Exponent selection method.
"max" uses the max absolute value
"none" uses an exponent for each value (i.e., no sharing)
axes {list(int)} -- List of integers which specifies the axes across which
shared exponents are calculated.
A (torch.Tensor): Input tensor
method (str, optional): Exponent selection method.
"max" uses the max absolute value.
"none" uses an exponent for each value (i.e., no sharing)
Defaults to "max".
axes (list(int), optional): list of integers which specifies the axes across which
shared exponents are calculated. Defaults to None.
ebits (int, optional): bit number of the shared exponents. Defaults to 0.
Returns:
shared_exp {PyTorch tensor} -- Tensor of shared exponents
shared_exp (torch.Tensor): Tensor of shared exponents
"""

if method == "max":
if axes is None:
shared_exp = torch.max(torch.abs(A))
Expand Down Expand Up @@ -346,21 +364,20 @@ def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes):


def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", saturate_normals=False, allow_denorm=True):
"""Core function used for element-wise quantization
Arguments:
A {PyTorch tensor} -- A tensor to be quantized
bits {int} -- Number of mantissa bits. Includes
sign bit and implicit one for floats
exp_bits {int} -- Number of exponent bits, 0 for ints
max_norm {float} -- Largest representable normal number
round {str} -- Rounding mode: (floor, nearest, even)
saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf)
that exceed max norm are clamped.
Must be True for correct MX conversion.
allow_denorm {bool} -- If False, flush denorm numbers in the
elem_format to zero.
"""Core function used for element-wise quantization.
Args:
A (torch.Tensor): tensor to be quantized
bits (int): number of mantissa bits. Includes sign bit and implicit one for floats
exp_bits (int): number of exponent bits, 0 for ints
max_norm (float): largest representable normal number
round (str, optional): rounding mode: (floor, nearest, even). Defaults to "nearest".
saturate_normals (bool, optional): whether clip normal numbers that exceed max norm.
Must be True for correct MX conversion. Defaults to False.
allow_denorm (bool, optional): if False, flush denorm numbers in the elem_format to zero. Defaults to True.
Returns:
quantized tensor {PyTorch tensor} -- A tensor that has been quantized
torch.Tensor: tensor that has been quantized
"""
# Flush values < min_norm to zero if denorms are not allowed
if not allow_denorm and exp_bits > 0:
Expand Down Expand Up @@ -401,15 +418,20 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", satura


def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_denorm=True):
"""Quantize values to IEEE fpX format.
The format defines NaN/Inf
and subnorm numbers in the same way as FP32 and FP16.
Arguments:
exp_bits {int} -- number of bits used to store exponent
mantissa_bits {int} -- number of bits used to store mantissa, not
including sign or implicit 1
round {str} -- Rounding mode, (floor, nearest, even)
"""Quantize values to IEEE fpX format..
The format defines NaN/Inf and subnorm numbers in the same way as FP32 and FP16.
Args:
A (torch.Tensor): a tensor that needs to be quantized
exp_bits (int, optional): number of bits used to store exponent. Defaults to None.
mantissa_bits (int, optional): number of bits used to store mantissa.
Not including sign or implicit 1. Defaults to None.
round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest".
allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True.
Returns:
torch.Tensor: tensor that has been quantized
"""
# Shortcut for no quantization
if exp_bits is None or mantissa_bits is None:
Expand All @@ -425,11 +447,17 @@ def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_de


def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True):
"""Quantize values to bfloatX format
Arguments:
bfloat {int} -- Total number of bits for bfloatX format,
Includes 1 sign, 8 exp bits, and variable
mantissa bits. Must be >= 9.
"""Quantize values to bfloatX format.
Args:
A (torch.Tensor): a tensor that needs to be quantized
bfloat (int): total number of bits for bfloatX format.
Includes 1 sign, 8 exp bits, and variable mantissa bits. Must be >= 9.
round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest".
allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True.
Returns:
torch.Tensor: tensor that has been quantized
"""
# Shortcut for no quantization
if bfloat == 0 or bfloat == 32:
Expand All @@ -443,12 +471,14 @@ def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True):


def quantize_elemwise_op(A, mx_specs):
"""A function used for element-wise quantization with mx_specs
Arguments:
A {PyTorch tensor} -- a tensor that needs to be quantized
mx_specs {dictionary} -- dictionary to specify mx_specs
"""A function used for element-wise quantization with mx_specs.
Args:
A (torch.Tensor): a tensor that needs to be quantized
mx_specs (dict): dictionary to specify mx_specs
Returns:
quantized value {PyTorch tensor} -- a tensor that has been quantized
torch.Tensor: tensor that has been quantized
"""
if mx_specs is None:
return A
Expand Down Expand Up @@ -530,14 +560,15 @@ def _quantize_mx(


def quantize_mx_op(
A,
A: torch.Tensor,
elem_format: str,
round: str,
block_size: int,
scale_bits=8,
axes=None,
expand_and_reshape=False,
):
"""Quantize tensor to MX data type."""
if elem_format is None:
return A
elif type(elem_format) is str:
Expand Down

0 comments on commit b787940

Please sign in to comment.