Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (a2q+): adding new super resolution models to brevitas_examples #811

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/brevitas_examples/super_resolution/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ Note that this is a difference from many academic works that train only on the Y
| [quant_espcn_x2_w8a8_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_base-f761e4a1.pth) | x2 | int8 | (u)int8 | 30.96 |
| [quant_espcn_x2_w8a8_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth) | x2 | int8 | (u)int8 | 30.79 |
| [quant_espcn_x2_w8a8_a2q_16b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth) | x2 | int8 | (u)int8 | 30.56 |
| [quant_espcn_x2_w8a8_a2q_plus_16b](https://github.com/Xilinx/brevitas/releases/download/super_res_r2/quant_espcn_x2_w8a8_a2q_plus_16b-0ddf46f1.pth) | x2 | int8 | (u)int8 | 31.24 |
||
| [quant_espcn_x2_w4a4_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_base-80658e6d.pth) | x2 | int4 | (u)int4 | 30.30 |
| [quant_espcn_x2_w4a4_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth) | x2 | int4 | (u)int4 | 30.27 |
| [quant_espcn_x2_w4a4_a2q_13b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth) | x2 | int4 | (u)int4 | 30.24 |
| [quant_espcn_x2_w4a4_a2q_plus_13b](https://github.com/Xilinx/brevitas/releases/download/super_res_r2/quant_espcn_x2_w4a4_a2q_plus_13b-6e6d55f0.pth) | x2 | int4 | (u)int4 | 30.95 |


## Train
Expand Down
46 changes: 37 additions & 9 deletions src/brevitas_examples/super_resolution/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import hub
import torch.nn as nn

from .common import CommonIntAccumulatorAwareZeroCenterWeightQuant
from .espcn import *

model_impl = {
Expand Down Expand Up @@ -43,18 +44,45 @@
upscale_factor=2,
weight_bit_width=4,
act_bit_width=4,
acc_bit_width=13)}
acc_bit_width=13),
'quant_espcn_x2_w4a4_a2q_plus_13b':
partial(
quant_espcn,
upscale_factor=2,
weight_bit_width=4,
act_bit_width=4,
acc_bit_width=13,
weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant),
'quant_espcn_x2_w8a8_a2q_plus_16b':
partial(
quant_espcn,
upscale_factor=2,
weight_bit_width=8,
act_bit_width=8,
acc_bit_width=16,
weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant)}

root_url = 'https://github.com/Xilinx/brevitas/releases/download/super_res_r1'
root_url = 'https://github.com/Xilinx/brevitas/releases/download/'

model_url = {
'float_espcn_x2': f'{root_url}/float_espcn_x2-2f85a454.pth',
'quant_espcn_x2_w4a4_a2q_13b': f'{root_url}/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth',
'quant_espcn_x2_w4a4_a2q_32b': f'{root_url}/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth',
'quant_espcn_x2_w4a4_base': f'{root_url}/quant_espcn_x2_w4a4_base-80658e6d.pth',
'quant_espcn_x2_w8a8_a2q_16b': f'{root_url}/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth',
'quant_espcn_x2_w8a8_a2q_32b': f'{root_url}/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth',
'quant_espcn_x2_w8a8_base': f'{root_url}/quant_espcn_x2_w8a8_base-f761e4a1.pth'}
'float_espcn_x2':
f'{root_url}/super_res_r1/float_espcn_x2-2f85a454.pth',
'quant_espcn_x2_w4a4_a2q_13b':
f'{root_url}/super_res_r1/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth',
'quant_espcn_x2_w4a4_a2q_32b':
f'{root_url}/super_res_r1/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth',
'quant_espcn_x2_w4a4_base':
f'{root_url}/super_res_r1/quant_espcn_x2_w4a4_base-80658e6d.pth',
'quant_espcn_x2_w8a8_a2q_16b':
f'{root_url}/super_res_r1/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth',
'quant_espcn_x2_w8a8_a2q_32b':
f'{root_url}/super_res_r1/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth',
'quant_espcn_x2_w8a8_base':
f'{root_url}/super_res_r1/quant_espcn_x2_w8a8_base-f761e4a1.pth',
'quant_espcn_x2_w4a4_a2q_plus_13b':
f'{root_url}/super_res_r2/quant_espcn_x2_w4a4_a2q_plus_13b-6e6d55f0.pth',
'quant_espcn_x2_w8a8_a2q_plus_16b':
f'{root_url}/super_res_r2/quant_espcn_x2_w8a8_a2q_plus_16b-0ddf46f1.pth'}


def get_model_by_name(name: str, pretrained: bool = False) -> Union[FloatESPCN, QuantESPCN]:
Expand Down
10 changes: 8 additions & 2 deletions src/brevitas_examples/super_resolution/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import brevitas.nn as qnn
from brevitas.nn.quant_layer import WeightQuantType
from brevitas.quant import Int8AccumulatorAwareWeightQuant
from brevitas.quant import Int8AccumulatorAwareZeroCenterWeightQuant
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import Uint8ActPerTensorFloat
Expand All @@ -26,9 +27,14 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):


class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
"""A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance"""
restrict_scaling_impl = FloatRestrictValue # backwards compatibility
pre_scaling_min_val = 1e-10
scaling_min_val = 1e-10
bit_width = None


class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant):
"""A2Q+: Improving Accumulator-Aware Weight Quantization"""
bit_width = None


class CommonIntActQuant(Int8ActPerTensorFloat):
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/super_resolution/models/espcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,15 @@ def float_espcn(upscale_factor: int, num_channels: int = 3) -> FloatESPCN:


def quant_espcn(
upcsale_factor: int,
upscale_factor: int,
num_channels: int = 3,
weight_bit_width: int = 8,
act_bit_width: int = 8,
acc_bit_width: int = 32,
weight_quant: WeightQuantType = CommonIntWeightPerChannelQuant) -> QuantESPCN:
""" """
return QuantESPCN(
upscale_factor=upcsale_factor,
upscale_factor=upscale_factor,
num_channels=num_channels,
act_bit_width=act_bit_width,
acc_bit_width=acc_bit_width,
Expand Down
55 changes: 49 additions & 6 deletions src/brevitas_examples/super_resolution/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,52 @@
from torch import Tensor
import torch.nn as nn

from brevitas.core.scaling import AccumulatorAwareParameterPreScaling
from brevitas.core.scaling import AccumulatorAwareZeroCenterParameterPreScaling
import brevitas.nn as qnn
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL

EPS = 1e-10


def _get_a2q_module(module: nn.Module):
for submod in module.modules():
if isinstance(submod, AccumulatorAwareParameterPreScaling):
return submod
return None


def _calc_a2q_acc_bit_width(
weight_max_l1_norm: Tensor, input_bit_width: Tensor, input_is_signed: bool):
"""Using the closed-form bounds on accumulator bit-width as derived in
`A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance`.
This function returns the minimum accumulator bit-width that can be used
without risk of overflow."""
assert weight_max_l1_norm.numel() == 1
input_is_signed = float(input_is_signed)
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS)
alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed
phi = lambda x: torch.log2(1. + pow(2., -x))
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = torch.ceil(min_bit_width)
return min_bit_width


def _calc_a2q_plus_acc_bit_width(
weight_max_l1_norm: Tensor, input_bit_width: Tensor, input_is_signed: bool):
"""Using the closed-form bounds on accumulator bit-width as derived in `A2Q+:
Improving Accumulator-Aware Weight Quantization`. This function returns the
minimum accumulator bit-width that can be used without risk of overflow,
assuming that the floating-point weights are zero-centered."""
input_is_signed = float(input_is_signed)
assert weight_max_l1_norm.numel() == 1
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS)
input_range = pow(2., input_bit_width) - 1. # 2^N - 1.
min_bit_width = torch.log2(weight_max_l1_norm * input_range + 2.)
min_bit_width = torch.ceil(min_bit_width)
return min_bit_width


def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor:
assert isinstance(module, qnn.QuantConv2d), "Error: function only support QuantConv2d."

Expand All @@ -24,12 +64,15 @@ def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor:
quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3))

# using the closed-form bounds on accumulator bit-width
weight_max_l1_norm = quant_weight_per_channel_l1_norm.max()
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS)
alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed
phi = lambda x: torch.log2(1. + pow(2., -x))
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = torch.ceil(min_bit_width)
min_bit_width = _calc_a2q_acc_bit_width(
quant_weight_per_channel_l1_norm.max(),
input_bit_width=input_bit_width,
input_is_signed=input_is_signed)
if isinstance(_get_a2q_module(module), AccumulatorAwareZeroCenterParameterPreScaling):
min_bit_width = _calc_a2q_plus_acc_bit_width(
quant_weight_per_channel_l1_norm.max(),
input_bit_width=input_bit_width,
input_is_signed=input_is_signed)
return min_bit_width


Expand Down
Loading