Skip to content

Commit

Permalink
Feat (core)!: deprecate bitwidth-less bias (#839)
Browse files Browse the repository at this point in the history
Breaking change: All bias quantizers now require to have their bitwidth specified
  • Loading branch information
Giuseppe5 authored Feb 28, 2024
1 parent 6079b12 commit 8d95bbd
Show file tree
Hide file tree
Showing 16 changed files with 48 additions and 85 deletions.
5 changes: 1 addition & 4 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,13 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
@staticmethod
def gate_params_fwd(gate, quant_input):
acc_scale = None
acc_bit_width = None
quant_weight_ih = gate.input_weight()
quant_weight_hh = gate.hidden_weight()
if isinstance(quant_input, QuantTensor):
acc_bit_width = None # TODO
if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor):
acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1)
acc_scale = quant_weight_ih.scale.view(acc_scale_shape)
acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape)
quant_bias = gate.bias_quant(gate.bias, acc_scale, acc_bit_width)
quant_bias = gate.bias_quant(gate.bias, acc_scale)
return quant_weight_ih, quant_weight_hh, quant_bias

def reset_parameters(self) -> None:
Expand Down
10 changes: 5 additions & 5 deletions src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,13 @@ def quant_bias(self):
if self.bias is None:
return None
scale = self.quant_bias_scale()
bit_width = self.quant_bias_bit_width()
quant_bias = self.bias_quant(self.bias, scale, bit_width)
quant_bias = self.bias_quant(self.bias, scale)
return quant_bias

def quant_bias_scale(self):
if self.bias is None or not self.is_bias_quant_enabled:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
if not self.bias_quant.requires_input_scale:
return self.bias_quant(self.bias).scale
else:
if self._cached_bias is None:
Expand All @@ -197,7 +196,8 @@ def quant_bias_scale(self):
def quant_bias_zero_point(self):
if self.bias is None:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:

if not self.bias_quant.requires_input_scale:
bias_quant = self.bias_quant(self.bias)
if isinstance(bias_quant, QuantTensor):
return bias_quant.zero_point
Expand All @@ -215,7 +215,7 @@ def quant_bias_zero_point(self):
def quant_bias_bit_width(self):
if self.bias is None or not self.is_bias_quant_enabled:
return None
if not self.bias_quant.requires_input_scale and not self.bias_quant.requires_input_bit_width:
if not self.bias_quant.requires_input_scale:
return self.bias_quant(self.bias).bit_width
else:
if self._cached_bias is None:
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,10 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
output_signed = quant_input.signed or quant_weight.signed

if self.bias is not None:
quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width)
quant_bias = self.bias_quant(self.bias, output_scale)
if not self.training and self.cache_inference_quant_bias and isinstance(quant_bias,
QuantTensor):

self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False)
output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input),
Expand Down
37 changes: 10 additions & 27 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def forward(self, x: torch.Tensor) -> QuantTensor:

@runtime_checkable
class BiasQuantProxyProtocol(QuantProxyProtocol, Protocol):
requires_input_bit_width: bool
requires_input_scale: bool

def forward(
Expand Down Expand Up @@ -161,13 +160,6 @@ class BiasQuantProxyFromInjector(ParameterQuantProxyFromInjector, BiasQuantProxy
def tracked_parameter_list(self):
return [m.bias for m in self.tracked_module_list if m.bias is not None]

@property
def requires_input_bit_width(self) -> bool:
if self.is_quant_enabled:
return self.quant_injector.requires_input_bit_width
else:
return False

@property
def requires_input_scale(self) -> bool:
if self.is_quant_enabled:
Expand All @@ -179,42 +171,33 @@ def scale(self):
if self.requires_input_scale:
return None
zhs = self._zero_hw_sentinel()
scale = self.__call__(self.tracked_parameter_list[0], zhs, zhs).scale
scale = self.__call__(self.tracked_parameter_list[0], zhs).scale
return scale

def zero_point(self):
zhs = self._zero_hw_sentinel()
zero_point = self.__call__(self.tracked_parameter_list[0], zhs, zhs).zero_point
zero_point = self.__call__(self.tracked_parameter_list[0], zhs).zero_point
return zero_point

def bit_width(self):
if self.requires_input_bit_width:
return None
zhs = self._zero_hw_sentinel()
bit_width = self.__call__(self.tracked_parameter_list[0], zhs, zhs).bit_width
bit_width = self.__call__(self.tracked_parameter_list[0], zhs).bit_width
return bit_width

def forward(
self,
x: Tensor,
input_scale: Optional[Tensor] = None,
input_bit_width: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]:
def forward(self,
x: Tensor,
input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None:
raise RuntimeError("Input scale required")
if self.requires_input_bit_width and input_bit_width is None:
raise RuntimeError("Input bit-width required")
if self.requires_input_scale and self.requires_input_bit_width:
input_scale = input_scale.view(-1)
out, out_scale, out_zp, out_bit_width = impl(x, input_scale, input_bit_width)
elif self.requires_input_scale and not self.requires_input_bit_width:

if self.requires_input_scale:
input_scale = input_scale.view(-1)
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
elif not self.requires_input_scale and not self.requires_input_bit_width:
out, out_scale, out_zp, out_bit_width = impl(x)
else:
raise RuntimeError("Internally defined bit-width required")
out, out_scale, out_zp, out_bit_width = impl(x)

return QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
else:
return x
1 change: 0 additions & 1 deletion src/brevitas/quant/fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ class Int8BiasPerTensorFixedPointInternalScaling(IntQuant,
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8BiasPerTensorFixedPointInternalScaling)
"""
requires_input_scale = False
requires_input_bit_width = False


class Int4WeightPerTensorFixedPointDecoupled(WeightPerTensorFloatDecoupledL2Param):
Expand Down
1 change: 0 additions & 1 deletion src/brevitas/quant/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class NoneBiasQuant(BiasQuantSolver):
"""
quant_type = QuantType.FP
requires_input_scale = False
requires_input_bit_width = False


class NoneTruncQuant(TruncQuantSolver):
Expand Down
6 changes: 0 additions & 6 deletions src/brevitas/quant/scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class IntBias(IntQuant, BiasQuantSolver):
"""
tensor_clamp_impl = TensorClamp
requires_input_scale = True
requires_input_bit_width = True


class Int8Bias(IntBias):
Expand All @@ -98,7 +97,6 @@ class Int8Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8Bias)
"""
bit_width = 8
requires_input_bit_width = False


class Int16Bias(IntBias):
Expand All @@ -111,7 +109,6 @@ class Int16Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias)
"""
bit_width = 16
requires_input_bit_width = False


class Int24Bias(IntBias):
Expand All @@ -124,7 +121,6 @@ class Int24Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias)
"""
bit_width = 24
requires_input_bit_width = False


class Int32Bias(IntBias):
Expand All @@ -137,7 +133,6 @@ class Int32Bias(IntBias):
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int16Bias)
"""
bit_width = 32
requires_input_bit_width = False


class Int8BiasPerTensorFloatInternalScaling(IntQuant,
Expand All @@ -153,7 +148,6 @@ class Int8BiasPerTensorFloatInternalScaling(IntQuant,
>>> fc = QuantLinear(10, 5, bias=True, bias_quant=Int8BiasPerTensorFloatInternalScaling)
"""
requires_input_scale = False
requires_input_bit_width = False


class Int8WeightPerTensorFloat(NarrowIntQuant,
Expand Down
22 changes: 4 additions & 18 deletions src/brevitas/quant/solver/bias.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.core.function_wrapper import Identity
from brevitas.core.quant import PrescaledRestrictIntQuant
from brevitas.core.quant import PrescaledRestrictIntQuantWithInputBitWidth
from brevitas.core.quant import RescalingIntQuant
from brevitas.inject import ExtendedInjector
from brevitas.inject import value
Expand Down Expand Up @@ -34,29 +32,17 @@ def scaling_per_output_channel_shape(module):
return (module.out_channels,)


class SolveBiasBitWidthImplFromEnum(ExtendedInjector):

@value
def bit_width_impl(bit_width_impl_type, requires_input_bit_width):
if not requires_input_bit_width:
return solve_bit_width_impl_from_enum(bit_width_impl_type)
else:
return Identity


class SolveBiasTensorQuantFromEnum(SolveIntQuantFromEnum):

@value
def tensor_quant(quant_type, requires_input_bit_width, requires_input_scale):
def tensor_quant(quant_type, requires_input_scale):
if quant_type == QuantType.FP:
return None
elif quant_type == QuantType.INT:
if not requires_input_bit_width and requires_input_scale:
if requires_input_scale:
return PrescaledRestrictIntQuant
elif not requires_input_bit_width and not requires_input_scale:
else:
return RescalingIntQuant
else: # requires_input_bit_width == True
return PrescaledRestrictIntQuantWithInputBitWidth
elif quant_type == QuantType.TERNARY:
raise RuntimeError(f'{quant_type} not supported.')
elif quant_type == QuantType.BINARY:
Expand All @@ -75,7 +61,7 @@ class BiasQuantSolver(SolveScalingStatsInputViewShapeImplFromEnum,
SolveParameterScalingImplFromEnum,
SolveParameterTensorClampImplFromEnum,
SolveParameterScalingInitFromEnum,
SolveBiasBitWidthImplFromEnum,
SolveBitWidthImplFromEnum,
SolveBiasScalingPerOutputChannelShapeFromModule,
SolveBiasScalingStatsInputConcatDimFromModule,
SolveBiasTensorQuantFromEnum,
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/bnn_pynq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ These pretrained models and training scripts are courtesy of
| CNV_1W1A | 8 bit | 1 bit | 1 bit | CIFAR10 | 84.22% |
| CNV_1W2A | 8 bit | 1 bit | 2 bit | CIFAR10 | 87.80% |
| CNV_2W2A | 8 bit | 2 bit | 2 bit | CIFAR10 | 89.03% |
| RESNET18_4W4A | 8 bit (assumed) | 4 bit | 4 bit | CIFAR10 | 92.60% |
| RESNET18_4W4A | 8 bit (assumed) | 4 bit | 4 bit | CIFAR10 | 92.61% |

## Train

Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/bnn_pynq/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerChannelFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import IntBias
from brevitas.quant import Int32Bias
from brevitas.quant import TruncTo8bit
from brevitas.quant_tensor import QuantTensor

Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
act_bit_width=8,
weight_bit_width=8,
round_average_pool=False,
last_layer_bias_quant=IntBias,
last_layer_bias_quant=Int32Bias,
weight_quant=Int8WeightPerChannelFloat,
first_layer_weight_quant=Int8WeightPerChannelFloat,
last_layer_weight_quant=Int8WeightPerTensorFloat):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from brevitas.nn import QuantLinear
from brevitas.nn import QuantReLU
from brevitas.nn import TruncAvgPool2d
from brevitas.quant import IntBias
from brevitas.quant import Int32Bias

from .common import CommonIntWeightPerChannelQuant
from .common import CommonIntWeightPerTensorQuant
Expand Down Expand Up @@ -181,7 +181,7 @@ def __init__(
in_channels,
num_classes,
bias=True,
bias_quant=IntBias,
bias_quant=Int32Bias,
weight_quant=last_layer_weight_quant,
weight_bit_width=last_layer_bit_width)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from brevitas.nn import QuantLinear
from brevitas.nn import QuantReLU
from brevitas.nn import TruncAvgPool2d
from brevitas.quant import IntBias
from brevitas.quant import Int32Bias

from .common import *

Expand Down Expand Up @@ -300,7 +300,7 @@ def __init__(
in_features=in_channels,
out_features=num_classes,
bias=True,
bias_quant=IntBias,
bias_quant=Int32Bias,
weight_bit_width=bit_width,
weight_quant=CommonIntWeightPerTensorQuant)

Expand Down
12 changes: 6 additions & 6 deletions src/brevitas_examples/imagenet_classification/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ and by no means a direct mapping to hardware should be assumed.

Below in the table is a list of example pretrained models made available for reference.

| Name | Cfg | Scaling Type | First layer weights | Weights | Activations | Avg pool | Top1 | Top5 | Pretrained model | Retrained from |
|--------------|-----------------------|----------------------------|---------------------|---------|-------------|----------|-------|-------|-------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 71.14 | 90.10 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 w/ Hadamard classifier | quant_proxylessnas_mobile14_hadamard_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 73.52 | 91.46 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_hadamard_4b-r0/quant_proxylessnas_mobile14_hadamard_4b-4acbfa9f.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 74.42 | 92.04 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b-r0/quant_proxylessnas_mobile14_4b-e10882e1.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b5b | Floating-point per channel | 8 bit | 4 bit, 5 bit | 4 bit, 5 bit | 4 bit | 75.01 | 92.33 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b5b-r0/quant_proxylessnas_mobile14_4b5b-2bdf7f8d.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| Name | Cfg | Scaling Type | First layer weights | Weights | Activations | Avg pool | Top1 | Pretrained model | Retrained from |
|--------------|-----------------------|----------------------------|---------------------|---------|-------------|----------|-------|-------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
| MobileNet V1 | quant_mobilenet_v1_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 70.95 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_mobilenet_v1_4b-r1/quant_mobilenet_v1_4b-0100a667.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 w/ Hadamard classifier | quant_proxylessnas_mobile14_hadamard_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 72.87 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_hadamard_4b-r0/quant_proxylessnas_mobile14_hadamard_4b-4acbfa9f.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b | Floating-point per channel | 8 bit | 4 bit | 4 bit | 4 bit | 74.39 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b-r0/quant_proxylessnas_mobile14_4b-e10882e1.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |
| ProxylessNAS Mobile14 | quant_proxylessnas_mobile14_4b5b | Floating-point per channel | 8 bit | 4 bit, 5 bit | 4 bit, 5 bit | 4 bit | 74.94 | [Download](https://github.com/Xilinx/brevitas/releases/download/quant_proxylessnas_mobile14_4b5b-r0/quant_proxylessnas_mobile14_4b5b-2bdf7f8d.pth) | [link](https://github.com/osmr/imgclsmob/tree/master/pytorch) |


To evaluate a pretrained quantized model on ImageNet:
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/imagenet_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torchvision.transforms as transforms
from tqdm import tqdm

from brevitas.quant_tensor import QuantTensor

SEED = 123456

MEAN = [0.485, 0.456, 0.406]
Expand Down Expand Up @@ -81,6 +83,8 @@ def print_accuracy(top1, prefix=''):
images = images.to(dtype)

output = model(images)
if isinstance(output, QuantTensor):
output = output.value
# measure accuracy
acc1, = accuracy(output, target, stable=stable)
top1.update(acc1[0], images.size(0))
Expand Down
5 changes: 0 additions & 5 deletions tests/brevitas/nn/test_hadamard.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import torch
from torch.nn import Module

from brevitas.nn import HadamardClassifier
from brevitas.quant import IntBias
from brevitas.quant_tensor import QuantTensor

OUTPUT_FEATURES = 10
INPUT_FEATURES = 5
Expand Down
Loading

0 comments on commit 8d95bbd

Please sign in to comment.