diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 92f917dfa557..6a267110de86 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -251,7 +251,7 @@ def _impl(inputs, input_types): def _convolution(): def _impl(inputs, input_types): # Use transpose or normal - use_transpose = True if inputs[6] == "1" else False + use_transpose = True if inputs[6] == 1 else False data = inputs[0] weight = inputs[1] @@ -268,6 +268,10 @@ def _impl(inputs, input_types): else: assert "data type {} could not be parsed in conv op" % (type(weight)) + # Transposed convolutions have IOHW layout. + if use_transpose: + weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0] + channels = weight_shape[0] groups = int(inputs[8]) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index a5557cede031..1878266a6a86 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -448,6 +448,14 @@ def forward(self, *args): input_data=torch.randn((1, 8, 16, 16))) +def test_forward_conv_transpose(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=input_data) + verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=input_data) + + def test_forward_threshold(): torch.set_grad_enabled(False) input_shape = [1, 3] @@ -1050,6 +1058,7 @@ def forward(self, xs): test_forward_maxpool1d() test_forward_hardtanh() test_forward_conv() + test_forward_conv_transpose() test_forward_threshold() test_forward_contiguous() test_forward_batchnorm()