Skip to content

Commit

Permalink
Feat (quant): Add weight norm based integer quantizer (#559)
Browse files Browse the repository at this point in the history
* Adding weight normalization-based integer quantizer

* Pre-commit fixes

* Adding variable for p-norm

* Adding SingleArgStatelessBuffer

* Adding L1Norm as scaling_stats_impl for weight normalization

* Adding L2Norm for normalize_stats_impl

* Adding ParameterPreScalingWeightNorm

* Adding modules to top-level imports

* Removing WeightNormIntQuant

* Updating list of modules in pre_scaling.py

* Adding WeightNormPerChannelFloatDecoupled

* Adding Int8WeightNormL2PerChannelFixedPoint injector

* Fixing L2Norm initialization

* Typo fix

* Pre-commit fixes

* Adding quant_decoupled to WBIOL weight quantizer tests
  • Loading branch information
i-colbert authored Mar 24, 2023
1 parent 10edc24 commit 735b183
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/brevitas/core/scaling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .int_scaling import IntScaling
from .int_scaling import PowerOfTwoIntScaling
from .pre_scaling import ParameterPreScalingWeightNorm
from .runtime import RuntimeStatsScaling
from .runtime import StatsFromParameterScaling
from .standalone import ConstScaling
Expand Down
110 changes: 110 additions & 0 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause


from typing import List, Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter

import brevitas
import brevitas.config as config
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.stats import SCALAR_SHAPE
from brevitas.core.stats.stats_wrapper import _Stats
from brevitas.function import abs_binary_sign_grad

__all__ = [
"ParameterPreScalingWeightNorm",
]

class ParameterPreScalingWeightNorm(brevitas.jit.ScriptModule):
"""
ScriptModule implementation of learned pre-clipping scaling factor to support weight
normalization-based quantization as proposed in `Quantized Neural Networks for Low-
Precision Accumulation with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo,
and J. Petri-Koenig.
The module parameterizes the pre-clipping scaling factor (i.e., `pre_scale`) of the
decoupled tensor quantizer (i.e., `DecoupledRescalingIntQuant`) by combining the
calculuated weight norm stats (i.e., `d_w`) with both the parameterized weight norm
vector (i.e., `g`) and the post-clipping scaling factor (i.e., `post_scale`). The
arithmetic is outlined below.
The formulation for weight normalization-based quantization is given below:
`y = clip(round( (g / s) * (w / norm(w)) )) * s`
which we re-write as:
`y = clip(round(w / pre_scale)) * post_scale`
where `pre_scale = s * norm(w) / g` and `post_scale = s`.
Here, `pre_scale` refers to the pre-clipping scaling factor and `post_scale` refers to
the post-clipping scaling factor.
Args:
scaling_impl (Module): post-clipping scaling factor.
normalize_stats_impl (Module): calculate statistics for normalizing weight parameter.
scaling_stats_input_view_shape_impl (Module): transforming scaling to a new shape.
tracked_parameter_list (List[torch.nn.Parameter]): list of tracked weight parameters
for tensor quantizer.
pre_scaling_shape (Tuple[int]): shape of pre-clipping scaling factor. Default: None.
restrict_pre_scaling_impl (Module): restrict pre_scaling_init according to some
criteria. Default: None.
pre_scaling_min_val (float): force a lower-bound on scaling_init. Default: None.
Returns:
Tensor: scaling factor wrapped in a float torch.Tensor.
"""
def __init__(
self,
scaling_impl: Module,
normalize_stats_impl: Module,
scaling_stats_input_view_shape_impl: Module,
tracked_parameter_list: List[torch.nn.Parameter],
pre_scaling_shape: Optional[Tuple[int, ...]] = None,
restrict_pre_scaling_impl: Optional[Module] = None,
pre_scaling_min_val: Optional[float] = None) -> None:
super(ParameterPreScalingWeightNorm, self).__init__()

self.stats = _Stats(normalize_stats_impl, pre_scaling_shape)
self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
self.scaling_impl = scaling_impl # this is the post-clipping scaling factor

if len(tracked_parameter_list) > 1:
raise NotImplementedError(
"Error: ParameterPreScalingWeightNorm does not support multiple tracked quantizers."
)
assert len(tracked_parameter_list) == 1

# Initialize the weight norm parameter vector from the tracked parameter itself
param = tracked_parameter_list[0]
param = self.stats_input_view_shape_impl(param)
pre_scaling_init = self.stats(param)
if restrict_pre_scaling_impl is not None:
pre_scaling_init = restrict_pre_scaling_impl.restrict_init_tensor(pre_scaling_init)
if pre_scaling_init.shape == SCALAR_SHAPE and pre_scaling_shape is not None:
pre_scaling_init = torch.full(pre_scaling_shape, pre_scaling_init)
self.value = Parameter(pre_scaling_init)
self.restrict_clamp_scaling = _RestrictClampValue(pre_scaling_min_val, restrict_pre_scaling_impl)

@brevitas.jit.script_method
def forward(self, weights: Tensor) -> Tensor:
"""Takes weights as input and returns the pre-clipping scaling factor"""
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
s = self.scaling_impl(weights) # s
value = (s * d_w) / g
return value

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
value_key = prefix + 'value'
retrocomp_value_key = prefix + 'learned_value'
if retrocomp_value_key in state_dict: # retrocompatibility
state_dict[value_key] = state_dict.pop(retrocomp_value_key)
super(ParameterPreScalingWeightNorm, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
missing_keys.remove(value_key)
2 changes: 2 additions & 0 deletions src/brevitas/core/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .stats_op import AbsMaxL2
from .stats_op import AbsMinMax
from .stats_op import AbsPercentile
from .stats_op import L1Norm
from .stats_op import L2Norm
from .stats_op import MeanLearnedSigmaStd
from .stats_op import MeanSigmaStd
from .stats_op import NegativeMinOrZero
Expand Down
36 changes: 36 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,39 @@ def forward(self, x: Tensor):
min_divergence_idx = torch.argmin(divergence)
opt_threshold = thresholds[min_divergence_idx]
return opt_threshold


class L1Norm(brevitas.jit.ScriptModule):
"""ScriptModule implementation to collect per-channel L1 normalization stats
for weight normalization-based quantization."""
__constants__ = ['stats_reduce_dim']

def __init__(self, stats_reduce_dim: Optional[int] = None) -> None:
super(L1Norm, self).__init__()
self.stats_reduce_dim = stats_reduce_dim

@brevitas.jit.script_method
def forward(self, x: Tensor):
if self.stats_reduce_dim is None:
# Need to be able to return the max per-channel L1 norm as a scalar
raise NotImplementedError("L1 normalization is not supported per-tensor yet.")
else:
return x.norm(p=1, dim=self.stats_reduce_dim, keepdim=True)


class L2Norm(brevitas.jit.ScriptModule):
"""ScriptModule implementation to collect per-channel L2 normalization stats
for weight normalization-based quantization."""
__constants__ = ['stats_reduce_dim']

def __init__(self, stats_reduce_dim: Optional[int] = None) -> None:
super(L2Norm, self).__init__()
self.stats_reduce_dim = stats_reduce_dim

@brevitas.jit.script_method
def forward(self, x: Tensor):
if self.stats_reduce_dim is None:
# Need to be able to return the max per-channel L2 norm as a scalar
raise NotImplementedError("L2 normalization is not supported per-tensor yet.")
else:
return x.norm(p=2, dim=self.stats_reduce_dim, keepdim=True)
11 changes: 11 additions & 0 deletions src/brevitas/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,14 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
destination=destination, prefix=prefix, keep_vars=keep_vars)
del output_dict[prefix + VALUE_ATTR_NAME]
return output_dict


class SingleArgStatelessBuffer(brevitas.jit.ScriptModule):

def __init__(self, value: torch.Tensor):
super(SingleArgStatelessBuffer, self).__init__()
self.const = StatelessBuffer(torch.tensor(value))

@brevitas.jit.script_method
def forward(self, placeholder):
return self.const()
54 changes: 53 additions & 1 deletion src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,26 @@

from brevitas.core.bit_width import BitWidthConst
from brevitas.core.function_wrapper import OverOutputChannelView
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.function_wrapper import TensorClampSte
from brevitas.core.function_wrapper.ops_ste import CeilSte
from brevitas.core.quant import ClampedBinaryQuant
from brevitas.core.quant.int import DecoupledRescalingIntQuant
from brevitas.core.quant.int_base import DecoupledIntQuant
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.restrict_val import LogFloatRestrictValue
from brevitas.core.scaling import IntScaling
from brevitas.core.scaling import ParameterPreScalingWeightNorm
from brevitas.core.scaling import ParameterScaling
from brevitas.core.scaling import SCALAR_SHAPE
from brevitas.core.scaling import SCALING_STATS_REDUCE_DIM
from brevitas.core.scaling import StatsFromParameterScaling
from brevitas.core.stats import AbsMax
from brevitas.core.stats import AbsMaxL2
from brevitas.core.stats import L2Norm
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.stats import NegativePercentileOrZero
from brevitas.core.utils import SingleArgStatelessBuffer
from brevitas.core.zero_point import ParameterFromRuntimeZeroPoint
from brevitas.core.zero_point import StatsFromParameterZeroPoint
from brevitas.core.zero_point import ZeroZeroPoint
Expand Down Expand Up @@ -56,7 +61,8 @@
'IntTrunc',
'SignedBinaryClampedConst',
'WeightPerTensorFloatDecoupledL2Param',
'WeightPerChannelFloatDecoupled'
'WeightPerChannelFloatDecoupled',
'WeightNormPerChannelFloatDecoupled',
]


Expand Down Expand Up @@ -291,3 +297,49 @@ def scaling_init(scaling_init_impl):
scaling_stats_input_view_shape_impl = OverOutputChannelView
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_per_output_channel = True


class WeightNormPerChannelFloatDecoupled(
SolveWeightScalingStatsInputDimsFromModule,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveParameterScalingShape):
"""Experimental narrow per-channel weight normalization-based signed integer quantizer
based on `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed
Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig.
The formulation for weight normalization-based quantization is given below:
`y = clip(round( (g / s) * (w / norm(w)) )) * s`
The default quantizer uses the decoupled rescaling integer quantization arithmetic
where the weight normalization calculation and parameterization are combined with the
scaling factor to become the pre-clipping scaling factor (i.e., `pre_scale`) and the
scaling factor is the post-clipping scaling factor (i.e., `post_scale`). For further
details on the arithmetic, see `ParameterPreScalingWeightNorm`. For further details
on the weight normalization-based quantization technique, see the referenced paper."""

@value
def scaling_init(scaling_init_impl):
return scaling_init_impl()

proxy_class = DecoupledWeightQuantProxyFromInjector
tensor_quant = DecoupledRescalingIntQuant
decoupled_int_quant = DecoupledIntQuant
tensor_clamp_impl = TensorClamp
scaling_impl = ParameterScaling
restrict_scaling_impl = FloatRestrictValue
scaling_stats_impl = AbsMax
scaling_init_impl = ParameterFromStatsScalingInit
parameter_stats_scaling_init_impl = StatsFromParameterScaling
pre_scaling_impl = ParameterPreScalingWeightNorm
restrict_pre_scaling_impl = LogFloatRestrictValue
normalize_stats_impl = L2Norm
pre_scaling_shape = this.scaling_shape # TODO: decouple pre_scaling_shape from scaling_shape
int_scaling_impl = SingleArgStatelessBuffer(1.)
zero_point_impl = ZeroZeroPoint
pre_zero_point_impl = ZeroZeroPoint
bit_width_impl = BitWidthConst
narrow_range = True
signed = True
scaling_stats_input_view_shape_impl = OverOutputChannelView
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_per_output_channel = True
17 changes: 17 additions & 0 deletions src/brevitas/quant/fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,20 @@ class Int4WeightPerTensorFixedPointDecoupled(WeightPerTensorFloatDecoupledL2Para
restrict_scaling_impl = PowerOfTwoRestrictValue
int_scaling_impl = PowerOfTwoIntScaling
restrict_value_float_to_int_impl = CeilSte


class Int8WeightNormL2PerChannelFixedPoint(WeightNormPerChannelFloatDecoupled):
"""
Experimental 8-bit narrow signed integer quantizer with learned per-channel scaling factors
and L2 weight normalization based on `Quantized Neural Networks for Low-Precision Accumulation
with Guaranteed Overflow Avoidance` by I. Colbert, A. Pappalardo, and J. Petri-Koenig
(https://arxiv.org/abs/2301.13376). The quantizer learns scaling factors in the float domain and
learns vector parameter g in the log domain with the half-way rounding function. Suitable for
retraining from floating-point depthwise separable weights.
Examples:
>>> from brevitas.nn import QuantConv2d
>>> conv = QuantConv2d(4, 4, 3, groups=4, weight_quant=Int8WeightNormL2PerChannelFixedPoint)
>>> conv.quant_weight()
"""
bit_width = 8
39 changes: 28 additions & 11 deletions tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from brevitas.nn import QuantLinear
from brevitas.nn.quant_rnn import QuantLSTM
from brevitas.nn.quant_rnn import QuantRNN
from brevitas.quant.fixed_point import Int8WeightNormL2PerChannelFixedPoint
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
Expand All @@ -31,10 +32,18 @@
FEATURES = 5
KERNEL_SIZE = 3

WEIGHT_QUANTIZER = {
LSTM_WEIGHT_QUANTIZER = {
'None': None,
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat,}
'quant_asym': ShiftedUint8WeightPerTensorFloat,
}

WBIOL_WEIGHT_QUANTIZER = {
'None': None,
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat,
'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint,
}

IO_QUANTIZER = {
'None': None,
Expand All @@ -44,25 +53,33 @@
SIGNED_ACT_QUANTIZER = {
'None': None,
'quant_sym': Int8ActPerTensorFloat,
'quant_asym': ShiftedUint8ActPerTensorFloat,}
'quant_asym': ShiftedUint8ActPerTensorFloat,
}

UNSIGNED_ACT_QUANTIZER = {
'None': None,
'quant_sym': Uint8ActPerTensorFloat,}
'quant_sym': Uint8ActPerTensorFloat,
}

BIAS_QUANTIZER = {
'None': None,
'quant_external': Int16Bias,
'quant_internal': Int8BiasPerTensorFloatInternalScaling}
'quant_internal': Int8BiasPerTensorFloatInternalScaling,
}

QUANT_WBIOL_IMPL = [
QuantLinear, QuantConv1d, QuantConv2d, QuantConvTranspose1d, QuantConvTranspose2d]
QuantLinear,
QuantConv1d,
QuantConv2d,
QuantConvTranspose1d,
QuantConvTranspose2d,
]


@pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize('io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', WBIOL_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WBIOL_WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize('module', QUANT_WBIOL_IMPL, ids=[f'model_type${c.__name__}' for c in QUANT_WBIOL_IMPL])
@pytest_cases.parametrize('is_training', [True, False], ids=[f'is_training${f}' for f in [True, False]])
Expand Down Expand Up @@ -121,7 +138,7 @@ def forward(self, x):
@pytest_cases.parametrize('io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()])
@pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', LSTM_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in LSTM_WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize('bidirectional', [True, False, 'shared_input_hidden'], ids=[f'bidirectional${f}' for f in [True, False, 'shared_input_hidden']])
@pytest_cases.parametrize('cifg', [True, False])
Expand Down Expand Up @@ -187,7 +204,7 @@ def forward(self, x):
@pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize('signed_act_quantizer', SIGNED_ACT_QUANTIZER.items(), ids=[f'signed_act${c}' for c, _ in SIGNED_ACT_QUANTIZER.items()])
@pytest_cases.parametrize('unsigned_act_quantizer', UNSIGNED_ACT_QUANTIZER.items(), ids=[f'unsigned_act${c}' for c, _ in UNSIGNED_ACT_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', LSTM_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in LSTM_WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize('bidirectional', [True, False, 'shared_input_hidden'], ids=[f'bidirectional${f}' for f in [True, False, 'shared_input_hidden']])
@pytest_cases.parametrize('cifg', [True, False])
Expand Down Expand Up @@ -253,7 +270,7 @@ def forward(self, x):
@pytest_cases.parametrize('io_quantizer', IO_QUANTIZER.items(), ids=[f'io_quant${c}' for c, _ in IO_QUANTIZER.items()])
@pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', LSTM_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in LSTM_WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize('bidirectional', [True, False], ids=[f'bidirectional${f}' for f in [True, False]])
@pytest_cases.parametrize('num_layers', [1, 2], ids=[f'num_layers${f}' for f in [1, 2]])
Expand Down Expand Up @@ -307,7 +324,7 @@ def forward(self, x):
@pytest_cases.parametrize('input_quantized', [True, False], ids=[f'input_quantized${c}' for c in [True, False]])
@pytest_cases.parametrize('bias_quantizer', BIAS_QUANTIZER.items(), ids=[f'bias_quant${c}' for c, _ in BIAS_QUANTIZER.items()])
@pytest_cases.parametrize('signed_act_quantizer', SIGNED_ACT_QUANTIZER.items(), ids=[f'signed_act${c}' for c, _ in SIGNED_ACT_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('weight_quantizer', LSTM_WEIGHT_QUANTIZER.items(), ids=[f'weight_quant${c}' for c, _ in LSTM_WEIGHT_QUANTIZER.items()])
@pytest_cases.parametrize('return_quant_tensor', [True, False], ids=[f'return_quant_tensor${f}' for f in [True, False]])
@pytest_cases.parametrize('bidirectional', [True, False], ids=[f'bidirectional${f}' for f in [True, False]])
@pytest_cases.parametrize('num_layers', [1, 2], ids=[f'num_layers${f}' for f in [1, 2]])
Expand Down

0 comments on commit 735b183

Please sign in to comment.