Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add docstring for mx quant #1932

Merged
merged 8 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
/neural-compressor/neural_compressor/strategy
/neural-compressor/neural_compressor/training.py
/neural-compressor/neural_compressor/utils
/neural_compressor/torch/algorithms/mx_quant
/neural-compressor/neural_compressor/torch/algorithms/static_quant
/neural-compressor/neural_compressor/torch/algorithms/smooth_quant
/neural_compressor/torch/algorithms/pt2e_quant
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/torch/algorithms/mx_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

# pylint:disable=import-error
"""MX quantization."""
9 changes: 8 additions & 1 deletion neural_compressor/torch/algorithms/mx_quant/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MX quantization."""

from collections import OrderedDict

Expand All @@ -31,6 +31,8 @@


class MXLinear(torch.nn.Linear):
"""Linear for MX data type."""

def __init__(
self,
in_features,
Expand All @@ -39,13 +41,15 @@ def __init__(
mx_specs=None,
name=None,
):
"""Initialization function."""
self.mx_none = mx_specs is None

self.name = name
self.mx_specs = mx_specs
super().__init__(in_features, out_features, bias)

def apply_mx_specs(self):
"""Apply MX data type to weight."""
if self.mx_specs is not None:
if self.mx_specs.out_dtype != "float32":
self.weight.data = quantize_elemwise_op(self.weight.data, mx_specs=self.mx_specs)
Expand All @@ -63,6 +67,7 @@ def apply_mx_specs(self):
)

def forward(self, input):
"""Forward function."""
if self.mx_none:
return super().forward(input)

Expand Down Expand Up @@ -93,6 +98,8 @@ def forward(self, input):


class MXQuantizer(Quantizer):
"""Quantizer of MX data type."""

def __init__(self, quant_config: OrderedDict = {}):
"""Init a MXQuantizer object.

Expand Down
135 changes: 83 additions & 52 deletions neural_compressor/torch/algorithms/mx_quant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MX quantization utils."""

from enum import Enum, IntEnum

Expand All @@ -28,6 +28,8 @@


class ElemFormat(Enum):
"""Element format."""

int8 = 1
int4 = 2
int2 = 3
Expand All @@ -44,6 +46,7 @@ class ElemFormat(Enum):

@staticmethod
def from_str(s):
"""Get element format with str."""
assert s is not None, "String elem_format == None"
s = s.lower()
if hasattr(ElemFormat, s):
Expand All @@ -53,6 +56,7 @@ def from_str(s):

@staticmethod
def is_bf(s):
"""Whether the format is brain floating-point format."""
if isinstance(s, str):
assert s is not None, "String elem_format == None"
s = s.lower()
Expand All @@ -65,6 +69,7 @@ def is_bf(s):

@staticmethod
def is_fp(s):
"""Whether the format is floating-point format."""
if isinstance(s, str):
assert s is not None, "String elem_format == None"
s = s.lower()
Expand All @@ -77,6 +82,7 @@ def is_fp(s):

@staticmethod
def is_int(s):
"""Whether the format is integer format."""
if isinstance(s, str):
assert s is not None, "String elem_format == None"
s = s.lower()
Expand All @@ -89,12 +95,15 @@ def is_int(s):


class RoundingMode(IntEnum):
"""Rounding mode."""

nearest = 0
floor = 1
even = 2

@staticmethod
def string_enums():
"""Rounding mode names."""
return [s.name for s in list(RoundingMode)]


Expand All @@ -115,14 +124,19 @@ def _get_max_norm(ebits, mbits):


def _get_format_params(fmt):
"""Allowed formats:
"""Get parameters of the format.

Allowed formats:
- intX: 2 <= X <= 32, assume sign-magnitude, 1.xxx representation
- floatX/fpX: 16 <= X <= 28, assume top exp is used for NaN/Inf
- bfloatX/bfX: 9 <= X <= 32
- fp4, no NaN/Inf
- fp6_e3m2/e2m3, no NaN/Inf
- fp8_e4m3/e5m2, e5m2 normal NaN/Inf, e4m3 special behavior

Args:
fmt (str od ElemFormat): format

Returns:
ebits: exponent bits
mbits: mantissa bits: includes sign and implicit bits
Expand Down Expand Up @@ -198,17 +212,19 @@ def _safe_rshift(x, bits, exp):


def _round_mantissa(A, bits, round, clamp=False):
"""
Rounds mantissa to nearest bits depending on the rounding method 'round'
"""Rounds mantissa to nearest bits depending on the rounding method 'round'.

Args:
A {PyTorch tensor} -- Input tensor
round {str} -- Rounding method
"floor" rounds to the floor
"nearest" rounds to ceil or floor, whichever is nearest
A (torch.Tensor): input tensor
bits (int): bit number of mantissa
round (str): rounding method
"floor" rounds to the floor
"nearest" rounds to ceil or floor, whichever is nearest
clamp (bool, optional): Whether do clip. Defaults to False.

Returns:
A {PyTorch tensor} -- Tensor with mantissas rounded
torch.Tensor: tensor with mantissas rounded
"""

if round == "dither":
rand_A = torch.rand_like(A, requires_grad=False)
A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A)
Expand All @@ -235,16 +251,18 @@ def _shared_exponents(A, method="max", axes=None, ebits=0):
"""Get shared exponents for the passed matrix A.

Args:
A {PyTorch tensor} -- Input tensor
method {str} -- Exponent selection method.
"max" uses the max absolute value
"none" uses an exponent for each value (i.e., no sharing)
axes {list(int)} -- List of integers which specifies the axes across which
shared exponents are calculated.
A (torch.Tensor): Input tensor
method (str, optional): Exponent selection method.
"max" uses the max absolute value.
"none" uses an exponent for each value (i.e., no sharing)
Defaults to "max".
axes (list(int), optional): list of integers which specifies the axes across which
shared exponents are calculated. Defaults to None.
ebits (int, optional): bit number of the shared exponents. Defaults to 0.

Returns:
shared_exp {PyTorch tensor} -- Tensor of shared exponents
shared_exp (torch.Tensor): Tensor of shared exponents
"""

if method == "max":
if axes is None:
shared_exp = torch.max(torch.abs(A))
Expand Down Expand Up @@ -346,21 +364,20 @@ def _undo_reshape_to_blocks(A, padded_shape, orig_shape, axes):


def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", saturate_normals=False, allow_denorm=True):
"""Core function used for element-wise quantization
Arguments:
A {PyTorch tensor} -- A tensor to be quantized
bits {int} -- Number of mantissa bits. Includes
sign bit and implicit one for floats
exp_bits {int} -- Number of exponent bits, 0 for ints
max_norm {float} -- Largest representable normal number
round {str} -- Rounding mode: (floor, nearest, even)
saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf)
that exceed max norm are clamped.
Must be True for correct MX conversion.
allow_denorm {bool} -- If False, flush denorm numbers in the
elem_format to zero.
"""Core function used for element-wise quantization.

Args:
A (torch.Tensor): tensor to be quantized
bits (int): number of mantissa bits. Includes sign bit and implicit one for floats
exp_bits (int): number of exponent bits, 0 for ints
max_norm (float): largest representable normal number
round (str, optional): rounding mode: (floor, nearest, even). Defaults to "nearest".
saturate_normals (bool, optional): whether clip normal numbers that exceed max norm.
Must be True for correct MX conversion. Defaults to False.
allow_denorm (bool, optional): if False, flush denorm numbers in the elem_format to zero. Defaults to True.

Returns:
quantized tensor {PyTorch tensor} -- A tensor that has been quantized
torch.Tensor: tensor that has been quantized
"""
# Flush values < min_norm to zero if denorms are not allowed
if not allow_denorm and exp_bits > 0:
Expand Down Expand Up @@ -401,15 +418,20 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round="nearest", satura


def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_denorm=True):
"""Quantize values to IEEE fpX format.

The format defines NaN/Inf
and subnorm numbers in the same way as FP32 and FP16.
Arguments:
exp_bits {int} -- number of bits used to store exponent
mantissa_bits {int} -- number of bits used to store mantissa, not
including sign or implicit 1
round {str} -- Rounding mode, (floor, nearest, even)
"""Quantize values to IEEE fpX format..

The format defines NaN/Inf and subnorm numbers in the same way as FP32 and FP16.

Args:
A (torch.Tensor): a tensor that needs to be quantized
exp_bits (int, optional): number of bits used to store exponent. Defaults to None.
mantissa_bits (int, optional): number of bits used to store mantissa.
Not including sign or implicit 1. Defaults to None.
round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest".
allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True.

Returns:
torch.Tensor: tensor that has been quantized
"""
# Shortcut for no quantization
if exp_bits is None or mantissa_bits is None:
Expand All @@ -425,11 +447,17 @@ def _quantize_fp(A, exp_bits=None, mantissa_bits=None, round="nearest", allow_de


def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True):
"""Quantize values to bfloatX format
Arguments:
bfloat {int} -- Total number of bits for bfloatX format,
Includes 1 sign, 8 exp bits, and variable
mantissa bits. Must be >= 9.
"""Quantize values to bfloatX format.

Args:
A (torch.Tensor): a tensor that needs to be quantized
bfloat (int): total number of bits for bfloatX format.
Includes 1 sign, 8 exp bits, and variable mantissa bits. Must be >= 9.
round (str, optional): rounding mode, (floor, nearest, even). Defaults to "nearest".
allow_denorm (bool, optional): allow denorm numbers to exist. Defaults to True.

Returns:
torch.Tensor: tensor that has been quantized
"""
# Shortcut for no quantization
if bfloat == 0 or bfloat == 32:
Expand All @@ -443,12 +471,14 @@ def _quantize_bfloat(A, bfloat, round="nearest", allow_denorm=True):


def quantize_elemwise_op(A, mx_specs):
"""A function used for element-wise quantization with mx_specs
Arguments:
A {PyTorch tensor} -- a tensor that needs to be quantized
mx_specs {dictionary} -- dictionary to specify mx_specs
"""A function used for element-wise quantization with mx_specs.

Args:
A (torch.Tensor): a tensor that needs to be quantized
mx_specs (dict): dictionary to specify mx_specs

Returns:
quantized value {PyTorch tensor} -- a tensor that has been quantized
torch.Tensor: tensor that has been quantized
"""
if mx_specs is None:
return A
Expand Down Expand Up @@ -530,14 +560,15 @@ def _quantize_mx(


def quantize_mx_op(
A,
A: torch.Tensor,
elem_format: str,
round: str,
block_size: int,
scale_bits=8,
axes=None,
expand_and_reshape=False,
):
"""Quantize tensor to MX data type."""
if elem_format is None:
return A
elif type(elem_format) is str:
Expand Down
Loading