diff --git a/src/brevitas/nn/quant_scale_bias.py b/src/brevitas/nn/quant_scale_bias.py index c570887cc..47f833b43 100644 --- a/src/brevitas/nn/quant_scale_bias.py +++ b/src/brevitas/nn/quant_scale_bias.py @@ -31,7 +31,10 @@ def __init__(self, num_features: int, bias: bool, runtime_shape=(1, -1, 1, 1)): self.runtime_shape = runtime_shape def forward(self, input): - return input * self.weight.view(self.runtime_shape) + self.bias.view(self.runtime_shape) + out = input * self.weight.view(self.runtime_shape) + if self.bias: + out += self.bias.view(self.runtime_shape) + return out class QuantScaleBias(QuantWBIOL, ScaleBias):