From fe6d2ab37256d0092ae6dc1ade104ebae9373028 Mon Sep 17 00:00:00 2001 From: Samuel Date: Wed, 20 May 2020 06:39:44 +0530 Subject: [PATCH] [MXNET]MaxPool3d and AvgPool3d Ops support added (#5614) --- python/tvm/relay/frontend/mxnet.py | 31 ++++++++++++++++++++- tests/python/frontend/mxnet/test_forward.py | 9 ++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index e6384f711e40..edf668041fd5 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -318,6 +318,34 @@ def _pool2d(new_op, is_avg): new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True) return new_op(inputs[0], **new_attrs) + def _pool3d(new_op, is_avg): + kernel_size = attrs.get_int_tuple("kernel") + if len(kernel_size) != 3: + raise tvm.error.OpAttributeInvalid( + 'Only 3D kernels are supported for operator Pool3D.') + new_attrs = {} + new_attrs["pool_size"] = kernel_size + new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0)) + new_attrs["ceil_mode"] = (attrs.get_str("pooling_convention", "valid") == "full") + if is_avg: + new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True) + return new_op(inputs[0], **new_attrs) + + #3D pooling + if len(_infer_shape(inputs[0])) == 5: + if pool_type == "max": + if global_pool: + return _op.nn.global_max_pool3d(inputs[0]) + return _pool3d(_op.nn.max_pool3d, False) + if pool_type == "avg": + if global_pool: + return _op.nn.global_avg_pool3d(inputs[0]) + return _pool3d(_op.nn.avg_pool3d, True) + raise tvm.error.OpNotImplemented( + 'Operator {} Pooling is not supported for frontend MXNet.' \ + .format(pool_type.capitalize())) + #2D Pooling if pool_type == "max": if global_pool: return _op.nn.global_max_pool2d(inputs[0]) @@ -327,7 +355,8 @@ def _pool2d(new_op, is_avg): return _op.nn.global_avg_pool2d(inputs[0]) return _pool2d(_op.nn.avg_pool2d, True) raise tvm.error.OpNotImplemented( - 'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize())) + 'Operator {} Pooling is not supported for frontend MXNet.' \ + .format(pool_type.capitalize())) def _mx_adaptive_avg_pooling(inputs, attrs): diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 81132711f557..6e8acdeab101 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -179,6 +179,14 @@ def test_forward_pooling(): mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max') verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8)) +def test_forward_pooling3d(): + data = mx.sym.var('data') + mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='avg') + verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8)) + + mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='max') + verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8)) + def test_forward_adaptive_pooling(): data = mx.sym.var('data') mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(1,)) @@ -1123,6 +1131,7 @@ def verify(shape, blocksize=2): test_forward_pad() test_forward_slice() test_forward_pooling() + test_forward_pooling3d() test_forward_adaptive_pooling() test_forward_lrn() test_forward_ones()