Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MXNET]conv3d and conv3d_transpose addedx #5814

Merged
merged 1 commit into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 82 additions & 5 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ def _mx_fully_connected(inputs, attrs):


def _get_channel_axis(layout, op_name):
if layout == "NCHW":
if layout in ["NCHW", "NCDHW"]:
return 1
if layout == "NHWC":
return 3
if layout == "NDHWC":
return 4
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "layout" of operator {} is not valid.'.format(layout, op_name))

Expand Down Expand Up @@ -149,13 +151,15 @@ def _mx_zeros(inputs, attrs):

def _mx_conv(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) == 2:
if len(kernel_size) == 3:
return _mx_conv3d(inputs, attrs)
elif len(kernel_size) == 2:
return _mx_conv2d(inputs, attrs)
elif len(kernel_size) == 1:
return _mx_conv1d(inputs, attrs)
else:
raise tvm.error.OpAttributeInvalid(
'1D or 2D kernels only are supported for operator Convolution')
'1D, 2D or 3D kernels only are supported for operator Convolution')

def _mx_conv1d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
Expand Down Expand Up @@ -226,15 +230,53 @@ def _mx_conv2d(inputs, attrs):
return res


def _get_mx_conv3d_attrs(attrs):
kernel_size = attrs.get_int_tuple("kernel")
data_layout = attrs.get_str("layout", "NCDHW")
if "kernel_layout" in attrs.attrs:
kernel_layout = attrs.get_str("kernel_layout")
else:
kernel_layout = "DHWIO" if data_layout == "NDHWC" else "OIDHW"
new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["kernel_size"] = kernel_size
new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1))
new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0))
new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1, 1))
new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout
new_attrs["kernel_layout"] = kernel_layout
return new_attrs


def _mx_conv3d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
data_layout = attrs.get_str("layout", "NCDHW")
if len(kernel_size) != 3:
raise tvm.error.OpAttributeInvalid(
'Only 3D kernels are supported for operator Convolution')

new_attrs = _get_mx_conv3d_attrs(attrs)
channel_axis = _get_channel_axis(data_layout, "conv3d")
use_bias = not attrs.get_bool("no_bias", False)
res = _op.nn.conv3d(inputs[0], inputs[1], **new_attrs)
if use_bias:
assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
return res


def _mx_conv_transpose(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) == 2:
if len(kernel_size) == 3:
return _mx_conv3d_transpose(inputs, attrs)
elif len(kernel_size) == 2:
return _mx_conv2d_transpose(inputs, attrs)
elif len(kernel_size) == 1:
return _mx_conv1d_transpose(inputs, attrs)
else:
raise tvm.error.OpAttributeInvalid(
'1D or 2D kernels only are supported for operator Convolution')
'1D, 2D or 3D kernels only are supported for operator Convolution')


def _mx_conv1d_transpose(inputs, attrs):
Expand Down Expand Up @@ -300,6 +342,41 @@ def _mx_conv2d_transpose(inputs, attrs):
return res


def _mx_conv3d_transpose(inputs, attrs):
if "target_shape" in attrs.attrs:
raise tvm.error.OpAttributeUnImplemented(
'Attribute "target_shape" is not supported for operator Conv3D-transpose.')
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 3:
raise tvm.error.OpAttributeInvalid(
'Non-3D kernels are not supported for operator Conv3D-transpose.')
data_layout = attrs.get_str("layout", "NCDHW")
channel_axis = _get_channel_axis(data_layout, "conv3d_transpose")

if "kernel_layout" in attrs.attrs:
kernel_layout = attrs.get_str("kernel_layout")
else:
kernel_layout = "DHWIO" if data_layout == "NDHWC" else "OIDHW"

new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["kernel_size"] = kernel_size
new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1))
new_attrs["output_padding"] = attrs.get_int_tuple("adj", (0, 0, 0))
new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0))
new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1, 1))
new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout
new_attrs["kernel_layout"] = kernel_layout
use_bias = not attrs.get_bool("no_bias", True)
res = _op.nn.conv3d_transpose(inputs[0], inputs[1], **new_attrs)

if use_bias:
assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
return res


def _mx_pooling(inputs, attrs):
global_pool = attrs.get_bool("global_pool", False)
pool_type = attrs.get_str("pool_type")
Expand Down
4 changes: 4 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,10 @@ def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False)
verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=8,
is_depthwise=True)
verify(data_shape=(1, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
verify(data_shape=(20, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)
verify(data_shape=(1, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(2, 2, 2), pad=(1, 1, 1), num_filter=2)
verify(data_shape=(20, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2)

def test_forward_deconvolution():
def verify(data_shape, kernel_size, stride, pad, num_filter):
Expand Down