Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 19, 2020
1 parent c5b45ee commit 021b85a
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,28 +702,34 @@ def forward(self, x):


def test_adaptive_pool3d():
inp = torch.rand((1, 32, 16, 16, 16))
verify_model(torch.nn.AdaptiveMaxPool3d((1, 1, 1)).eval(), inp)
verify_model(torch.nn.AdaptiveMaxPool3d((2, 2, 2)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((1, 1, 1)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((2, 2, 2)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((4, 8, 8)).eval(), inp)
verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp)
for ishape in [(1, 32, 16, 16, 16),
(1, 32, 9, 15, 15),
(1, 32, 13, 7, 7)]:
inp = torch.rand(ishape)
verify_model(torch.nn.AdaptiveMaxPool3d((1, 1, 1)).eval(), inp)
verify_model(torch.nn.AdaptiveMaxPool3d((2, 2, 2)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((1, 1, 1)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((2, 2, 2)).eval(), inp)
verify_model(torch.nn.AdaptiveAvgPool3d((4, 8, 8)).eval(), inp)
verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp)


def test_conv3d():
inp = torch.rand((1, 32, 16, 16, 16))
verify_model(torch.nn.Conv3d(32, 16, (3, 3, 3),
padding=(1, 1, 1)).eval(),
inp),
verify_model(torch.nn.Conv3d(32, 16, (5, 5, 5),
padding=(2, 2, 2)).eval(),
inp),
verify_model(torch.nn.Conv3d(32, 16, kernel_size=1).eval(),
inp)
# downsample
verify_model(torch.nn.Conv3d(32, 16, kernel_size=1, stride=2).eval(),
inp)
for ishape in [(1, 32, 16, 16, 16),
(1, 32, 9, 15, 15),
(1, 32, 13, 7, 7)]:
inp = torch.rand(ishape)
verify_model(torch.nn.Conv3d(32, 16, (3, 3, 3),
padding=(1, 1, 1)).eval(),
inp),
verify_model(torch.nn.Conv3d(32, 16, (5, 5, 5),
padding=(2, 2, 2)).eval(),
inp),
verify_model(torch.nn.Conv3d(32, 16, kernel_size=1).eval(),
inp)
# downsample
verify_model(torch.nn.Conv3d(32, 16, kernel_size=1, stride=2).eval(),
inp)


# Model tests
Expand Down

0 comments on commit 021b85a

Please sign in to comment.