Skip to content

Commit

Permalink
Feat (graph eq): equalize through BN
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 17, 2023
1 parent c5c3c7e commit ca5a87f
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.Linear)
torch.nn.Linear,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d)

_scale_invariant_layers = (
torch.nn.Dropout,
Expand Down Expand Up @@ -79,6 +82,10 @@ def _channel_range(inp):
# correct corner case where where all weights along a channel have the same value
# e.g. when a mean/torch.nn.AvgPool/torch.nn.AdaptiveAvgPool is converted to a depth-wise conv
out = torch.where(out == 0., torch.mean(inp, dim=1), out)

# convert to positive range, in case any of the values are negative,
# highly likely in cases when there is only one value per channel, such as in Batch Norm
out = torch.abs(out)
return out


Expand All @@ -94,6 +101,8 @@ def _get_size(axes):
def _get_input_axis(module):
if isinstance(module, torch.nn.Linear):
return 1
elif isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
return 0
elif isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
if module.groups == 1:
return 1
Expand All @@ -113,7 +122,8 @@ def _get_input_axis(module):


def _get_output_axis(module):
if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
return 0
elif isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranpose3d)):
return 1
Expand Down Expand Up @@ -158,6 +168,9 @@ def _cross_layer_equalization(srcs, sinks):
for module, axis in sink_axes.items():
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)
if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
additive_factor = module.running_mean.data * module.weight.data / torch.sqrt(module.running_var.data + module.eps)
module.bias.data = module.bias.data + additive_factor * (scaling_factors - 1)
module.weight.data = module.weight.data * torch.reshape(scaling_factors, src_broadcast_size)


Expand Down

0 comments on commit ca5a87f

Please sign in to comment.