Skip to content

Commit

Permalink
Test (graph_eq): tests for activation equalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed May 3, 2023
1 parent 4c19504 commit 7f83831
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 6 deletions.
28 changes: 23 additions & 5 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_
name_to_module[name] = module
for i in range(3):
for region in regions:
scale_factors_region = _cross_layer_equalization([name_to_module[n] for n in region[0]],
[name_to_module[n] for n in region[1]],
merge_bias,
bias_shrinkage,
scale_computation_type)
scale_factors_region = _cross_layer_equalization(
[name_to_module[n] for n in region[0]], [name_to_module[n] for n in region[1]],
merge_bias=merge_bias,
bias_shrinkage=bias_shrinkage,
scale_computation_type=scale_computation_type)
if i == 0:
scale_factors_regions.append(scale_factors_region)
return scale_factors_regions
Expand Down Expand Up @@ -85,9 +85,11 @@ def __init__(self) -> None:
self.bn.weight.data = torch.randn_like(self.bn.weight)
self.bn.bias.data = torch.randn_like(self.bn.bias)
self.conv = nn.Conv2d(3, 16, kernel_size=3)
self.relu = nn.ReLU()

def forward(self, x):
x = self.bn(x)
x = self.relu(x)
x = self.conv(x)
return x

Expand All @@ -109,9 +111,11 @@ def __init__(self) -> None:
self.linear = nn.Linear(3, 24)
self.mha = nn.MultiheadAttention(
24, 3, 0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first)
self.relu = nn.ReLU()

def forward(self, x):
x = self.linear(x)
x = self.relu(x)
x, _ = self.mha(x, x, x)
return x

Expand All @@ -136,9 +140,11 @@ def __init__(self) -> None:
self.layernorm.bias.data = torch.randn_like(self.layernorm.bias)
self.mha = nn.MultiheadAttention(
3, 3, 0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first)
self.relu = nn.ReLU()

def forward(self, x):
x = self.layernorm(x)
x = self.relu(x)
x, _ = self.mha(x, x, x)
return x

Expand All @@ -160,9 +166,11 @@ def __init__(self) -> None:
self.mha = nn.MultiheadAttention(
3, 1, 0.1, bias=bias, add_bias_kv=add_bias_kv, batch_first=batch_first)
self.linear = nn.Linear(3, 6)
self.relu = nn.ReLU()

def forward(self, x):
x, _ = self.mha(x, x, x)
x = self.relu(x)
x = self.linear(x)
return x

Expand All @@ -178,9 +186,11 @@ def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3)
self.conv_0 = nn.Conv2d(16, 16, kernel_size=1, groups=16)
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
return x

Expand All @@ -202,9 +212,11 @@ def __init__(self) -> None:
# Simulate learned parameters
self.bn.weight.data = torch.randn_like(self.bn.weight)
self.bn.bias.data = torch.randn_like(self.bn.bias)
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.bn(x)
return x

Expand All @@ -220,10 +232,12 @@ def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=1)
self.conv_0 = nn.Conv2d(16, 3, kernel_size=1)
self.relu = nn.ReLU()

def forward(self, x):
start = x
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
x = start + x
return x
Expand All @@ -244,11 +258,13 @@ def __init__(self) -> None:
self.conv_start = nn.Conv2d(3, 3, kernel_size=1)
self.conv = nn.Conv2d(3, 3, kernel_size=1)
self.conv_0 = nn.Conv2d(3, 3, kernel_size=1)
self.relu = nn.ReLU()

def forward(self, x):
start = self.conv_start(x)
x = self.conv_0(start)
x = start + x
x = self.relu(x)
x = self.conv(x)
return x

Expand All @@ -265,11 +281,13 @@ def __init__(self) -> None:
self.conv_1 = nn.Conv2d(3, 3, kernel_size=1)
self.conv_0 = nn.Conv2d(3, 3, kernel_size=1)
self.conv_end = nn.Conv2d(3, 3, kernel_size=1)
self.relu = nn.ReLU()

def forward(self, x):
x_0 = self.conv_0(x)
x_1 = self.conv_1(x)
x = x_0 * x_1
x = self.relu(x)
x = self.conv_end(x)
return x

Expand Down
82 changes: 81 additions & 1 deletion tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from brevitas.fx import symbolic_trace
from brevitas.graph.equalize import _extract_regions
from brevitas.graph.equalize import _is_supported_module
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.standardize import DuplicateSharedStatelessModule
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.graph.utils import get_module

from .equalization_fixtures import *
Expand Down Expand Up @@ -131,4 +134,81 @@ def test_models(toy_model, merge_bias, request):
assert torch.allclose(expected_out, out, atol=ATOL)
# Check that at least one region performs "true" equalization
# If all shapes are scalar, no equalization has been performed
assert all([shape != () for shape in shape_scale_regions])
assert any([shape != () for shape in shape_scale_regions])


@pytest.mark.parametrize("merge_bias", [True, False])
def test_act_equalization_models(toy_model, merge_bias, request):
test_id = request.node.callspec.id

if 'mha' in test_id:
in_shape = IN_SIZE_LINEAR
else:
in_shape = IN_SIZE_CONV

model_class = toy_model
model = model_class()
inp = torch.randn(in_shape)

model.eval()
expected_out = model(inp)
model = symbolic_trace(model)

with activation_equalization_mode(model, 0.5, True) as aem:
regions = aem.graph_act_eq.regions
model(inp)
scale_factor_regions = aem.scale_factors
shape_scale_regions = [scale.shape for scale in scale_factor_regions]

out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)

# This region is made up of a residual branch, so no regions are found for act equalization
if 'srcsinkconflict_mode' not in test_id:
assert len(regions) > 0
# Check that at least one region performs "true" equalization
# If all shapes are scalar, no equalization has been performed
assert any([shape != () for shape in shape_scale_regions])


@pytest_cases.parametrize(
"model_dict", [(model_name, coverage) for model_name, coverage in MODELS.items()],
ids=[model_name for model_name, _ in MODELS.items()])
@pytest.mark.parametrize("merge_bias", [True, False])
def test_act_equalization_torchvision_models(model_dict: dict, merge_bias: bool):
model, coverage = model_dict

if model == 'googlenet' and torch_version == version.parse('1.8.1'):
pytest.skip(
'Skip because of PyTorch error = AttributeError: \'function\' object has no attribute \'GoogLeNetOutputs\' '
)
if 'vit' in model and torch_version < version.parse('1.13'):
pytest.skip(
f'ViT supported from torch version 1.13, current torch version is {torch_version}')

try:
model = getattr(models, model)(pretrained=True, transform_input=False)
except TypeError:
model = getattr(models, model)(pretrained=True)

torch.manual_seed(SEED)
inp = torch.randn(IN_SIZE_CONV)
model.eval()

model = symbolic_trace(model)
model = TorchFunctionalToModule().apply(model)
model = DuplicateSharedStatelessModule().apply(model)
expected_out = model(inp)

with activation_equalization_mode(model, 0.5, True) as aem:
model(inp)

scale_factor_regions = aem.scale_factors
shape_scale_regions = [scale.shape for scale in scale_factor_regions]

out = model(inp)

assert torch.allclose(expected_out, out, atol=ATOL)
# Check that at least one region performs "true" equalization
# If all shapes are scalar, no equalization has been performed
assert any([shape != () for shape in shape_scale_regions])

0 comments on commit 7f83831

Please sign in to comment.