diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 54f13c6881515..cb2ecc9764950 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -278,6 +278,26 @@ def schedule_conv2d_transpose(attrs, outs, target): return topi.generic.schedule_conv2d_transpose_nchw(outs) +@reg.register_legalize("nn.conv2d_transpose") +def legalize_conv2d_transpose(attrs, inputs, types): + """Legalize conv2d_transpose op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current Transposed convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.nn.conv2d_transpose_legalize(attrs, inputs, types) + reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) # bias_add diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 2de0257aa841f..35b2c053f8cff 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -284,3 +284,8 @@ class BinaryConv2DAttrs(Attrs): @register_relay_attr_node class BinaryDenseAttrs(Attrs): """Attributes used in bitserial dense operators""" + + +@register_relay_attr_node +class Conv2DTransposeAttrs(Attrs): + """Attributes used in Transposed Conv2D operators""" diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 08c5eb0d5cfca..b54efaac0bde4 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -311,8 +311,8 @@ def test_conv2d_transpose_infer_type(): (10, 15, 3, 3), "float32") # infer by shape of w, mixed precision - n, c, h, w = tvm.var("n"), 10, 10, 12 - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + n, h, w, c = tvm.var("n"), 10, 10, 12 + x = relay.var("x", relay.TensorType((n, h, w, c), "float32")) w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32")) y = relay.nn.conv2d_transpose(x, w, output_padding=(1, 1), @@ -323,7 +323,7 @@ def test_conv2d_transpose_infer_type(): (n, 15, 15, 11), "float32") -def test_conv2d_transpose_run(): +def test_conv2d_transpose_nchw_run(): dshape = (1, 3, 18, 18) kshape = (3, 10, 3, 3) oshape = (1, 10, 37, 37) @@ -348,6 +348,33 @@ def test_conv2d_transpose_run(): tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) +def test_conv2d_transpose_nhwc_run(): + dshape_nhwc = (1, 18, 18, 3) + kshape_hwoi = (3, 3, 10, 3) + oshape_nhwc = (1, 37, 37, 10) + x = relay.var("x", shape=dshape_nhwc) + w = relay.var("w") + # kshape and kernel_layout should have swapped IO. + # kshape is HWOI and kernel_layout is HWIO + y = relay.nn.conv2d_transpose(x, w, + channels=10, kernel_size=(3, 3), strides=(2, 2), + padding=(1, 1), output_padding=(2, 2), + data_layout="NHWC", kernel_layout="HWIO") + func = relay.Function([x, w], y) + dtype = "float32" + data = np.random.uniform(size=dshape_nhwc).astype(dtype) + kernel = np.random.uniform(size=kshape_hwoi).astype(dtype) + # use true kshape layout here - HWOI + c_np = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1) + d_np = np.zeros(shape=oshape_nhwc) + d_np[:,0:c_np.shape[1],0:c_np.shape[2],:] = c_np + ref_res = d_np + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data, kernel) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + def test_upsampling_infer_type(): n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") @@ -819,7 +846,8 @@ def test_bitpack_infer_type(): test_pad_infer_type() test_pad_run() test_conv2d_transpose_infer_type() - test_conv2d_transpose_run() + test_conv2d_transpose_nchw_run() + test_conv2d_transpose_nhwc_run() test_conv2d_run() test_conv2d_winograd() test_bitserial_conv2d_infer_type() diff --git a/topi/python/topi/nn/conv2d_transpose.py b/topi/python/topi/nn/conv2d_transpose.py index 2f3e323337079..068829d55f771 100644 --- a/topi/python/topi/nn/conv2d_transpose.py +++ b/topi/python/topi/nn/conv2d_transpose.py @@ -17,12 +17,17 @@ # pylint: disable=invalid-name, unused-variable, unused-argument """Transposed 2D convolution operators (sometimes called Deconvolution).""" from __future__ import absolute_import as _abs + +import logging + import tvm +from tvm import relay from .dilate import dilate from .pad import pad from .util import get_pad_tuple from ..util import simplify +logger = logging.getLogger('topi') @tvm.target.generic_func def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype): @@ -102,3 +107,65 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype) axis=[dc, dh, dw]), tag="conv2d_transpose_nchw") return Output + + +@tvm.target.generic_func +def conv2d_transpose_legalize(attrs, inputs, types): + """Legalizes Transposed 2D convolution op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current Transposed 2D convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + if attrs['data_layout'] == 'NHWC': + data, kernel = inputs + kernel_layout = attrs['kernel_layout'] + # Convert Kernel layout to IOHW + # kernel_layout is different from input kernel layout - IO is swapped + if kernel_layout == 'HWIO': + # input kernel layout is swapped to HWOI + # output kernel layout will be IOHW + kernel = relay.transpose(kernel, axes=(3, 2, 0, 1)) + elif kernel_layout == 'HWOI': + # input kernel layout is swapped to HWIO + # output kernel layout will be IOHW + kernel = relay.transpose(kernel, axes=(2, 3, 0, 1)) + elif kernel_layout == 'IOHW': + # input kernel layout is swapped to OIHW + # output kernel layout will be IOHW + kernel = relay.transpose(kernel, axes=(1, 0, 2, 3)) + elif kernel_layout == 'OIHW': + # input kernel layout is swapped to IOHW + # output kernel layout will be IOHW + pass + else: + raise ValueError("Invalid kernel_layout {}".format(kernel_layout)) + + logger.warning("Legalize conv2d_transpose - NHWC schedule is absent. " + + "Inserting layout transforms to " + + "fallback to NCHW. This can result in performance degradation.") + + # Set new attrs for conv2d_transpose. + new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs['data_layout'] = 'NCHW' + # layout of kernel should be IOHW, but kernel_layout should be swapped - OIHW + new_attrs['kernel_layout'] = 'OIHW' + + # Convert data to NCHW. + data = relay.transpose(data, axes=(0, 3, 1, 2)) + deconv = relay.nn.conv2d_transpose(data, kernel, **new_attrs) + # Convert back to original NHWC layout. + out = relay.transpose(deconv, axes=(0, 2, 3, 1)) + return out + + return None diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 6c4a0e30cee8e..6d7340070113f 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -24,7 +24,7 @@ from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nhwc_python import conv2d_nhwc_python -from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python +from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python diff --git a/topi/python/topi/testing/conv2d_transpose_nchw_python.py b/topi/python/topi/testing/conv2d_transpose_python.py similarity index 65% rename from topi/python/topi/testing/conv2d_transpose_nchw_python.py rename to topi/python/topi/testing/conv2d_transpose_python.py index 60b9d69c81de1..50c43eb70e3e0 100644 --- a/topi/python/topi/testing/conv2d_transpose_nchw_python.py +++ b/topi/python/topi/testing/conv2d_transpose_python.py @@ -73,3 +73,50 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding): padded_a_np[n, c], w_np[c, f], mode='valid') b_np[n, f] += out return b_np + + +def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding): + """Transposed convolution operator in NHWC layout. + + Parameters + ---------- + a_nhwc : numpy.ndarray + 4-D with shape [batch, in_height, in_width, in_channel] + + weight : numpy.ndarray + 4-D in formats HWIO, HWOI, OIHW or IOHW + + weight_format : str + ['HWIO', 'HWOI', 'OIHW', 'IOHW'] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + assert a_nhwc.ndim == 4, "a_nhwc number of dimensions should be 4" + assert weight.ndim == 4, "weight number of dimensions should be 4" + + a_nchw = np.transpose(a_nhwc, (0, 3, 1, 2)) + + # conv2d_transpose_nchw_python needs kernel layout to be IOHW + if weight_format == 'HWIO': + w_iohw = np.transpose(weight, (2, 3, 0, 1)) + elif weight_format == 'HWOI': + w_iohw = np.transpose(weight, (3, 2, 0, 1)) + elif weight_format == 'OIHW': + w_iohw = np.transpose(weight, (1, 0, 2, 3)) + elif weight_format == 'IOHW': + w_iohw = weight + else: + raise ValueError('Valid weight_formats are HWIO, HWOI, OIHW or IOHW') + + res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding) + res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1)) + return res_nhwc