Skip to content

Commit

Permalink
[MXNET]MaxPool3d and AvgPool3d Ops support added (#5614)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored May 20, 2020
1 parent 9a9fe97 commit 78e5aa1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
31 changes: 30 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,34 @@ def _pool2d(new_op, is_avg):
new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True)
return new_op(inputs[0], **new_attrs)

def _pool3d(new_op, is_avg):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 3:
raise tvm.error.OpAttributeInvalid(
'Only 3D kernels are supported for operator Pool3D.')
new_attrs = {}
new_attrs["pool_size"] = kernel_size
new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1))
new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0))
new_attrs["ceil_mode"] = (attrs.get_str("pooling_convention", "valid") == "full")
if is_avg:
new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True)
return new_op(inputs[0], **new_attrs)

#3D pooling
if len(_infer_shape(inputs[0])) == 5:
if pool_type == "max":
if global_pool:
return _op.nn.global_max_pool3d(inputs[0])
return _pool3d(_op.nn.max_pool3d, False)
if pool_type == "avg":
if global_pool:
return _op.nn.global_avg_pool3d(inputs[0])
return _pool3d(_op.nn.avg_pool3d, True)
raise tvm.error.OpNotImplemented(
'Operator {} Pooling is not supported for frontend MXNet.' \
.format(pool_type.capitalize()))
#2D Pooling
if pool_type == "max":
if global_pool:
return _op.nn.global_max_pool2d(inputs[0])
Expand All @@ -327,7 +355,8 @@ def _pool2d(new_op, is_avg):
return _op.nn.global_avg_pool2d(inputs[0])
return _pool2d(_op.nn.avg_pool2d, True)
raise tvm.error.OpNotImplemented(
'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize()))
'Operator {} Pooling is not supported for frontend MXNet.' \
.format(pool_type.capitalize()))


def _mx_adaptive_avg_pooling(inputs, attrs):
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ def test_forward_pooling():
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))

def test_forward_pooling3d():
data = mx.sym.var('data')
mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='avg')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8))

mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='max')
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8))

def test_forward_adaptive_pooling():
data = mx.sym.var('data')
mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(1,))
Expand Down Expand Up @@ -1123,6 +1131,7 @@ def verify(shape, blocksize=2):
test_forward_pad()
test_forward_slice()
test_forward_pooling()
test_forward_pooling3d()
test_forward_adaptive_pooling()
test_forward_lrn()
test_forward_ones()
Expand Down

0 comments on commit 78e5aa1

Please sign in to comment.