Skip to content

Commit

Permalink
Test (graph eq): test equalization through BN
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 17, 2023
1 parent ca5a87f commit b279a09
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,47 @@
import torch.nn as nn


@pytest_cases.fixture
def bnconv_model():
class BNConvModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.bn = nn.BatchNorm2d(3)
# Simulate statistics gathering
self.bn.running_mean.data = torch.randn_like(self.bn.running_mean)
self.bn.running_var.data = torch.abs(torch.randn_like(self.bn.running_var))
# Simulate learned parameters
self.bn.weight.data = torch.randn_like(self.bn.weight)
self.bn.bias.data = torch.randn_like(self.bn.bias)
self.conv = nn.Conv2d(3, 16, kernel_size=3)
def forward(self, x):
x = self.bn(x)
x = self.conv(x)
return x
return BNConvModel


@pytest_cases.fixture
def convbn_model():
class ConvBNModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 128, kernel_size=3)
self.bn = nn.BatchNorm2d(128)
# Simulate statistics gathering
self.bn.running_mean.data = torch.randn_like(self.bn.running_mean)
self.bn.running_var.data = torch.abs(torch.randn_like(self.bn.running_var))
# Simulate learned parameters
self.bn.weight.data = torch.randn_like(self.bn.weight)
self.bn.bias.data = torch.randn_like(self.bn.bias)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
return ConvBNModel


@pytest_cases.fixture
def residual_model():
class ResidualModel(nn.Module):
Expand Down Expand Up @@ -58,4 +99,5 @@ def forward(self, x):
return x
return ResidualSrcsAndSinkModel

toy_model = fixture_union('toy_model', [ 'residual_model', 'srcsinkconflict_model', 'mul_model'])
toy_model = fixture_union('toy_model', [ 'residual_model', 'srcsinkconflict_model', 'mul_model',
'convbn_model', 'bnconv_model'])

0 comments on commit b279a09

Please sign in to comment.