Skip to content

Commit

Permalink
[Relay] pytorch frontend support conv1d (apache#6203)
Browse files Browse the repository at this point in the history
* [Relay] pytorch frontend support conv1d

* add tests for conv1d

Co-authored-by: xutianming.xtm <[email protected]>
  • Loading branch information
2 people authored and Trevor Morris committed Aug 26, 2020
1 parent c9ca04a commit 4897500
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 13 deletions.
23 changes: 19 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,20 +752,26 @@ def _impl(inputs, input_types):
# If groups > 1 but weight_shape[1] != 1, this is group convolution
if groups > 1 and weight_shape[1] == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3])
new_weight_shape = (groups, channel_multiplier) + tuple(weight_shape[2:])
weight = _op.transform.reshape(weight, new_weight_shape)

kernel_size = weight_shape[2:]
use_bias = isinstance(bias, _expr.Expr)

if isinstance(strides, _expr.Expr):
strides = _infer_shape(strides)
if len(kernel_size) == 1:
strides = (1, ) + strides

if isinstance(padding, _expr.Expr):
padding = _infer_shape(padding)
if len(kernel_size) == 1:
padding = (0, ) + padding

if isinstance(dilation, _expr.Expr):
dilation = _infer_shape(dilation)
if len(kernel_size) == 1:
dilation = (1, ) + dilation

if use_transpose:
if len(kernel_size) == 3:
Expand All @@ -785,6 +791,9 @@ def _impl(inputs, input_types):
data_layout = "NCHW"
kernel_layout = "OIHW"

if len(kernel_size) == 1:
data = _op.expand_dims(data, axis=2)
weight = _op.expand_dims(weight, axis=2)

conv_out = conv_op(data,
weight,
Expand All @@ -793,15 +802,21 @@ def _impl(inputs, input_types):
dilation=dilation,
groups=groups,
channels=channels,
kernel_size=kernel_size,
kernel_size=[1] + kernel_size \
if len(kernel_size) == 1 \
else kernel_size,
data_layout=data_layout,
kernel_layout=kernel_layout,
out_layout="",
out_dtype="")
if use_bias:
return _op.nn.bias_add(conv_out, bias)
res = _op.nn.bias_add(conv_out, bias)
else:
return conv_out
res = conv_out
if len(kernel_size) == 1:
res = _op.squeeze(res, axis=[2])
return res

return _impl

def _softmax():
Expand Down
55 changes: 46 additions & 9 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,8 @@ def test_forward_hardtanh():

def test_forward_conv():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
conv1d_input_shape = [1, 3, 10]
conv2d_input_shape = [1, 3, 10, 10]

class Conv2D1(Module):
def __init__(self):
Expand Down Expand Up @@ -731,23 +732,59 @@ def __init__(self):
def forward(self, *args):
return self.softmax(self.conv(args[0]))

input_data = torch.rand(input_shape).float()
verify_model(Conv2D1().float().eval(), input_data=input_data)
verify_model(Conv2D2().float().eval(), input_data=input_data)
class Conv1D1(Module):
def __init__(self):
super(Conv1D1, self).__init__()
self.conv = torch.nn.Conv1d(3, 6, 7)
self.softmax = torch.nn.Softmax()

def forward(self, *args):
return self.softmax(self.conv(args[0]))

class Conv1D2(Module):
def __init__(self):
super(Conv1D2, self).__init__()
self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
self.softmax = torch.nn.Softmax()

def forward(self, *args):
return self.softmax(self.conv(args[0]))

class Conv1D3(Module):
def __init__(self):
super(Conv1D3, self).__init__()
self.conv = torch.nn.Conv1d(3, 6, 7, groups=3, bias=False)
self.softmax = torch.nn.Softmax()

def forward(self, *args):
return self.softmax(self.conv(args[0]))

conv2d_input_data = torch.rand(conv2d_input_shape).float()
verify_model(Conv2D1().float().eval(), input_data=conv2d_input_data)
verify_model(Conv2D2().float().eval(), input_data=conv2d_input_data)
# depth wise conv with channel mult 2
verify_model(Conv2D3().float().eval(), input_data=input_data)
verify_model(Conv2D3().float().eval(), input_data=conv2d_input_data)
# group conv
verify_model(torch.nn.Conv2d(8, 8, kernel_size=(3, 3),
stride=(1, 1), groups=2).eval(),
input_data=torch.randn((1, 8, 16, 16)))

conv1d_input_data = torch.rand(conv1d_input_shape).float()
verify_model(Conv1D1().float().eval(), input_data=conv1d_input_data)
verify_model(Conv1D2().float().eval(), input_data=conv1d_input_data)
verify_model(Conv1D3().float().eval(), input_data=conv1d_input_data)

def test_forward_conv_transpose():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=input_data)
verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=input_data)
conv2d_input_shape = [1, 3, 10, 10]
conv2d_input_data = torch.rand(conv2d_input_shape).float()
verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=conv2d_input_data)
verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=conv2d_input_data)

conv1d_input_shape = [1, 3, 10]
conv1d_input_data = torch.rand(conv1d_input_shape).float()
verify_model(torch.nn.ConvTranspose1d(3, 6, 7, bias=True), input_data=conv1d_input_data)
verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data)


def test_forward_threshold():
Expand Down

0 comments on commit 4897500

Please sign in to comment.