Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Compression] Add bias correction feature for PTQ quantizer #5603

Merged
merged 17 commits into from
Jun 29, 2023
Merged
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
r'https://docs\.nvidia\.com/deeplearning/',
r'https://cla\.opensource\.microsoft\.com',
r'https://www\.docker\.com/',
r'https://nlp.stanford.edu/projects/glove/',

# remove after 3.0 release
r'https://nni\.readthedocs\.io/en/v2\.10/compression/overview\.html',
Expand Down
88 changes: 74 additions & 14 deletions nni/contrib/compression/base/apply_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,87 @@ def bypass(target: torch.Tensor, target_space: TargetSpace):


def lsq_clamp_round(target: torch.Tensor, target_space: QuantizationTargetSpace):
def grad_scale(x, scale_factor):
y_out = x
y_grad = x * scale_factor
return (y_out - y_grad).detach() + y_grad
qmax: int = target_space.qmax
qmin: int = target_space.qmin
if target_space._scaler is not None:
scale = target_space._scaler.expand(target_space.scale, target_space.shape, keepdim=True) # type: ignore
else:
scale = target_space.scale

class LSQClampRound(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, target: torch.Tensor, scale: torch.Tensor):
ctx.save_for_backward(target, scale)
quantize_target = torch.clamp(target / scale, qmin, qmax)
dequantize_target = torch.round(quantize_target) * scale
return dequantize_target

@staticmethod
def backward(ctx: Any, grad_output: Any) -> Any:
target, scale = ctx.saved_tensors
grad_scale_factor = 1.0 / ((qmax * target.numel()) ** 0.5) if (qmax * target.numel()) ** 0.5 != 0 else 1.0

q_target = target / scale
# compute index
ind_neg = (q_target < qmin).float()
ind_pos = (q_target > qmax).float()
ind_mid = (1.0 - ind_neg.float() - ind_pos.float())
# scale gradient
grad_scale = (ind_neg * qmin + ind_pos * qmax + ind_mid * (-q_target + torch.round(q_target))) * grad_output * grad_scale_factor
if target_space._scaler is None:
grad_scale = grad_scale.sum().expand(scale.size())
grad_target = grad_output * ind_mid

return grad_target, grad_scale

return LSQClampRound.apply(target, scale)

def round_pass(x):
y_out = torch.round(x)
y_grad = x
return (y_out - y_grad).detach() + y_grad

def lsq_plus_clamp_round(target: torch.Tensor, target_space: QuantizationTargetSpace):
qmax: int = target_space.qmax
qmin: int = target_space.qmin
if target_space._scaler is not None:
scale = target_space._scaler.expand(target_space.scale, target_space.shape, keepdim=True) # type: ignore
zero_point = target_space._scaler.expand(target_space.zero_point, target_space.shape, keepdim=True) # type: ignore
else:
scale = target_space.scale
#Quantize
grad_scale_factor = 1.0 / ((qmax * target.numel()) ** 0.5) if (qmax * target.numel()) ** 0.5 != 0 else 1.0
scale = grad_scale(scale, grad_scale_factor)
new_target = torch.clamp(target / scale, qmin, qmax)
dequantized_target = round_pass(new_target) * scale
return dequantized_target
zero_point = target_space.zero_point

class LSQPlusClampRound(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, target: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor):
ctx.save_for_backward(target, scale, zero_point)
new_target = torch.clamp(target / scale + zero_point, qmin, qmax)
dequantized_target = (torch.round(new_target) - zero_point) * scale

return dequantized_target

@staticmethod
def backward(ctx: Any, grad_output: Any) -> Any:
target, scale, zero_point = ctx.saved_tensors
grad_scale_factor = 1.0 / ((qmax * target.numel()) ** 0.5) if (qmax * target.numel()) ** 0.5 != 0 else 1.0
#
q_target = target / scale
# compute index
ind_neg = ((q_target + zero_point) < qmin).float()
ind_pos = ((q_target + zero_point) > qmax).float()
ind_mid = (1.0 - ind_neg.float() - ind_pos.float())
# scale gradient
grad_scale = (ind_neg * (qmin - zero_point) + ind_pos * (qmax - zero_point) + \
ind_mid * (-q_target - zero_point + \
torch.round(q_target + zero_point))) * grad_output * grad_scale_factor
# zero_point gradient
grad_zp = (ind_neg * -scale + ind_pos * -scale + ind_mid * 0.0) * grad_output
# target gradient
grad_target = grad_output * ind_mid

if target_space._scaler is None:
grad_scale = grad_scale.sum().expand(scale.size())
grad_zp = grad_zp.sum().expand(zero_point.size())

return grad_target, grad_scale, grad_zp

return LSQPlusClampRound.apply(target, scale, zero_point)


class DoferaGradClampRound(torch.autograd.Function):
Expand Down Expand Up @@ -196,4 +255,5 @@ def slim_mul_mask(target: torch.Tensor, target_space: PruningTargetSpace):
'dorefa_clamp_round_output': DoferaGradClampRound.dorefa_clamp_round_output,
"lsq_clamp_round": lsq_clamp_round,
'bnn_clamp_round': BNNClampRound.apply,
'lsq_plus_clamp_round': lsq_plus_clamp_round
}
3 changes: 2 additions & 1 deletion nni/contrib/compression/base/target_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def _register_target(self):
assert hasattr(self._wrapper.module, self._target_name)
target = getattr(self._wrapper.module, self._target_name)
if isinstance(target, torch.nn.parameter.Parameter):
self._wrapper.register_parameter(self._target_name, torch.nn.Parameter(target.detach().clone()))
self._wrapper.register_parameter(self._target_name, torch.nn.Parameter(target.detach().clone(),
requires_grad=target.requires_grad))
elif isinstance(target, torch.Tensor):
self._wrapper.register_buffer(self._target_name, target.detach().clone())
elif target is None:
Expand Down
69 changes: 64 additions & 5 deletions nni/contrib/compression/base/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
import logging
import inspect
import functools
from typing import Any, Callable, Dict, List, Tuple, Type, Union, Literal

import torch
Expand Down Expand Up @@ -65,6 +66,13 @@ def __init__(self, module: torch.nn.Module, module_name: str, config: Dict[str,
self.module_forward = self.module.forward
self.name = module_name
self.config = config if config is not None else {}
# config storage
self.configs = {}
for mode, value in self.config.items():
if mode not in self.configs:
self.configs[mode] = []
self.configs[mode].append(value)

assert all(k in ['pruning', 'quantization', 'distillation'] for k in self.config)

# the arguments' name of self.module.forward
Expand Down Expand Up @@ -106,7 +114,7 @@ def register_bias(self):
bias = self.module.bias
self.is_register_bias = True
if isinstance(bias, nn.parameter.Parameter):
self.register_parameter('bias', torch.nn.Parameter(bias.detach().clone()))
self.register_parameter('bias', torch.nn.Parameter(bias.detach().clone(), requires_grad=bias.requires_grad))
delattr(self.module, 'bias')
self.module.register_buffer('bias', bias.data)
elif isinstance(bias, torch.Tensor):
Expand Down Expand Up @@ -160,12 +168,14 @@ def unwrap(self):
for target_name, target_space in self.pruning_target_spaces.items():
if target_space.type == TargetType.PARAMETER and isinstance(target_space.target, torch.nn.Parameter):
delattr(self.module, target_name)
self.module.register_parameter(target_name, torch.nn.Parameter(target_space.target.detach().clone()))
self.module.register_parameter(target_name, torch.nn.Parameter(target_space.target.detach().clone(), \
requires_grad=target_space.target.requires_grad))

for target_name, target_space in self.quantization_target_spaces.items():
if target_space.type == TargetType.PARAMETER and isinstance(target_space.target, torch.nn.Parameter):
delattr(self.module, target_name)
self.module.register_parameter(target_name, torch.nn.Parameter(target_space.target.detach().clone()))
self.module.register_parameter(target_name, torch.nn.Parameter(target_space.target.detach().clone(), \
requires_grad=target_space.target.requires_grad))

self.module.forward = self.module_forward
delattr(self.module, '_nni_wrapper')
Expand All @@ -174,7 +184,8 @@ def unwrap(self):
delattr(self.module, 'bias')
nni_original_bias = self.bias
if isinstance(nni_original_bias, nn.parameter.Parameter):
self.module.register_parameter('bias', torch.nn.Parameter(nni_original_bias.detach().clone()))
self.module.register_parameter('bias', torch.nn.Parameter(nni_original_bias.detach().clone(), \
requires_grad=nni_original_bias.requires_grad))
elif isinstance(nni_original_bias, torch.Tensor):
self.module.register_buffer('bias', nni_original_bias.detach().clone())
if len(self.fused_modules) > 0 and self.is_bias == 'None' and check_bias(self.module) == 'Tensor':
Expand Down Expand Up @@ -389,6 +400,14 @@ def forward(self, *args, **kwargs):
if len(self.fused_modules) > 0:
params_dict, activation_func_lis = fuse_modules(self, params_dict, *args, **kwargs)

# obtain original output
original_outputs = None
if getattr(self, "is_bias_correction", False) and check_bias(self.module) == 'Tensor':
for target_name, original_param in params_dict.items():
setattr(self.module, target_name, original_param * 1.0)

original_outputs = self.module_forward(*args, **kwargs)

params_dict = self.patch_params(params_dict)
for target_name, patched_param in params_dict.items():
# NOTE: here using copy_ will cause `backward through the graph a second time` error, don't know why.
Expand All @@ -403,10 +422,49 @@ def forward(self, *args, **kwargs):
#fuse activation func
for activation_module in activation_func_lis:
outputs = activation_module._nni_wrapper.module_forward(outputs)
if original_outputs is not None:
original_outputs = activation_module._nni_wrapper.module_forward(original_outputs)

outputs = self.patch_outputs(outputs)

if getattr(self, "is_bias_correction", False) and check_bias(self.module) == 'Tensor':
assert isinstance(original_outputs, torch.Tensor) and isinstance(outputs, torch.Tensor), \
f"Bias correction is only applied to variables with tensor output types, but got {(type(original_outputs), type(outputs))}"
element_num = functools.reduce(lambda x,y: x * y, list(original_outputs.shape[:-1]))
dim_sum = tuple(range(len(original_outputs.shape[:-1])))
bias_correction = torch.sum(original_outputs - outputs, dim=dim_sum)
if not hasattr(self, 'bias_correction'):
setattr(self, 'bias_correction', bias_correction)
else:
self.bias_correction += bias_correction
if not hasattr(self, 'bias_element_num'):
setattr(self, 'bias_element_num', element_num)
else:
self.bias_element_num: int = self.bias_element_num + element_num

torch.cuda.empty_cache()

return outputs

def update_bias(self):
assert hasattr(self, "bias_element_num")
assert check_bias(self.module) == 'Tensor'

bias_correction = getattr(self, "bias_correction", None)
element_num = getattr(self, "bias_element_num", 0)
assert bias_correction is not None
assert element_num > 0

bias_correction /= element_num ## compute mean

if 'bias' in self.quantization_target_spaces:
target_space = self.quantization_target_spaces['bias']
assert target_space.target is not None and \
list(target_space.target.size()) == list(bias_correction.size())
target_space.target.data += bias_correction.detach().clone()
else:
self.module.bias.data += bias_correction.detach().clone()


class IdentityModuleWrapper(ModuleWrapper): # only aviable for batchnorm
'''
Expand Down Expand Up @@ -494,7 +552,8 @@ def create_module_wrapper(model: nn.Module, module: nn.Module, module_name: str,
raise ValueError(f'Using two fused_modules_pair for {module_name} is not supported')
wrapper.unfreeze()
target_spaces = wrapper.extend_target_spaces(config, mode)
wrapper.config = update_config(wrapper.config, {mode: config})
wrapper.configs = update_config(wrapper.configs, {mode: config})

if len(fused_modules_pair) > 0:
wrapper.fused_modules = fused_modules
else:
Expand Down
3 changes: 2 additions & 1 deletion nni/contrib/compression/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from .dorefa_quantizer import DoReFaQuantizer
from .lsq_quantizer import LsqQuantizer
from .ptq_quantizer import PtqQuantizer
from .lsqplus_quantizer import LsqPlusQuantizer

__all__ = ["QATQuantizer", "BNNQuantizer", "DoReFaQuantizer", "LsqQuantizer", "PtqQuantizer"]
__all__ = ["QATQuantizer", "BNNQuantizer", "DoReFaQuantizer", "LsqQuantizer", "PtqQuantizer", "LsqPlusQuantizer"]
7 changes: 4 additions & 3 deletions nni/contrib/compression/quantization/lsq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class LsqQuantizer(Quantizer):
A list of dict, each dict configure which module need to be quantized, and how to quantize.
Please refer :doc:`Compression Config Specification </compression/config_list>` for more information.
evaluator
TODO: {evaluator_docstring}
{evaluator_docstring}

Examples
--------
Expand Down Expand Up @@ -95,9 +95,10 @@ def mean_reduce_func(converted_target: Tensor) -> torch.Tensor:
if self.is_init or not self.check_target(wrapper, target_name):
return
target_space = wrapper.quantization_target_spaces[target_name]
init_target = target.data.detach().abs().mean() * 2 / (target_space.qmax ** 0.5)
# init_target = target.data.detach().abs().mean() * 2 / (target_space.qmax ** 0.5)
init_target = torch.tensor([0.01]).to(target.device)
if not target_space._scaler:
target_space.scale.data = init_target # type: ignore
target_space.scale.data = init_target # type: ignore
target_space.zero_point = torch.tensor(0.0).to(target.device)
else:
new_target = init_target.expand(target.shape).to(target.device)
Expand Down
Loading