diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 847623bf0b06..1d8842d69d12 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -87,10 +87,12 @@ def _mx_fully_connected(inputs, attrs): def _get_channel_axis(layout, op_name): - if layout == "NCHW": + if layout in ["NCHW", "NCDHW"]: return 1 if layout == "NHWC": return 3 + if layout == "NDHWC": + return 4 raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "layout" of operator {} is not valid.'.format(layout, op_name)) @@ -149,13 +151,15 @@ def _mx_zeros(inputs, attrs): def _mx_conv(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") - if len(kernel_size) == 2: + if len(kernel_size) == 3: + return _mx_conv3d(inputs, attrs) + elif len(kernel_size) == 2: return _mx_conv2d(inputs, attrs) elif len(kernel_size) == 1: return _mx_conv1d(inputs, attrs) else: raise tvm.error.OpAttributeInvalid( - '1D or 2D kernels only are supported for operator Convolution') + '1D, 2D or 3D kernels only are supported for operator Convolution') def _mx_conv1d(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") @@ -226,15 +230,53 @@ def _mx_conv2d(inputs, attrs): return res +def _get_mx_conv3d_attrs(attrs): + kernel_size = attrs.get_int_tuple("kernel") + data_layout = attrs.get_str("layout", "NCDHW") + if "kernel_layout" in attrs.attrs: + kernel_layout = attrs.get_str("kernel_layout") + else: + kernel_layout = "DHWIO" if data_layout == "NDHWC" else "OIDHW" + new_attrs = {} + new_attrs["channels"] = attrs.get_int("num_filter") + new_attrs["kernel_size"] = kernel_size + new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0)) + new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1, 1)) + new_attrs["groups"] = attrs.get_int("num_group", 1) + new_attrs["data_layout"] = data_layout + new_attrs["kernel_layout"] = kernel_layout + return new_attrs + + +def _mx_conv3d(inputs, attrs): + kernel_size = attrs.get_int_tuple("kernel") + data_layout = attrs.get_str("layout", "NCDHW") + if len(kernel_size) != 3: + raise tvm.error.OpAttributeInvalid( + 'Only 3D kernels are supported for operator Convolution') + + new_attrs = _get_mx_conv3d_attrs(attrs) + channel_axis = _get_channel_axis(data_layout, "conv3d") + use_bias = not attrs.get_bool("no_bias", False) + res = _op.nn.conv3d(inputs[0], inputs[1], **new_attrs) + if use_bias: + assert len(inputs) == 3 + res = _op.nn.bias_add(res, inputs[2], axis=channel_axis) + return res + + def _mx_conv_transpose(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") - if len(kernel_size) == 2: + if len(kernel_size) == 3: + return _mx_conv3d_transpose(inputs, attrs) + elif len(kernel_size) == 2: return _mx_conv2d_transpose(inputs, attrs) elif len(kernel_size) == 1: return _mx_conv1d_transpose(inputs, attrs) else: raise tvm.error.OpAttributeInvalid( - '1D or 2D kernels only are supported for operator Convolution') + '1D, 2D or 3D kernels only are supported for operator Convolution') def _mx_conv1d_transpose(inputs, attrs): @@ -300,6 +342,41 @@ def _mx_conv2d_transpose(inputs, attrs): return res +def _mx_conv3d_transpose(inputs, attrs): + if "target_shape" in attrs.attrs: + raise tvm.error.OpAttributeUnImplemented( + 'Attribute "target_shape" is not supported for operator Conv3D-transpose.') + kernel_size = attrs.get_int_tuple("kernel") + if len(kernel_size) != 3: + raise tvm.error.OpAttributeInvalid( + 'Non-3D kernels are not supported for operator Conv3D-transpose.') + data_layout = attrs.get_str("layout", "NCDHW") + channel_axis = _get_channel_axis(data_layout, "conv3d_transpose") + + if "kernel_layout" in attrs.attrs: + kernel_layout = attrs.get_str("kernel_layout") + else: + kernel_layout = "DHWIO" if data_layout == "NDHWC" else "OIDHW" + + new_attrs = {} + new_attrs["channels"] = attrs.get_int("num_filter") + new_attrs["kernel_size"] = kernel_size + new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1)) + new_attrs["output_padding"] = attrs.get_int_tuple("adj", (0, 0, 0)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0)) + new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1, 1)) + new_attrs["groups"] = attrs.get_int("num_group", 1) + new_attrs["data_layout"] = data_layout + new_attrs["kernel_layout"] = kernel_layout + use_bias = not attrs.get_bool("no_bias", True) + res = _op.nn.conv3d_transpose(inputs[0], inputs[1], **new_attrs) + + if use_bias: + assert len(inputs) == 3 + res = _op.nn.bias_add(res, inputs[2], axis=channel_axis) + return res + + def _mx_pooling(inputs, attrs): global_pool = attrs.get_bool("global_pool", False) pool_type = attrs.get_str("pool_type") diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 463b50f5b265..00c077f0d2e0 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1033,6 +1033,10 @@ def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False) verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=8, is_depthwise=True) + verify(data_shape=(1, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2) + verify(data_shape=(20, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2) + verify(data_shape=(1, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(2, 2, 2), pad=(1, 1, 1), num_filter=2) + verify(data_shape=(20, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2) def test_forward_deconvolution(): def verify(data_shape, kernel_size, stride, pad, num_filter):