diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 4328f6a02..7d20bc340 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -56,14 +56,14 @@ nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d, - nn.Identity) + nn.Identity, + nn.ReLU, + nn.LeakyReLU) _scale_invariant_op = (torch.mul, operator.mul, operator.imul, operator.__mul__, operator.__imul__) _select_op = (operator.getitem, operator.__getitem__) -_scale_invariant_activations = (torch.nn.ReLU,) - _scale_varying_activations = ( torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.ReLU6, torch.nn.GELU, torch.nn.SiLU) @@ -534,8 +534,7 @@ def _is_supported_module(graph_model: GraphModule, node: Node) -> bool: def _is_scale_invariant_module(graph_model: GraphModule, node: Node) -> bool: return node.op == 'call_module' and isinstance( - get_module(graph_model, node.target), - _scale_invariant_layers + _scale_invariant_activations) + get_module(graph_model, node.target), _scale_invariant_layers) def _is_scale_varying_activation(graph_model, node):