From 3b2ce81580c28ea6e6c3508b86f87a72d9a4edf1 Mon Sep 17 00:00:00 2001 From: Samuel Date: Fri, 3 Apr 2020 03:50:41 +0530 Subject: [PATCH] [PYTORCH]AvgPool3d, MaxPool3d and Squeeze op support (#5220) * [PYTORCH]AvgPool3d, MaxPool3d and Squeeze op support * Testcases added * review comments --- python/tvm/relay/frontend/pytorch.py | 55 +++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 45 +++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9a08af904701..977a899ec93d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -57,6 +57,17 @@ def _impl(inputs, input_types): return get_relay_op(name)(data0, data1) return _impl +def _squeeze(): + def _impl(inputs, input_types): + data = inputs[0] + if len(inputs) == 1: + axis = None + else: + axis = [int(inputs[1])] + + return _op.transform.squeeze(data, axis) + return _impl + def _unsqueeze(): def _impl(inputs, input_types): data = inputs[0] @@ -297,6 +308,26 @@ def _impl(inputs, input_types): return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode) return _impl +def _maxpool_3d(): + def _impl(inputs, input_types): + data = inputs[0] + + pool_size = _infer_shape(inputs[1]) + strides = _infer_shape(inputs[2]) + padding = _infer_shape(inputs[3]) + dilation = _infer_shape(inputs[4]) + ceil_mode = int(inputs[5]) + if dilation != (1, 1, 1): + msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation), ) + raise NotImplementedError(msg) + + return _op.nn.max_pool3d(data, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode) + return _impl + def _hardtanh(): def _impl(inputs, input_types): a = inputs[0] @@ -631,6 +662,27 @@ def func(x): return _impl +def _avg_pool3d(): + def _impl(inputs, input_types): + data = inputs[0] + + pool_size = _infer_shape(inputs[1]) + if inputs[2]: + strides = _infer_shape(inputs[2]) + else: + strides = pool_size + padding = _infer_shape(inputs[3]) + + ceil_mode = int(inputs[4]) + count_include_pad = int(inputs[5]) + + return _op.nn.avg_pool3d(data, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad) + return _impl def _dropout(): def _impl(inputs, input_types): @@ -970,6 +1022,7 @@ def _wrap_const(c): "aten::ones" : _ones(), "aten::zeros" : _zeros(), "aten::to" : _to(), + "aten::squeeze" : _squeeze(), "aten::unsqueeze" : _unsqueeze(), "aten::cat" : _concatenate(), "aten::slice" : _slice(), @@ -987,6 +1040,7 @@ def _wrap_const(c): "aten::max_pool2d" : _maxpool_2d(), "aten::max_pool2d_with_indices" : _maxpool_2d(), "aten::max_pool1d" : _maxpool_1d(), + "aten::max_pool3d" : _maxpool_3d(), "aten::hardtanh" : _hardtanh(), "aten::hardtanh_" : _hardtanh(), "aten::_convolution" : _convolution(), @@ -1007,6 +1061,7 @@ def _wrap_const(c): "aten::log_softmax" : _log_softmax(), "aten::sigmoid" : _sigmoid(), "aten::avg_pool2d" : _avg_pool2d(), + "aten::avg_pool3d" : _avg_pool3d(), "aten::dropout" : _dropout(), "aten::dropout_" : _dropout(), "aten::feature_dropout" : _dropout(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index c75ae6ed96aa..e7c2e0841a87 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -304,6 +304,22 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Unsqueeze1().float().eval(), input_data=input_data) +def test_forward_squeeze(): + torch.set_grad_enabled(False) + input_shape = [2, 1, 10, 1, 10] + + class Squeeze1(Module): + def forward(self, *args): + return args[0].squeeze() + + class Squeeze2(Module): + def forward(self, *args): + return args[0].squeeze(1) + + input_data = torch.rand(input_shape).float() + verify_model(Squeeze1().float().eval(), input_data=input_data) + verify_model(Squeeze2().float().eval(), input_data=input_data) + def test_forward_concatenate(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -388,6 +404,20 @@ def test_forward_maxpool1d(): stride=2).eval(), input_data) +def test_forward_maxpool3d(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10, 10] + input_data = torch.rand(input_shape).float() + + verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(), + input_data) + verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(), + input_data) + verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4], + padding=2, + stride=2).eval(), + input_data) + def test_forward_split(): torch.set_grad_enabled(False) input_shape = [4, 10] @@ -423,6 +453,18 @@ def forward(self, *args): verify_model(torch.nn.AvgPool2d(kernel_size=[10, 10]).eval(), input_data=input_data) verify_model(AvgPool2D2().float().eval(), input_data=input_data) +def test_forward_avgpool3d(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10, 10] + + class AvgPool3D1(Module): + def forward(self, *args): + return torch.nn.functional.avg_pool3d(args[0], kernel_size=[10, 10, 10]) + + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.AvgPool3d(kernel_size=[10, 10, 10]).eval(), input_data=input_data) + verify_model(AvgPool3D1().float().eval(), input_data=input_data) + def test_forward_hardtanh(): torch.set_grad_enabled(False) input_shape = [10] @@ -1071,6 +1113,7 @@ def forward(self, xs): test_forward_add() test_forward_subtract() test_forward_multiply() + test_forward_squeeze() test_forward_unsqueeze() test_forward_concatenate() test_forward_relu() @@ -1081,6 +1124,7 @@ def forward(self, xs): test_forward_adaptiveavgpool() test_forward_maxpool2d() test_forward_maxpool1d() + test_forward_maxpool3d() test_forward_hardtanh() test_forward_conv() test_forward_conv_transpose() @@ -1097,6 +1141,7 @@ def forward(self, xs): test_forward_sigmoid() test_forward_dense() test_forward_avgpool() + test_forward_avgpool3d() test_forward_dropout() test_forward_slice() test_forward_mean()