Skip to content

Commit

Permalink
Feat (graph_eq): activation equalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jun 15, 2023
1 parent f7c5096 commit 25703a2
Show file tree
Hide file tree
Showing 10 changed files with 620 additions and 205 deletions.
511 changes: 401 additions & 110 deletions src/brevitas/graph/equalize.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brevitas.graph.fixed_point import MergeBatchNorm
from brevitas.graph.fixed_point import MoveSplitBatchNormBeforeCat
from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool
from brevitas.graph.quantize_impl import act_handler
from brevitas.graph.quantize_impl import add_output_quant_handler
from brevitas.graph.quantize_impl import inp_placeholder_handler
from brevitas.graph.quantize_impl import layer_handler
Expand Down Expand Up @@ -301,7 +302,7 @@ def quantize(
graph_model.eval()
graph_model = inp_placeholder_handler(
graph_model, input_quantizer=quant_identity_map.get('signed', None))
graph_model = layer_handler(graph_model, layer_map=quant_act_map, requantize_output=False)
graph_model = act_handler(graph_model, layer_map=quant_act_map)
graph_model = add_output_quant_handler(
graph_model, quant_identity_map, quant_act_map, unsigned_act_tuple)
graph_model = layer_handler(
Expand Down
55 changes: 41 additions & 14 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.graph.base import InsertModuleCallAfter
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleToModuleByInstance
from brevitas.graph.utils import del_module
from brevitas.graph.utils import get_module

ADD_FNS = [torch.add, operator.add, operator.iadd]
Expand Down Expand Up @@ -107,9 +108,7 @@ def are_inputs_quantized_and_aligned(model, node, quantized_modules_list, quant_
for inp_node in node.all_input_nodes:
if inp_node.op == 'call_module':
inp_module = get_module(model, inp_node.target)
if isinstance(inp_module, tuple(quant_act_map.keys())):
quantized_modules_list.append(None)
elif isinstance(inp_module, tuple(PRECISION_PRESERVING_MODULES)) and (
if isinstance(inp_module, tuple(PRECISION_PRESERVING_MODULES)) and (
not same_sign or
(same_sign and isinstance(inp_module, tuple(SIGN_PRESERVING_MODULES)))):
are_inputs_quantized_and_aligned(
Expand Down Expand Up @@ -156,15 +155,15 @@ def output_quant_handler(
rewriters,
is_sign_preserving,
quant_identity_map,
quant_act_map=None,
quant_act_map,
unsigned_act_tuple=None):
"""
Starting from `node`, check if any of the users requires requantization (i.e., it does not have
an act_quant attribute). In that case, the functions adds a requantization step to which all the
branches are connected. If another branch has its own requantization step, there will be two
consecutive for that branch.
"""
if is_sign_preserving and (quant_act_map is None or unsigned_act_tuple is None):
if is_sign_preserving and unsigned_act_tuple is None:
raise RuntimeError("Missing information for output_quant_handler")
quant_module = None
quant_module_name = None
Expand Down Expand Up @@ -366,18 +365,47 @@ def add_output_quant_handler(model, quant_identity_map, quant_act_map, unsigned_
return model


def act_handler(model, layer_map):
for node in model.graph.nodes:
rewriters = []
if node.op == 'call_module':
module = get_module(model, node.target)
if isinstance(module, tuple(layer_map.keys())):
if layer_map[type(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type(module)]
quant_module = quant_module_class(**quant_module_kwargs)
# Check for activation equalization mul nodes
if len(node.users) == 1:
user_node = list(node.users.keys())[0]
if user_node.name.endswith('act_eq_mul'):
# We update activation_impl so that the mul node is executed before quantization
act_module = quant_module.act_quant.fused_activation_quant_proxy.activation_impl
mul_module = get_module(model, user_node.target)
quant_module.act_quant.fused_activation_quant_proxy.activation_impl = torch.nn.Sequential(
*[act_module, mul_module])
# The mul node added during equalization is removed
user_node.replace_all_uses_with(node)
model.graph.erase_node(user_node)
del_module(model, user_node.target)
rewriter = ModuleInstanceToModuleInstance(module, quant_module)
rewriters.append(rewriter)
for rewriter in rewriters:
model = rewriter.apply(model)
return model


def layer_handler(
model,
layer_map,
requantize_output,
quant_identity_map=dict(),
quant_act_map=dict(),
unsigned_act_tuple=dict()):
model,
layer_map,
requantize_output,
quant_identity_map=None,
quant_act_map=None,
unsigned_act_tuple=None):
"""
Replace FP weight layers with their corresponding quantized version
"""
if requantize_output and (len(quant_identity_map) == 0 or len(quant_act_map) == 0 or
len(unsigned_act_tuple) == 0):
if requantize_output and (quant_identity_map is None or quant_act_map is None or
unsigned_act_tuple is None):
raise RuntimeError("Missing information to requantize output")
for node in model.graph.nodes:
rewriters = []
Expand Down Expand Up @@ -409,7 +437,6 @@ def layer_handler(
quant_module_class, quant_module_kwargs = layer_map[type(module)]
# Quantize the input if is not quantized, input_quant is not specified,
# and the quant_identity_map is provided.
# The last requirement is needed to avoid requantizing the input to activations
if not are_inputs_quantized_and_aligned(
model, node, [], quant_act_map, same_sign=False
) and not 'input_quant' in quant_module_kwargs and len(quant_identity_map) > 0:
Expand Down
22 changes: 1 addition & 21 deletions src/brevitas/graph/target/flexml.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,11 @@
from torch import nn

from brevitas.fx.brevitas_tracer import symbolic_trace
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.equalize import EqualizeGraph
from brevitas.graph.fixed_point import CollapseConsecutiveConcats
from brevitas.graph.fixed_point import MergeBatchNorm
from brevitas.graph.fixed_point import MoveSplitBatchNormBeforeCat
from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool
from brevitas.graph.quantize import preprocess_for_quantize
from brevitas.graph.quantize import quantize
from brevitas.graph.quantize import UNSIGNED_ACT_TUPLE
from brevitas.graph.standardize import DuplicateSharedStatelessModule
from brevitas.graph.standardize import MeanMethodToAdaptiveAvgPool2d
from brevitas.graph.standardize import RemoveStochasticModules
from brevitas.graph.standardize import TorchFunctionalToModule
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
from brevitas.quant import Int8WeightPerTensorFixedPoint
Expand Down Expand Up @@ -113,19 +105,7 @@
nn.Sigmoid: (
qnn.QuantSigmoid, {
'act_quant': Uint8ActPerTensorFixedPoint,
'return_quant_tensor': True,}),
nn.SiLU: (
qnn.flexml.FlexMLQuantSwish, {
'act_quant': Int8ActPerTensorFixedPoint,
'return_quant_tensor': True,}),
nn.Hardswish: (
qnn.flexml.FlexMLQuantHardswish, {
'act_quant': Int8ActPerTensorFixedPoint,
'return_quant_tensor': True,}),
nn.Hardsigmoid: (
qnn.flexml.FlexMLQuantHardsigmoid, {
'act_quant': Uint8ActPerTensorFixedPoint,
'return_quant_tensor': True,})}
'return_quant_tensor': True,}),}

FLEXML_QUANT_IDENTITY_MAP = {
'signed':
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,9 @@ def get_output_channel_dim(module):

def get_output_channels(module):
return module.weight.shape[get_output_channel_dim(module)]


def get_node(graph_model, name):
for node in graph_model.graph.nodes:
if node.target == name:
return node
40 changes: 40 additions & 0 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from brevitas.nn.quant_mha import QuantMultiheadAttention


class EqualizedModule(torch.nn.Module):

def __init__(self, scale_module, layer) -> None:
super().__init__()
self.scale = scale_module
self.layer = layer

def forward(self, x, *args, **kwargs):
args = list(args)
out = x
if 'key' in kwargs:
if kwargs['key'].data_ptr() != out.data_ptr():
raise ValueError(
"Cross MHA is not supported for activation equalization."
"Replace kwargs with positional args to avoid this exception.")
out = self.scale(out)

pos_inputs = [out]
# QuantMultiheadAttention is not a subclass of MultiheadAttention
# We need to preserve the correctness of the forward even after
# quantization has been applied
if isinstance(self.layer, (torch.nn.MultiheadAttention, QuantMultiheadAttention)):
if 'key' not in kwargs.items():
pos_inputs.append(out)
args.pop(0)
else:
kwargs['key'] = out
if 'value' not in kwargs.items():
pos_inputs.append(out)
args.pop(0)
else:
kwargs['value'] = out

out = self.layer(*pos_inputs, *args, **kwargs)
return out
5 changes: 4 additions & 1 deletion src/brevitas/nn/quant_scale_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def __init__(self, num_features: int, bias: bool, runtime_shape=(1, -1, 1, 1)):
self.runtime_shape = runtime_shape

def forward(self, input):
return input * self.weight.view(self.runtime_shape) + self.bias.view(self.runtime_shape)
out = input * self.weight.view(self.runtime_shape)
if self.bias:
out += self.bias.view(self.runtime_shape)
return out


class QuantScaleBias(QuantWBIOL, ScaleBias):
Expand Down
54 changes: 0 additions & 54 deletions src/brevitas/nn/target/flexml.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,60 +51,6 @@ def is_quant_act_signed(self):
return self.output_quant.is_quant_act_signed


class FlexMLQuantSwish(QuantNLAL):

def __init__(
self,
act_quant: Optional[ActQuantType] = Int8ActPerTensorFixedPoint,
input_quant: Optional[ActQuantType] = None,
return_quant_tensor: bool = False,
**kwargs):
QuantNLAL.__init__(
self,
act_impl=nn.SiLU,
passthrough_act=False,
input_quant=input_quant,
act_quant=act_quant,
return_quant_tensor=return_quant_tensor,
**kwargs)


class FlexMLQuantHardsigmoid(QuantNLAL):

def __init__(
self,
act_quant: Optional[ActQuantType] = Uint8ActPerTensorFixedPoint,
input_quant: Optional[ActQuantType] = None,
return_quant_tensor: bool = False,
**kwargs):
QuantNLAL.__init__(
self,
act_impl=nn.Hardsigmoid,
passthrough_act=False,
input_quant=input_quant,
act_quant=act_quant,
return_quant_tensor=return_quant_tensor,
**kwargs)


class FlexMLQuantHardswish(QuantNLAL):

def __init__(
self,
act_quant: Optional[ActQuantType] = Int8ActPerTensorFixedPoint,
input_quant: Optional[ActQuantType] = None,
return_quant_tensor: bool = False,
**kwargs):
QuantNLAL.__init__(
self,
act_impl=nn.Hardswish,
passthrough_act=False,
input_quant=input_quant,
act_quant=act_quant,
return_quant_tensor=return_quant_tensor,
**kwargs)


class FlexMLQuantAvgPool2d(QuantLayerMixin, nn.AvgPool2d):

class Int16QuantAvgPoolDivQuant(Int8ActPerTensorFixedPoint):
Expand Down
Loading

0 comments on commit 25703a2

Please sign in to comment.