Skip to content

Commit

Permalink
[PYTORCH]AvgPool3d, MaxPool3d and Squeeze op support (apache#5220)
Browse files Browse the repository at this point in the history
* [PYTORCH]AvgPool3d, MaxPool3d and Squeeze op support

* Testcases added

* review comments
  • Loading branch information
siju-samuel authored and Trevor Morris committed Apr 16, 2020
1 parent dcc4768 commit 3b2ce81
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
55 changes: 55 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def _impl(inputs, input_types):
return get_relay_op(name)(data0, data1)
return _impl

def _squeeze():
def _impl(inputs, input_types):
data = inputs[0]
if len(inputs) == 1:
axis = None
else:
axis = [int(inputs[1])]

return _op.transform.squeeze(data, axis)
return _impl

def _unsqueeze():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -297,6 +308,26 @@ def _impl(inputs, input_types):
return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode)
return _impl

def _maxpool_3d():
def _impl(inputs, input_types):
data = inputs[0]

pool_size = _infer_shape(inputs[1])
strides = _infer_shape(inputs[2])
padding = _infer_shape(inputs[3])
dilation = _infer_shape(inputs[4])
ceil_mode = int(inputs[5])
if dilation != (1, 1, 1):
msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation), )
raise NotImplementedError(msg)

return _op.nn.max_pool3d(data,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode)
return _impl

def _hardtanh():
def _impl(inputs, input_types):
a = inputs[0]
Expand Down Expand Up @@ -631,6 +662,27 @@ def func(x):

return _impl

def _avg_pool3d():
def _impl(inputs, input_types):
data = inputs[0]

pool_size = _infer_shape(inputs[1])
if inputs[2]:
strides = _infer_shape(inputs[2])
else:
strides = pool_size
padding = _infer_shape(inputs[3])

ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])

return _op.nn.avg_pool3d(data,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
return _impl

def _dropout():
def _impl(inputs, input_types):
Expand Down Expand Up @@ -970,6 +1022,7 @@ def _wrap_const(c):
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::to" : _to(),
"aten::squeeze" : _squeeze(),
"aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(),
"aten::slice" : _slice(),
Expand All @@ -987,6 +1040,7 @@ def _wrap_const(c):
"aten::max_pool2d" : _maxpool_2d(),
"aten::max_pool2d_with_indices" : _maxpool_2d(),
"aten::max_pool1d" : _maxpool_1d(),
"aten::max_pool3d" : _maxpool_3d(),
"aten::hardtanh" : _hardtanh(),
"aten::hardtanh_" : _hardtanh(),
"aten::_convolution" : _convolution(),
Expand All @@ -1007,6 +1061,7 @@ def _wrap_const(c):
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
"aten::avg_pool2d" : _avg_pool2d(),
"aten::avg_pool3d" : _avg_pool3d(),
"aten::dropout" : _dropout(),
"aten::dropout_" : _dropout(),
"aten::feature_dropout" : _dropout(),
Expand Down
45 changes: 45 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,22 @@ def forward(self, *args):
input_data = torch.rand(input_shape).float()
verify_model(Unsqueeze1().float().eval(), input_data=input_data)

def test_forward_squeeze():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]

class Squeeze1(Module):
def forward(self, *args):
return args[0].squeeze()

class Squeeze2(Module):
def forward(self, *args):
return args[0].squeeze(1)

input_data = torch.rand(input_shape).float()
verify_model(Squeeze1().float().eval(), input_data=input_data)
verify_model(Squeeze2().float().eval(), input_data=input_data)

def test_forward_concatenate():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -388,6 +404,20 @@ def test_forward_maxpool1d():
stride=2).eval(),
input_data)

def test_forward_maxpool3d():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10, 10]
input_data = torch.rand(input_shape).float()

verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(),
input_data)
verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(),
input_data)
verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4],
padding=2,
stride=2).eval(),
input_data)

def test_forward_split():
torch.set_grad_enabled(False)
input_shape = [4, 10]
Expand Down Expand Up @@ -423,6 +453,18 @@ def forward(self, *args):
verify_model(torch.nn.AvgPool2d(kernel_size=[10, 10]).eval(), input_data=input_data)
verify_model(AvgPool2D2().float().eval(), input_data=input_data)

def test_forward_avgpool3d():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10, 10]

class AvgPool3D1(Module):
def forward(self, *args):
return torch.nn.functional.avg_pool3d(args[0], kernel_size=[10, 10, 10])

input_data = torch.rand(input_shape).float()
verify_model(torch.nn.AvgPool3d(kernel_size=[10, 10, 10]).eval(), input_data=input_data)
verify_model(AvgPool3D1().float().eval(), input_data=input_data)

def test_forward_hardtanh():
torch.set_grad_enabled(False)
input_shape = [10]
Expand Down Expand Up @@ -1071,6 +1113,7 @@ def forward(self, xs):
test_forward_add()
test_forward_subtract()
test_forward_multiply()
test_forward_squeeze()
test_forward_unsqueeze()
test_forward_concatenate()
test_forward_relu()
Expand All @@ -1081,6 +1124,7 @@ def forward(self, xs):
test_forward_adaptiveavgpool()
test_forward_maxpool2d()
test_forward_maxpool1d()
test_forward_maxpool3d()
test_forward_hardtanh()
test_forward_conv()
test_forward_conv_transpose()
Expand All @@ -1097,6 +1141,7 @@ def forward(self, xs):
test_forward_sigmoid()
test_forward_dense()
test_forward_avgpool()
test_forward_avgpool3d()
test_forward_dropout()
test_forward_slice()
test_forward_mean()
Expand Down

0 comments on commit 3b2ce81

Please sign in to comment.