Skip to content

Commit

Permalink
[Relay][Frontend][Pytorch] Fixed ConvTranspose2D parsing (#5157)
Browse files Browse the repository at this point in the history
* Fixed conv transpose parsing.

* small format change.

* Chage test module names.

* Simplified test syntax.
  • Loading branch information
jwfromm authored Mar 28, 2020
1 parent dada676 commit 9c80662
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _impl(inputs, input_types):
def _convolution():
def _impl(inputs, input_types):
# Use transpose or normal
use_transpose = True if inputs[6] == "1" else False
use_transpose = True if inputs[6] == 1 else False

data = inputs[0]
weight = inputs[1]
Expand All @@ -268,6 +268,10 @@ def _impl(inputs, input_types):
else:
assert "data type {} could not be parsed in conv op" % (type(weight))

# Transposed convolutions have IOHW layout.
if use_transpose:
weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]

channels = weight_shape[0]
groups = int(inputs[8])

Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,14 @@ def forward(self, *args):
input_data=torch.randn((1, 8, 16, 16)))


def test_forward_conv_transpose():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=input_data)
verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=input_data)


def test_forward_threshold():
torch.set_grad_enabled(False)
input_shape = [1, 3]
Expand Down Expand Up @@ -1050,6 +1058,7 @@ def forward(self, xs):
test_forward_maxpool1d()
test_forward_hardtanh()
test_forward_conv()
test_forward_conv_transpose()
test_forward_threshold()
test_forward_contiguous()
test_forward_batchnorm()
Expand Down

0 comments on commit 9c80662

Please sign in to comment.