From a9bc48652273552949140ced9e6804ed83e63a4d Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 27 Mar 2020 13:21:49 -0700 Subject: [PATCH 1/4] Fixed conv transpose parsing. --- python/tvm/relay/frontend/pytorch.py | 7 ++++- tests/python/frontend/pytorch/test_forward.py | 28 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 92f917dfa557..3a6e15b8e02b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -251,7 +251,7 @@ def _impl(inputs, input_types): def _convolution(): def _impl(inputs, input_types): # Use transpose or normal - use_transpose = True if inputs[6] == "1" else False + use_transpose = True if inputs[6] == 1 else False data = inputs[0] weight = inputs[1] @@ -260,6 +260,7 @@ def _impl(inputs, input_types): padding = inputs[4] dilation = inputs[5] + if isinstance(weight, _expr.Expr): inferred_shape = _infer_shape(weight) weight_shape = [] @@ -268,6 +269,10 @@ def _impl(inputs, input_types): else: assert "data type {} could not be parsed in conv op" % (type(weight)) + # Transposed convolutions have IOHW layout. + if use_transpose: + weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0] + channels = weight_shape[0] groups = int(inputs[8]) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index a5557cede031..528ea685413d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -448,6 +448,33 @@ def forward(self, *args): input_data=torch.randn((1, 8, 16, 16))) +def test_forward_conv_transpose(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class Conv2D1(Module): + def __init__(self): + super(Conv2D1, self).__init__() + self.conv = torch.nn.ConvTranspose2d(3, 6, 7, bias=True) + self.softmax = torch.nn.Softmax() + + def forward(self, *args): + return self.softmax(self.conv(args[0])) + + class Conv2D2(Module): + def __init__(self): + super(Conv2D2, self).__init__() + self.conv = torch.nn.ConvTranspose2d(3, 12, 3, bias=False) + self.softmax = torch.nn.Softmax() + + def forward(self, *args): + return self.softmax(self.conv(args[0])) + + input_data = torch.rand(input_shape).float() + verify_model(Conv2D1().float().eval(), input_data=input_data) + verify_model(Conv2D2().float().eval(), input_data=input_data) + + def test_forward_threshold(): torch.set_grad_enabled(False) input_shape = [1, 3] @@ -1050,6 +1077,7 @@ def forward(self, xs): test_forward_maxpool1d() test_forward_hardtanh() test_forward_conv() + test_forward_conv_transpose() test_forward_threshold() test_forward_contiguous() test_forward_batchnorm() From c88079066a5ed48d8169a227f54574721ba18cad Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 27 Mar 2020 13:25:34 -0700 Subject: [PATCH 2/4] small format change. --- python/tvm/relay/frontend/pytorch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3a6e15b8e02b..6a267110de86 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -260,7 +260,6 @@ def _impl(inputs, input_types): padding = inputs[4] dilation = inputs[5] - if isinstance(weight, _expr.Expr): inferred_shape = _infer_shape(weight) weight_shape = [] From c9e383b401af782566993cffb0955a817b91e79c Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 27 Mar 2020 13:31:53 -0700 Subject: [PATCH 3/4] Chage test module names. --- tests/python/frontend/pytorch/test_forward.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 528ea685413d..129d5f21420a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -452,18 +452,18 @@ def test_forward_conv_transpose(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - class Conv2D1(Module): + class ConvTranspose2D1(Module): def __init__(self): - super(Conv2D1, self).__init__() + super(ConvTranspose2D1, self).__init__() self.conv = torch.nn.ConvTranspose2d(3, 6, 7, bias=True) self.softmax = torch.nn.Softmax() def forward(self, *args): return self.softmax(self.conv(args[0])) - class Conv2D2(Module): + class ConvTranspose2D2(Module): def __init__(self): - super(Conv2D2, self).__init__() + super(ConvTranspose2D2, self).__init__() self.conv = torch.nn.ConvTranspose2d(3, 12, 3, bias=False) self.softmax = torch.nn.Softmax() @@ -471,8 +471,8 @@ def forward(self, *args): return self.softmax(self.conv(args[0])) input_data = torch.rand(input_shape).float() - verify_model(Conv2D1().float().eval(), input_data=input_data) - verify_model(Conv2D2().float().eval(), input_data=input_data) + verify_model(ConvTranspose2D1().float().eval(), input_data=input_data) + verify_model(ConvTranspose2D2().float().eval(), input_data=input_data) def test_forward_threshold(): From 43363ed9edb6d05638c24bba8f592f7f6299913d Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 27 Mar 2020 16:15:20 -0700 Subject: [PATCH 4/4] Simplified test syntax. --- tests/python/frontend/pytorch/test_forward.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 129d5f21420a..1878266a6a86 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -451,28 +451,9 @@ def forward(self, *args): def test_forward_conv_transpose(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - - class ConvTranspose2D1(Module): - def __init__(self): - super(ConvTranspose2D1, self).__init__() - self.conv = torch.nn.ConvTranspose2d(3, 6, 7, bias=True) - self.softmax = torch.nn.Softmax() - - def forward(self, *args): - return self.softmax(self.conv(args[0])) - - class ConvTranspose2D2(Module): - def __init__(self): - super(ConvTranspose2D2, self).__init__() - self.conv = torch.nn.ConvTranspose2d(3, 12, 3, bias=False) - self.softmax = torch.nn.Softmax() - - def forward(self, *args): - return self.softmax(self.conv(args[0])) - input_data = torch.rand(input_shape).float() - verify_model(ConvTranspose2D1().float().eval(), input_data=input_data) - verify_model(ConvTranspose2D2().float().eval(), input_data=input_data) + verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=input_data) + verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=input_data) def test_forward_threshold():