Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
KodiaqQ committed Jul 18, 2024
1 parent 387e666 commit 69a3543
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/post_training/test_templates/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def __init__(self):
self.conv_2 = self._build_conv(2, 3, 2)
self.conv_3 = self._build_conv(1, 2, 3)
self.conv_4 = self._build_conv(2, 3, 1)
self.conv_5 = self._build_conv(3, 2, 2)
self.conv_5 = self._build_conv(3, 2, 1)
self.max_pool = torch.nn.MaxPool2d((2, 2))
self.conv_6 = self._build_conv(2, 3, 1)

def _build_conv(self, in_channels=1, out_channels=2, kernel_size=2):
conv = create_conv(in_channels, out_channels, kernel_size)
Expand All @@ -198,7 +200,9 @@ def forward(self, x):
x_2 = self.conv_3(x)
x_2 = self.conv_4(F.relu(x_2))
x_1_2 = torch.concat([x_1, x_2])
return self.conv_5(F.relu(x_1_2))
x = self.conv_5(F.relu(x_1_2))
x = self.max_pool(x)
return self.conv_6(x)


class LinearMultiShapeModel(nn.Module):
Expand Down

0 comments on commit 69a3543

Please sign in to comment.