From 69a3543c119ceef19deb5bb7cff911435d346e81 Mon Sep 17 00:00:00 2001 From: Nikita Malinin Date: Thu, 18 Jul 2024 13:21:37 +0200 Subject: [PATCH] Add test --- tests/post_training/test_templates/helpers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/post_training/test_templates/helpers.py b/tests/post_training/test_templates/helpers.py index da969214914..c25d138fd8c 100644 --- a/tests/post_training/test_templates/helpers.py +++ b/tests/post_training/test_templates/helpers.py @@ -184,7 +184,9 @@ def __init__(self): self.conv_2 = self._build_conv(2, 3, 2) self.conv_3 = self._build_conv(1, 2, 3) self.conv_4 = self._build_conv(2, 3, 1) - self.conv_5 = self._build_conv(3, 2, 2) + self.conv_5 = self._build_conv(3, 2, 1) + self.max_pool = torch.nn.MaxPool2d((2, 2)) + self.conv_6 = self._build_conv(2, 3, 1) def _build_conv(self, in_channels=1, out_channels=2, kernel_size=2): conv = create_conv(in_channels, out_channels, kernel_size) @@ -198,7 +200,9 @@ def forward(self, x): x_2 = self.conv_3(x) x_2 = self.conv_4(F.relu(x_2)) x_1_2 = torch.concat([x_1, x_2]) - return self.conv_5(F.relu(x_1_2)) + x = self.conv_5(F.relu(x_1_2)) + x = self.max_pool(x) + return self.conv_6(x) class LinearMultiShapeModel(nn.Module):