From 4897500c80ae2cf01dbef446c3ea3fd825179f89 Mon Sep 17 00:00:00 2001 From: Tianming Xu Date: Thu, 6 Aug 2020 04:01:27 +0800 Subject: [PATCH] [Relay] pytorch frontend support conv1d (#6203) * [Relay] pytorch frontend support conv1d * add tests for conv1d Co-authored-by: xutianming.xtm --- python/tvm/relay/frontend/pytorch.py | 23 ++++++-- tests/python/frontend/pytorch/test_forward.py | 55 ++++++++++++++++--- 2 files changed, 65 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 57b64ace64ae1..3dfdb2f70e7f8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -752,7 +752,7 @@ def _impl(inputs, input_types): # If groups > 1 but weight_shape[1] != 1, this is group convolution if groups > 1 and weight_shape[1] == 1: channel_multiplier = channels // groups - new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3]) + new_weight_shape = (groups, channel_multiplier) + tuple(weight_shape[2:]) weight = _op.transform.reshape(weight, new_weight_shape) kernel_size = weight_shape[2:] @@ -760,12 +760,18 @@ def _impl(inputs, input_types): if isinstance(strides, _expr.Expr): strides = _infer_shape(strides) + if len(kernel_size) == 1: + strides = (1, ) + strides if isinstance(padding, _expr.Expr): padding = _infer_shape(padding) + if len(kernel_size) == 1: + padding = (0, ) + padding if isinstance(dilation, _expr.Expr): dilation = _infer_shape(dilation) + if len(kernel_size) == 1: + dilation = (1, ) + dilation if use_transpose: if len(kernel_size) == 3: @@ -785,6 +791,9 @@ def _impl(inputs, input_types): data_layout = "NCHW" kernel_layout = "OIHW" + if len(kernel_size) == 1: + data = _op.expand_dims(data, axis=2) + weight = _op.expand_dims(weight, axis=2) conv_out = conv_op(data, weight, @@ -793,15 +802,21 @@ def _impl(inputs, input_types): dilation=dilation, groups=groups, channels=channels, - kernel_size=kernel_size, + kernel_size=[1] + kernel_size \ + if len(kernel_size) == 1 \ + else kernel_size, data_layout=data_layout, kernel_layout=kernel_layout, out_layout="", out_dtype="") if use_bias: - return _op.nn.bias_add(conv_out, bias) + res = _op.nn.bias_add(conv_out, bias) else: - return conv_out + res = conv_out + if len(kernel_size) == 1: + res = _op.squeeze(res, axis=[2]) + return res + return _impl def _softmax(): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6a572db0bc29b..ab9cca1d4b65e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -702,7 +702,8 @@ def test_forward_hardtanh(): def test_forward_conv(): torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] + conv1d_input_shape = [1, 3, 10] + conv2d_input_shape = [1, 3, 10, 10] class Conv2D1(Module): def __init__(self): @@ -731,23 +732,59 @@ def __init__(self): def forward(self, *args): return self.softmax(self.conv(args[0])) - input_data = torch.rand(input_shape).float() - verify_model(Conv2D1().float().eval(), input_data=input_data) - verify_model(Conv2D2().float().eval(), input_data=input_data) + class Conv1D1(Module): + def __init__(self): + super(Conv1D1, self).__init__() + self.conv = torch.nn.Conv1d(3, 6, 7) + self.softmax = torch.nn.Softmax() + + def forward(self, *args): + return self.softmax(self.conv(args[0])) + + class Conv1D2(Module): + def __init__(self): + super(Conv1D2, self).__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) + self.softmax = torch.nn.Softmax() + + def forward(self, *args): + return self.softmax(self.conv(args[0])) + + class Conv1D3(Module): + def __init__(self): + super(Conv1D3, self).__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, groups=3, bias=False) + self.softmax = torch.nn.Softmax() + + def forward(self, *args): + return self.softmax(self.conv(args[0])) + + conv2d_input_data = torch.rand(conv2d_input_shape).float() + verify_model(Conv2D1().float().eval(), input_data=conv2d_input_data) + verify_model(Conv2D2().float().eval(), input_data=conv2d_input_data) # depth wise conv with channel mult 2 - verify_model(Conv2D3().float().eval(), input_data=input_data) + verify_model(Conv2D3().float().eval(), input_data=conv2d_input_data) # group conv verify_model(torch.nn.Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), groups=2).eval(), input_data=torch.randn((1, 8, 16, 16))) + conv1d_input_data = torch.rand(conv1d_input_shape).float() + verify_model(Conv1D1().float().eval(), input_data=conv1d_input_data) + verify_model(Conv1D2().float().eval(), input_data=conv1d_input_data) + verify_model(Conv1D3().float().eval(), input_data=conv1d_input_data) def test_forward_conv_transpose(): torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] - input_data = torch.rand(input_shape).float() - verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=input_data) - verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=input_data) + conv2d_input_shape = [1, 3, 10, 10] + conv2d_input_data = torch.rand(conv2d_input_shape).float() + verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=conv2d_input_data) + verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=conv2d_input_data) + + conv1d_input_shape = [1, 3, 10] + conv1d_input_data = torch.rand(conv1d_input_shape).float() + verify_model(torch.nn.ConvTranspose1d(3, 6, 7, bias=True), input_data=conv1d_input_data) + verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data) def test_forward_threshold():