From 4991a0d7d6422437f7651b2541eb7b0885352a20 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 7 Mar 2023 14:32:42 -0800 Subject: [PATCH 01/16] Adding weight normalization-based integer quantizer --- src/brevitas/core/quant/int.py | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index e58b52381..a33b2583b 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -227,3 +227,46 @@ def forward(self, y = y * scale y = self.delay_wrapper(x, y) return y, scale, zero_point, output_bit_width + +class WeightNormIntQuant(brevitas.jit.ScriptModule): + """Weight normalization-based integer quantized for accumulator- + aware quantization based on `Quantized Neural Networks for Low-Precision + Accumulation with Guaranteed Overflow Avoidance` by I. Colbert, + A. Pappalardo, and J. Petri-Koenig.""" + def __init__(self, + int_quant: Module, + scaling_impl: Module, + zero_point_impl: Module, + bit_width_impl: Module, + stats_reduce_dim: int, + scaling_shape: Tuple[int, ...], + scaling_min_val: float, + scaling_stats_input_view_shape_impl: Module): + super(WeightNormIntQuant, self).__init__() + self.int_quant = int_quant + self.scaling_impl = scaling_impl + self.zero_point_impl = zero_point_impl + self.msb_clamp_bit_width_impl = bit_width_impl + + self.tensor_reduce_dim = stats_reduce_dim + self.tensor_view_impl = scaling_stats_input_view_shape_impl + self.tensor_shape = scaling_shape + self.tensor_min_val = scaling_min_val + + @brevitas.jit.script_method + def normalize(self, inp: Tensor) -> Tensor: + denom: Tensor = self.tensor_view_impl(inp) + denom = denom.norm(p=1, dim=self.tensor_reduce_dim, keepdim=True) + denom = denom.view(self.tensor_shape) + return inp / denom.clamp_min(self.tensor_min_val) + + @brevitas.jit.script_method + def forward(self, x: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + # apply weight normalization-based quantization formulation + w = self.normalize(x) + bit_width = self.msb_clamp_bit_width_impl() + # use input bit_width and sign to get norm and scales + g, scale = self.scaling_impl(input_bit_width, input_is_signed) + zero_point = self.zero_point_impl(g * w, scale, bit_width) + y = self.int_quant(scale, zero_point, bit_width, g * w) + return y, scale, zero_point, bit_width \ No newline at end of file From 7b3f6912dbbceacc42e0c8c964ca3d1b58d1c65f Mon Sep 17 00:00:00 2001 From: icolbert Date: Fri, 10 Mar 2023 20:18:58 -0800 Subject: [PATCH 02/16] Pre-commit fixes --- src/brevitas/core/quant/int.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index a33b2583b..cbdc69d16 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -228,6 +228,7 @@ def forward(self, y = self.delay_wrapper(x, y) return y, scale, zero_point, output_bit_width + class WeightNormIntQuant(brevitas.jit.ScriptModule): """Weight normalization-based integer quantized for accumulator- aware quantization based on `Quantized Neural Networks for Low-Precision @@ -269,4 +270,4 @@ def forward(self, x: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> g, scale = self.scaling_impl(input_bit_width, input_is_signed) zero_point = self.zero_point_impl(g * w, scale, bit_width) y = self.int_quant(scale, zero_point, bit_width, g * w) - return y, scale, zero_point, bit_width \ No newline at end of file + return y, scale, zero_point, bit_width From d84b52b9a8586ed1e81d8944ba8f71dd464e35c5 Mon Sep 17 00:00:00 2001 From: icolbert Date: Fri, 17 Mar 2023 08:07:39 -0700 Subject: [PATCH 03/16] Adding variable for p-norm --- src/brevitas/core/quant/int.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cbdc69d16..8a7a9d572 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -230,7 +230,7 @@ def forward(self, class WeightNormIntQuant(brevitas.jit.ScriptModule): - """Weight normalization-based integer quantized for accumulator- + """Weight normalization-based integer quantizer for accumulator- aware quantization based on `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig.""" @@ -254,10 +254,14 @@ def __init__(self, self.tensor_shape = scaling_shape self.tensor_min_val = scaling_min_val + # TODO - expose this to be user-specified + self.tensor_p = 1 + @brevitas.jit.script_method def normalize(self, inp: Tensor) -> Tensor: + """Normalize weights by p-norm""" denom: Tensor = self.tensor_view_impl(inp) - denom = denom.norm(p=1, dim=self.tensor_reduce_dim, keepdim=True) + denom = denom.norm(p=self.tensor_p, dim=self.tensor_reduce_dim, keepdim=True) denom = denom.view(self.tensor_shape) return inp / denom.clamp_min(self.tensor_min_val) From b4ad793914a04020c38d82e1ed513b3576d87d18 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 12:12:38 -0700 Subject: [PATCH 04/16] Adding SingleArgStatelessBuffer --- src/brevitas/core/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/brevitas/core/utils.py b/src/brevitas/core/utils.py index 44aef87e8..0ddac751e 100644 --- a/src/brevitas/core/utils.py +++ b/src/brevitas/core/utils.py @@ -62,3 +62,14 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): destination=destination, prefix=prefix, keep_vars=keep_vars) del output_dict[prefix + VALUE_ATTR_NAME] return output_dict + + +class SingleArgStatelessBuffer(brevitas.jit.ScriptModule): + + def __init__(self, value: torch.Tensor): + super(SingleArgStatelessBuffer, self).__init__() + self.const = StatelessBuffer(torch.tensor(value)) + + @brevitas.jit.script_method + def forward(self, placeholder): + return self.const() From 119bcb3054dc16221df9e06d165ac5d8eea9f4c1 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 12:14:51 -0700 Subject: [PATCH 05/16] Adding L1Norm as scaling_stats_impl for weight normalization --- src/brevitas/core/stats/stats_op.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 0b5c2d180..1d122f96d 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -349,3 +349,18 @@ def forward(self, x: Tensor): min_divergence_idx = torch.argmin(divergence) opt_threshold = thresholds[min_divergence_idx] return opt_threshold + + +class L1Norm(brevitas.jit.ScriptModule): + __constants__ = ['stats_reduce_dim'] + + def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: + super(L1Norm, self).__init__() + self.stats_reduce_dim = stats_reduce_dim + + @brevitas.jit.script_method + def forward(self, x: Tensor): + if self.stats_reduce_dim is None: + raise NotImplementedError("L1 normalization is not supported per-tensor yet.") + else: + return x.norm(p=1, dim=self.stats_reduce_dim, keepdim=True) From 27da0482a8cfc5964469406430ae14b8553bb13d Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 12:19:13 -0700 Subject: [PATCH 06/16] Adding L2Norm for normalize_stats_impl --- src/brevitas/core/stats/stats_op.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 1d122f96d..cd57f3e49 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -352,6 +352,8 @@ def forward(self, x: Tensor): class L1Norm(brevitas.jit.ScriptModule): + """ScriptModule implementation to collect per-channel L1 normalization stats + for weight normalization-based quantization.""" __constants__ = ['stats_reduce_dim'] def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: @@ -361,6 +363,25 @@ def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: @brevitas.jit.script_method def forward(self, x: Tensor): if self.stats_reduce_dim is None: + # Need to be able to return the max per-channel L1 norm as a scalar raise NotImplementedError("L1 normalization is not supported per-tensor yet.") else: return x.norm(p=1, dim=self.stats_reduce_dim, keepdim=True) + + +class L2Norm(brevitas.jit.ScriptModule): + """ScriptModule implementation to collect per-channel L2 normalization stats + for weight normalization-based quantization.""" + __constants__ = ['stats_reduce_dim'] + + def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: + super(L1Norm, self).__init__() + self.stats_reduce_dim = stats_reduce_dim + + @brevitas.jit.script_method + def forward(self, x: Tensor): + if self.stats_reduce_dim is None: + # Need to be able to return the max per-channel L2 norm as a scalar + raise NotImplementedError("L2 normalization is not supported per-tensor yet.") + else: + return x.norm(p=2, dim=self.stats_reduce_dim, keepdim=True) From 8251396f24e62a1b7deb4c7decb6a9210faace0b Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 12:35:32 -0700 Subject: [PATCH 07/16] Adding ParameterPreScalingWeightNorm --- src/brevitas/core/scaling/pre_scaling.py | 107 +++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 src/brevitas/core/scaling/pre_scaling.py diff --git a/src/brevitas/core/scaling/pre_scaling.py b/src/brevitas/core/scaling/pre_scaling.py new file mode 100644 index 000000000..f93ba3a7e --- /dev/null +++ b/src/brevitas/core/scaling/pre_scaling.py @@ -0,0 +1,107 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + + +from typing import List, Optional, Tuple + +import torch +from torch import Tensor +from torch.nn import Module +from torch.nn import Parameter + +import brevitas +import brevitas.config as config +from brevitas.core.restrict_val import _RestrictClampValue +from brevitas.core.stats import SCALAR_SHAPE +from brevitas.core.stats.stats_wrapper import _Stats +from brevitas.function import abs_binary_sign_grad + + +class ParameterPreScalingWeightNorm(brevitas.jit.ScriptModule): + """ + ScriptModule implementation of learned pre-clipping scaling factor to support weight + normalization-based quantization as proposed in `Quantized Neural Networks for Low- + Precision Accumulation with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, + and J. Petri-Koenig. + + The module parameterizes the pre-clipping scaling factor (i.e., `pre_scale`) of the + decoupled tensor quantizer (i.e., `DecoupledRescalingIntQuant`) by combining the + calculuated weight norm stats (i.e., `d_w`) with both the parameterized weight norm + vector (i.e., `g`) and the post-clipping scaling factor (i.e., `post_scale`). The + arithmetic is outlined below. + + The formulation for weight normalization-based quantization is given below: + `y = clip(round( (g / s) * (w / norm(w)) )) * s` + which we re-write as: + `y = clip(round(w / pre_scale)) * post_scale` + where `pre_scale = s * norm(w) / g` and `post_scale = s`. + + Here, `pre_scale` refers to the pre-clipping scaling factor and `post_scale` refers to + the post-clipping scaling factor. + + Args: + scaling_impl (Module): post-clipping scaling factor. + normalize_stats_impl (Module): calculate statistics for normalizing weight parameter. + scaling_stats_input_view_shape_impl (Module): transforming scaling to a new shape. + tracked_parameter_list (List[torch.nn.Parameter]): list of tracked weight parameters + for tensor quantizer. + pre_scaling_shape (Tuple[int]): shape of pre-clipping scaling factor. Default: None. + restrict_pre_scaling_impl (Module): restrict pre_scaling_init according to some + criteria. Default: None. + pre_scaling_min_val (float): force a lower-bound on scaling_init. Default: None. + + Returns: + Tensor: scaling factor wrapped in a float torch.Tensor. + """ + def __init__( + self, + scaling_impl: Module, + normalize_stats_impl: Module, + scaling_stats_input_view_shape_impl: Module, + tracked_parameter_list: List[torch.nn.Parameter], + pre_scaling_shape: Optional[Tuple[int, ...]] = None, + restrict_pre_scaling_impl: Optional[Module] = None, + pre_scaling_min_val: Optional[float] = None) -> None: + super(ParameterPreScalingWeightNorm, self).__init__() + + self.stats = _Stats(normalize_stats_impl, pre_scaling_shape) + self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl + self.scaling_impl = scaling_impl # this is the post-clipping scaling factor + + if len(tracked_parameter_list) > 1: + raise NotImplementedError( + "Error: ParameterPreScalingWeightNorm does not support multiple tracked quantizers." + ) + assert len(tracked_parameter_list) == 1 + + # Initialize the weight norm parameter vector from the tracked parameter itself + param = tracked_parameter_list[0] + param = self.stats_input_view_shape_impl(param) + pre_scaling_init = self.stats(param) + if restrict_pre_scaling_impl is not None: + pre_scaling_init = restrict_pre_scaling_impl.restrict_init_tensor(pre_scaling_init) + if pre_scaling_init.shape == SCALAR_SHAPE and pre_scaling_shape is not None: + pre_scaling_init = torch.full(pre_scaling_shape, pre_scaling_init) + self.value = Parameter(pre_scaling_init) + self.restrict_clamp_scaling = _RestrictClampValue(pre_scaling_min_val, restrict_pre_scaling_impl) + + @brevitas.jit.script_method + def forward(self, weights: Tensor) -> Tensor: + """Takes weights as input and returns the pre-clipping scaling factor""" + weights = self.stats_input_view_shape_impl(weights) + d_w = self.stats(weights) # denominator for weight normalization + g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g + s = self.scaling_impl(weights) # s + value = (s * d_w) / g + return value + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + value_key = prefix + 'value' + retrocomp_value_key = prefix + 'learned_value' + if retrocomp_value_key in state_dict: # retrocompatibility + state_dict[value_key] = state_dict.pop(retrocomp_value_key) + super(ParameterPreScalingWeightNorm, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + if config.IGNORE_MISSING_KEYS and value_key in missing_keys: + missing_keys.remove(value_key) From f740c097d9733ca258b9f116364b19decee5675e Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 14:23:51 -0700 Subject: [PATCH 08/16] Adding modules to top-level imports --- src/brevitas/core/scaling/__init__.py | 1 + src/brevitas/core/stats/__init__.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/brevitas/core/scaling/__init__.py b/src/brevitas/core/scaling/__init__.py index 5472c8d30..d7870a2af 100644 --- a/src/brevitas/core/scaling/__init__.py +++ b/src/brevitas/core/scaling/__init__.py @@ -6,6 +6,7 @@ from .int_scaling import IntScaling from .int_scaling import PowerOfTwoIntScaling +from .pre_scaling import ParameterPreScalingWeightNorm from .runtime import RuntimeStatsScaling from .runtime import StatsFromParameterScaling from .standalone import ConstScaling diff --git a/src/brevitas/core/stats/__init__.py b/src/brevitas/core/stats/__init__.py index 929ae5e06..df63eb1cb 100644 --- a/src/brevitas/core/stats/__init__.py +++ b/src/brevitas/core/stats/__init__.py @@ -12,6 +12,8 @@ from .stats_op import AbsMaxL2 from .stats_op import AbsMinMax from .stats_op import AbsPercentile +from .stats_op import L1Norm +from .stats_op import L2Norm from .stats_op import MeanLearnedSigmaStd from .stats_op import MeanSigmaStd from .stats_op import NegativeMinOrZero From 96045e767b3eda69aec789ed35539f425a081103 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 14:25:18 -0700 Subject: [PATCH 09/16] Removing WeightNormIntQuant --- src/brevitas/core/quant/int.py | 48 ---------------------------------- 1 file changed, 48 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index 8a7a9d572..e58b52381 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -227,51 +227,3 @@ def forward(self, y = y * scale y = self.delay_wrapper(x, y) return y, scale, zero_point, output_bit_width - - -class WeightNormIntQuant(brevitas.jit.ScriptModule): - """Weight normalization-based integer quantizer for accumulator- - aware quantization based on `Quantized Neural Networks for Low-Precision - Accumulation with Guaranteed Overflow Avoidance` by I. Colbert, - A. Pappalardo, and J. Petri-Koenig.""" - def __init__(self, - int_quant: Module, - scaling_impl: Module, - zero_point_impl: Module, - bit_width_impl: Module, - stats_reduce_dim: int, - scaling_shape: Tuple[int, ...], - scaling_min_val: float, - scaling_stats_input_view_shape_impl: Module): - super(WeightNormIntQuant, self).__init__() - self.int_quant = int_quant - self.scaling_impl = scaling_impl - self.zero_point_impl = zero_point_impl - self.msb_clamp_bit_width_impl = bit_width_impl - - self.tensor_reduce_dim = stats_reduce_dim - self.tensor_view_impl = scaling_stats_input_view_shape_impl - self.tensor_shape = scaling_shape - self.tensor_min_val = scaling_min_val - - # TODO - expose this to be user-specified - self.tensor_p = 1 - - @brevitas.jit.script_method - def normalize(self, inp: Tensor) -> Tensor: - """Normalize weights by p-norm""" - denom: Tensor = self.tensor_view_impl(inp) - denom = denom.norm(p=self.tensor_p, dim=self.tensor_reduce_dim, keepdim=True) - denom = denom.view(self.tensor_shape) - return inp / denom.clamp_min(self.tensor_min_val) - - @brevitas.jit.script_method - def forward(self, x: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - # apply weight normalization-based quantization formulation - w = self.normalize(x) - bit_width = self.msb_clamp_bit_width_impl() - # use input bit_width and sign to get norm and scales - g, scale = self.scaling_impl(input_bit_width, input_is_signed) - zero_point = self.zero_point_impl(g * w, scale, bit_width) - y = self.int_quant(scale, zero_point, bit_width, g * w) - return y, scale, zero_point, bit_width From 70ecfa025ca6ddc1f29af75dc7fb36f3ea54bf5c Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 14:35:23 -0700 Subject: [PATCH 10/16] Updating list of modules in pre_scaling.py --- src/brevitas/core/scaling/pre_scaling.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/brevitas/core/scaling/pre_scaling.py b/src/brevitas/core/scaling/pre_scaling.py index f93ba3a7e..9429e1485 100644 --- a/src/brevitas/core/scaling/pre_scaling.py +++ b/src/brevitas/core/scaling/pre_scaling.py @@ -16,6 +16,9 @@ from brevitas.core.stats.stats_wrapper import _Stats from brevitas.function import abs_binary_sign_grad +__all__ = [ + "ParameterPreScalingWeightNorm", +] class ParameterPreScalingWeightNorm(brevitas.jit.ScriptModule): """ From f82db6df29a85db159db3818ea4395d6a28ccd08 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 14:37:46 -0700 Subject: [PATCH 11/16] Adding WeightNormPerChannelFloatDecoupled --- src/brevitas/quant/base.py | 54 +++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index db246989a..233f1b65a 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -7,21 +7,26 @@ from brevitas.core.bit_width import BitWidthConst from brevitas.core.function_wrapper import OverOutputChannelView +from brevitas.core.function_wrapper import TensorClamp from brevitas.core.function_wrapper import TensorClampSte from brevitas.core.function_wrapper.ops_ste import CeilSte from brevitas.core.quant import ClampedBinaryQuant from brevitas.core.quant.int import DecoupledRescalingIntQuant from brevitas.core.quant.int_base import DecoupledIntQuant from brevitas.core.restrict_val import FloatRestrictValue +from brevitas.core.restrict_val import LogFloatRestrictValue from brevitas.core.scaling import IntScaling +from brevitas.core.scaling import ParameterPreScalingWeightNorm from brevitas.core.scaling import ParameterScaling from brevitas.core.scaling import SCALAR_SHAPE from brevitas.core.scaling import SCALING_STATS_REDUCE_DIM from brevitas.core.scaling import StatsFromParameterScaling from brevitas.core.stats import AbsMax from brevitas.core.stats import AbsMaxL2 +from brevitas.core.stats import L2Norm from brevitas.core.stats import NegativeMinOrZero from brevitas.core.stats import NegativePercentileOrZero +from brevitas.core.utils import SingleArgStatelessBuffer from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint from brevitas.core.zero_point import StatsFromParameterZeroPoint from brevitas.core.zero_point import ZeroZeroPoint @@ -56,7 +61,8 @@ 'IntTrunc', 'SignedBinaryClampedConst', 'WeightPerTensorFloatDecoupledL2Param', - 'WeightPerChannelFloatDecoupled' + 'WeightPerChannelFloatDecoupled', + 'WeightNormPerChannelFloatDecoupled', ] @@ -291,3 +297,49 @@ def scaling_init(scaling_init_impl): scaling_stats_input_view_shape_impl = OverOutputChannelView stats_reduce_dim = SCALING_STATS_REDUCE_DIM scaling_per_output_channel = True + + +class WeightNormPerChannelFloatDecoupled( + SolveWeightScalingStatsInputDimsFromModule, + SolveWeightScalingPerOutputChannelShapeFromModule, + SolveParameterScalingShape): + """Experimental narrow per-channel weight normalization-based signed integer quantizer + based on `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed + Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig. + + The formulation for weight normalization-based quantization is given below: + `y = clip(round( (g / s) * (w / norm(w)) )) * s` + + The default quantizer uses the decoupled rescaling integer quantization arithmetic + where the weight normalization calculation and parameterization are combined with the + scaling factor to become the pre-clipping scaling factor (i.e., `pre_scale`) and the + scaling factor is the post-clipping scaling factor (i.e., `post_scale`). for further + details on the arithmetic, see `ParameterPreScalingWeightNorm`. For further details + on the weight normalization-based quantization technique, see the referenced paper.""" + + @value + def scaling_init(scaling_init_impl): + return scaling_init_impl() + + proxy_class = DecoupledWeightQuantProxyFromInjector + tensor_quant = DecoupledRescalingIntQuant + decoupled_int_quant = DecoupledIntQuant + tensor_clamp_impl = TensorClamp + scaling_impl = ParameterScaling + restrict_scaling_impl = FloatRestrictValue + scaling_stats_impl = AbsMax + scaling_init_impl = ParameterFromStatsScalingInit + parameter_stats_scaling_init_impl = StatsFromParameterScaling + pre_scaling_impl = ParameterPreScalingWeightNorm + restrict_pre_scaling_impl = LogFloatRestrictValue + normalize_stats_impl = L2Norm + pre_scaling_shape = this.scaling_shape # TODO: decouple pre_scaling_shape from scaling_shape + int_scaling_impl = SingleArgStatelessBuffer(1.) + zero_point_impl = ZeroZeroPoint + pre_zero_point_impl = ZeroZeroPoint + bit_width_impl = BitWidthConst + narrow_range = True + signed = True + scaling_stats_input_view_shape_impl = OverOutputChannelView + stats_reduce_dim = SCALING_STATS_REDUCE_DIM + scaling_per_output_channel = True From 3276d79dba477357d7262f2f458b734354ca9e5e Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 14:55:29 -0700 Subject: [PATCH 12/16] Adding Int8WeightNormL2PerChannelFixedPoint injector --- src/brevitas/quant/fixed_point.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/brevitas/quant/fixed_point.py b/src/brevitas/quant/fixed_point.py index d3953023f..bb3d7893e 100644 --- a/src/brevitas/quant/fixed_point.py +++ b/src/brevitas/quant/fixed_point.py @@ -103,3 +103,20 @@ class Int4WeightPerTensorFixedPointDecoupled(WeightPerTensorFloatDecoupledL2Para restrict_scaling_impl = PowerOfTwoRestrictValue int_scaling_impl = PowerOfTwoIntScaling restrict_value_float_to_int_impl = CeilSte + + +class Int8WeightNormL2PerChannelFixedPoint(WeightNormPerChannelFloatDecoupled): + """ + Experimental 8-bit narrow signed integer quantizer with learned per-channel scaling factors + and L2 weight normalization based on `Quantized Neural Networks for Low-Precision Accumulation + with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig + (https://arxiv.org/abs/2301.13376). The quantizer learns scaling factors in the float domain and + learns vector parameter g in the log domain with the half-way rounding function. Suitable for + retraining from floating-point depthwise separable weights. + + Examples: + >>> from brevitas.nn import QuantConv2d + >>> conv = QuantConv2d(4, 4, 3, groups=4, weight_quant=Int8WeightNormL2PerChannelFixedPoint) + >>> conv.quant_weight() + """ + bit_width = 8 From 672914725eeb714a18042d2c681c98d9d4e0f612 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 14:58:28 -0700 Subject: [PATCH 13/16] Fixing L2Norm initialization --- src/brevitas/core/stats/stats_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index cd57f3e49..f31cf871f 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -375,7 +375,7 @@ class L2Norm(brevitas.jit.ScriptModule): __constants__ = ['stats_reduce_dim'] def __init__(self, stats_reduce_dim: Optional[int] = None) -> None: - super(L1Norm, self).__init__() + super(L2Norm, self).__init__() self.stats_reduce_dim = stats_reduce_dim @brevitas.jit.script_method From fbbf543e8b8d3de17ba6d8156a106f4869bb0e12 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 15:05:08 -0700 Subject: [PATCH 14/16] Typo fix --- src/brevitas/quant/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 233f1b65a..9222e17f1 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -313,7 +313,7 @@ class WeightNormPerChannelFloatDecoupled( The default quantizer uses the decoupled rescaling integer quantization arithmetic where the weight normalization calculation and parameterization are combined with the scaling factor to become the pre-clipping scaling factor (i.e., `pre_scale`) and the - scaling factor is the post-clipping scaling factor (i.e., `post_scale`). for further + scaling factor is the post-clipping scaling factor (i.e., `post_scale`). For further details on the arithmetic, see `ParameterPreScalingWeightNorm`. For further details on the weight normalization-based quantization technique, see the referenced paper.""" From 86098584adb9ad9127f0325300fe1e84e807dfa4 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 21 Mar 2023 15:09:52 -0700 Subject: [PATCH 15/16] Pre-commit fixes --- src/brevitas/quant/fixed_point.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/quant/fixed_point.py b/src/brevitas/quant/fixed_point.py index bb3d7893e..a29a299f3 100644 --- a/src/brevitas/quant/fixed_point.py +++ b/src/brevitas/quant/fixed_point.py @@ -108,8 +108,8 @@ class Int4WeightPerTensorFixedPointDecoupled(WeightPerTensorFloatDecoupledL2Para class Int8WeightNormL2PerChannelFixedPoint(WeightNormPerChannelFloatDecoupled): """ Experimental 8-bit narrow signed integer quantizer with learned per-channel scaling factors - and L2 weight normalization based on `Quantized Neural Networks for Low-Precision Accumulation - with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig + and L2 weight normalization based on `Quantized Neural Networks for Low-Precision Accumulation + with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig (https://arxiv.org/abs/2301.13376). The quantizer learns scaling factors in the float domain and learns vector parameter g in the log domain with the half-way rounding function. Suitable for retraining from floating-point depthwise separable weights. From 6aefeef4e426803665bc8cecfc289c04519f7b1d Mon Sep 17 00:00:00 2001 From: icolbert Date: Wed, 22 Mar 2023 11:51:43 -0700 Subject: [PATCH 16/16] Adding quant_decoupled to WBIOL weight quantizer tests --- tests/brevitas/nn/nn_quantizers_fixture.py | 39 ++++++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index be93d3693..cbe14920a 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -16,6 +16,7 @@ from brevitas.nn import QuantLinear from brevitas.nn.quant_rnn import QuantLSTM from brevitas.nn.quant_rnn import QuantRNN +from brevitas.quant.fixed_point import Int8WeightNormL2PerChannelFixedPoint from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling from brevitas.quant.scaled_int import Int8WeightPerTensorFloat @@ -31,10 +32,18 @@ FEATURES = 5 KERNEL_SIZE = 3 -WEIGHT_QUANTIZER = { +LSTM_WEIGHT_QUANTIZER = { 'None': None, 'quant_sym': Int8WeightPerTensorFloat, - 'quant_asym': ShiftedUint8WeightPerTensorFloat,} + 'quant_asym': ShiftedUint8WeightPerTensorFloat, +} + +WBIOL_WEIGHT_QUANTIZER = { + 'None': None, + 'quant_sym': Int8WeightPerTensorFloat, + 'quant_asym': ShiftedUint8WeightPerTensorFloat, + 'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint, +} IO_QUANTIZER = { 'None': None, @@ -44,25 +53,33 @@ SIGNED_ACT_QUANTIZER = { 'None': None, 'quant_sym': Int8ActPerTensorFloat, - 'quant_asym': ShiftedUint8ActPerTensorFloat,} + 'quant_asym': ShiftedUint8ActPerTensorFloat, +} UNSIGNED_ACT_QUANTIZER = { 'None': None, - 'quant_sym': Uint8ActPerTensorFloat,} + 'quant_sym': Uint8ActPerTensorFloat, +} BIAS_QUANTIZER = { 'None': None, 'quant_external': Int16Bias, - 'quant_internal': Int8BiasPerTensorFloatInternalScaling} + 'quant_internal': Int8BiasPerTensorFloatInternalScaling, +} QUANT_WBIOL_IMPL = [ - QuantLinear, QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d] + QuantLinear, + QuantConv1d, + QuantConv2d, + QuantConvTranspose1d, + QuantConvTranspose2d, +] @pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]]) @pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()]) @pytest_cases.parametrize('io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()]) -@pytest_cases.parametrize('weight_quantizer', WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WEIGHT_QUANTIZER.items()]) +@pytest_cases.parametrize('weight_quantizer', WBIOL_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WBIOL_WEIGHT_QUANTIZER.items()]) @pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]]) @pytest_cases.parametrize('module', QUANT_WBIOL_IMPL, ids=[f'model_type${c.__name__}' for c in QUANT_WBIOL_IMPL]) @pytest_cases.parametrize('is_training', [True, False], ids=[f'is_training${f}' for f in [True, False]]) @@ -121,7 +138,7 @@ def forward(self, x): @pytest_cases.parametrize('io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()]) @pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]]) @pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()]) -@pytest_cases.parametrize('weight_quantizer', WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WEIGHT_QUANTIZER.items()]) +@pytest_cases.parametrize('weight_quantizer', LSTM_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in LSTM_WEIGHT_QUANTIZER.items()]) @pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]]) @pytest_cases.parametrize('bidirectional', [True, False, 'shared_input_hidden'], ids=[f'bidirectional${f}' for f in [True, False, 'shared_input_hidden']]) @pytest_cases.parametrize('cifg', [True, False]) @@ -187,7 +204,7 @@ def forward(self, x): @pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()]) @pytest_cases.parametrize('signed_act_quantizer', SIGNED_ACT_QUANTIZER.items(), ids=[f'signed_act${c}' for c, _ in SIGNED_ACT_QUANTIZER.items()]) @pytest_cases.parametrize('unsigned_act_quantizer', UNSIGNED_ACT_QUANTIZER.items(), ids=[f'unsigned_act${c}' for c, _ in UNSIGNED_ACT_QUANTIZER.items()]) -@pytest_cases.parametrize('weight_quantizer', WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WEIGHT_QUANTIZER.items()]) +@pytest_cases.parametrize('weight_quantizer', LSTM_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in LSTM_WEIGHT_QUANTIZER.items()]) @pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]]) @pytest_cases.parametrize('bidirectional', [True, False, 'shared_input_hidden'], ids=[f'bidirectional${f}' for f in [True, False, 'shared_input_hidden']]) @pytest_cases.parametrize('cifg', [True, False]) @@ -253,7 +270,7 @@ def forward(self, x): @pytest_cases.parametrize('io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()]) @pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]]) @pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()]) -@pytest_cases.parametrize('weight_quantizer', WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WEIGHT_QUANTIZER.items()]) +@pytest_cases.parametrize('weight_quantizer', LSTM_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in LSTM_WEIGHT_QUANTIZER.items()]) @pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]]) @pytest_cases.parametrize('bidirectional', [True, False], ids=[f'bidirectional${f}' for f in [True, False]]) @pytest_cases.parametrize('num_layers', [1, 2], ids=[f'num_layers${f}' for f in [1, 2]]) @@ -307,7 +324,7 @@ def forward(self, x): @pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]]) @pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()]) @pytest_cases.parametrize('signed_act_quantizer', SIGNED_ACT_QUANTIZER.items(), ids=[f'signed_act${c}' for c, _ in SIGNED_ACT_QUANTIZER.items()]) -@pytest_cases.parametrize('weight_quantizer', WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WEIGHT_QUANTIZER.items()]) +@pytest_cases.parametrize('weight_quantizer', LSTM_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in LSTM_WEIGHT_QUANTIZER.items()]) @pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]]) @pytest_cases.parametrize('bidirectional', [True, False], ids=[f'bidirectional${f}' for f in [True, False]]) @pytest_cases.parametrize('num_layers', [1, 2], ids=[f'num_layers${f}' for f in [1, 2]])