diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 4c3144c4382a4..b557e06fb3d56 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -318,16 +318,45 @@ def _pool2d(new_op, is_avg): new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True) return new_op(inputs[0], **new_attrs) - if pool_type == "max": - if global_pool: - return _op.nn.global_max_pool2d(inputs[0]) - return _pool2d(_op.nn.max_pool2d, False) - if pool_type == "avg": - if global_pool: - return _op.nn.global_avg_pool2d(inputs[0]) - return _pool2d(_op.nn.avg_pool2d, True) - raise tvm.error.OpNotImplemented( - 'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize())) + def _pool3d(new_op, is_avg): + kernel_size = attrs.get_int_tuple("kernel") + if len(kernel_size) != 3: + raise tvm.error.OpAttributeInvalid( + 'Only 3D kernels are supported for operator Pool3D.') + new_attrs = {} + new_attrs["pool_size"] = kernel_size + new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0)) + new_attrs["ceil_mode"] = (attrs.get_str("pooling_convention", "valid") == "full") + if is_avg: + new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True) + return new_op(inputs[0], **new_attrs) + + #3D pooling + if len(_infer_shape(inputs[0])) == 5: + if pool_type == "max": + if global_pool: + return _op.nn.global_max_pool3d(inputs[0]) + return _pool3d(_op.nn.max_pool3d, False) + if pool_type == "avg": + if global_pool: + return _op.nn.global_avg_pool3d(inputs[0]) + return _pool3d(_op.nn.avg_pool3d, True) + raise tvm.error.OpNotImplemented( + 'Operator {} Pooling is not supported for frontend MXNet.' \ + .format(pool_type.capitalize())) + else: + if pool_type == "max": + if global_pool: + return _op.nn.global_max_pool2d(inputs[0]) + return _pool2d(_op.nn.max_pool2d, False) + if pool_type == "avg": + if global_pool: + return _op.nn.global_avg_pool2d(inputs[0]) + return _pool2d(_op.nn.avg_pool2d, True) + raise tvm.error.OpNotImplemented( + 'Operator {} Pooling is not supported for frontend MXNet.' \ + .format(pool_type.capitalize())) def _mx_adaptive_avg_pooling(inputs, attrs): diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 9dd85065c8859..6f5e56b6263af 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -179,6 +179,14 @@ def test_forward_pooling(): mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max') verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8)) +def test_forward_pooling3d(): + data = mx.sym.var('data') + mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='avg') + verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8)) + + mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='max') + verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8)) + def test_forward_adaptive_pooling(): data = mx.sym.var('data') mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(1,)) @@ -1121,6 +1129,7 @@ def verify(shape, blocksize=2): test_forward_pad() test_forward_slice() test_forward_pooling() + test_forward_pooling3d() test_forward_adaptive_pooling() test_forward_lrn() test_forward_ones()