Skip to content

Commit

Permalink
[Relay][Legalize] Legalize conv2d_transpose for NHWC
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Nov 22, 2019
1 parent 70017ef commit cf65f02
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 5 deletions.
20 changes: 20 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,26 @@ def schedule_conv2d_transpose(attrs, outs, target):
return topi.generic.schedule_conv2d_transpose_nchw(outs)


@reg.register_legalize("nn.conv2d_transpose")
def legalize_conv2d_transpose(attrs, inputs, types):
"""Legalize conv2d_transpose op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current Transposed convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.conv2d_transpose_legalize(attrs, inputs, types)

reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)

# bias_add
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,8 @@ class BinaryConv2DAttrs(Attrs):
@register_relay_attr_node
class BinaryDenseAttrs(Attrs):
"""Attributes used in bitserial dense operators"""


@register_relay_attr_node
class Conv2DTransposeAttrs(Attrs):
"""Attributes used in Transposed Conv2D operators"""
36 changes: 32 additions & 4 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ def test_conv2d_transpose_infer_type():
(10, 15, 3, 3), "float32")

# infer by shape of w, mixed precision
n, c, h, w = tvm.var("n"), 10, 10, 12
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
n, h, w, c = tvm.var("n"), 10, 10, 12
x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32"))
y = relay.nn.conv2d_transpose(x, w,
output_padding=(1, 1),
Expand All @@ -323,7 +323,7 @@ def test_conv2d_transpose_infer_type():
(n, 15, 15, 11), "float32")


def test_conv2d_transpose_run():
def test_conv2d_transpose_nchw_run():
dshape = (1, 3, 18, 18)
kshape = (3, 10, 3, 3)
oshape = (1, 10, 37, 37)
Expand All @@ -348,6 +348,33 @@ def test_conv2d_transpose_run():
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)


def test_conv2d_transpose_nhwc_run():
dshape_nhwc = (1, 18, 18, 3)
kshape_hwoi = (3, 3, 10, 3)
oshape_nhwc = (1, 37, 37, 10)
x = relay.var("x", shape=dshape_nhwc)
w = relay.var("w")
# kshape and kernel_layout should have swapped IO.
# kshape is HWOI and kernel_layout is HWIO
y = relay.nn.conv2d_transpose(x, w,
channels=10, kernel_size=(3, 3), strides=(2, 2),
padding=(1, 1), output_padding=(2, 2),
data_layout="NHWC", kernel_layout="HWIO")
func = relay.Function([x, w], y)
dtype = "float32"
data = np.random.uniform(size=dshape_nhwc).astype(dtype)
kernel = np.random.uniform(size=kshape_hwoi).astype(dtype)
# use true kshape layout here - HWOI
c_np = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1)
d_np = np.zeros(shape=oshape_nhwc)
d_np[:,0:c_np.shape[1],0:c_np.shape[2],:] = c_np
ref_res = d_np

for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data, kernel)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)


def test_upsampling_infer_type():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
Expand Down Expand Up @@ -819,7 +846,8 @@ def test_bitpack_infer_type():
test_pad_infer_type()
test_pad_run()
test_conv2d_transpose_infer_type()
test_conv2d_transpose_run()
test_conv2d_transpose_nchw_run()
test_conv2d_transpose_nhwc_run()
test_conv2d_run()
test_conv2d_winograd()
test_bitserial_conv2d_infer_type()
Expand Down
67 changes: 67 additions & 0 deletions topi/python/topi/nn/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Transposed 2D convolution operators (sometimes called Deconvolution)."""
from __future__ import absolute_import as _abs

import logging

import tvm
from tvm import relay
from .dilate import dilate
from .pad import pad
from .util import get_pad_tuple
from ..util import simplify

logger = logging.getLogger('topi')

@tvm.target.generic_func
def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
Expand Down Expand Up @@ -102,3 +107,65 @@ def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype)
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")

return Output


@tvm.target.generic_func
def conv2d_transpose_legalize(attrs, inputs, types):
"""Legalizes Transposed 2D convolution op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current Transposed 2D convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
if attrs['data_layout'] == 'NHWC':
data, kernel = inputs
kernel_layout = attrs['kernel_layout']
# Convert Kernel layout to IOHW
# kernel_layout is different from input kernel layout - IO is swapped
if kernel_layout == 'HWIO':
# input kernel layout is swapped to HWOI
# output kernel layout will be IOHW
kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
elif kernel_layout == 'HWOI':
# input kernel layout is swapped to HWIO
# output kernel layout will be IOHW
kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
elif kernel_layout == 'IOHW':
# input kernel layout is swapped to OIHW
# output kernel layout will be IOHW
kernel = relay.transpose(kernel, axes=(1, 0, 2, 3))
elif kernel_layout == 'OIHW':
# input kernel layout is swapped to IOHW
# output kernel layout will be IOHW
pass
else:
raise ValueError("Invalid kernel_layout {}".format(kernel_layout))

logger.warning("Legalize conv2d_transpose - NHWC schedule is absent. "
+ "Inserting layout transforms to "
+ "fallback to NCHW. This can result in performance degradation.")

# Set new attrs for conv2d_transpose.
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs['data_layout'] = 'NCHW'
# layout of kernel should be IOHW, but kernel_layout should be swapped - OIHW
new_attrs['kernel_layout'] = 'OIHW'

# Convert data to NCHW.
data = relay.transpose(data, axes=(0, 3, 1, 2))
deconv = relay.nn.conv2d_transpose(data, kernel, **new_attrs)
# Convert back to original NHWC layout.
out = relay.transpose(deconv, axes=(0, 2, 3, 1))
return out

return None
2 changes: 1 addition & 1 deletion topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .conv2d_nhwc_python import conv2d_nhwc_python
from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python
from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
from .dilate_python import dilate_python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,50 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding):
padded_a_np[n, c], w_np[c, f], mode='valid')
b_np[n, f] += out
return b_np


def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding):
"""Transposed convolution operator in NHWC layout.
Parameters
----------
a_nhwc : numpy.ndarray
4-D with shape [batch, in_height, in_width, in_channel]
weight : numpy.ndarray
4-D in formats HWIO, HWOI, OIHW or IOHW
weight_format : str
['HWIO', 'HWOI', 'OIHW', 'IOHW']
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
assert a_nhwc.ndim == 4, "a_nhwc number of dimensions should be 4"
assert weight.ndim == 4, "weight number of dimensions should be 4"

a_nchw = np.transpose(a_nhwc, (0, 3, 1, 2))

# conv2d_transpose_nchw_python needs kernel layout to be IOHW
if weight_format == 'HWIO':
w_iohw = np.transpose(weight, (2, 3, 0, 1))
elif weight_format == 'HWOI':
w_iohw = np.transpose(weight, (3, 2, 0, 1))
elif weight_format == 'OIHW':
w_iohw = np.transpose(weight, (1, 0, 2, 3))
elif weight_format == 'IOHW':
w_iohw = weight
else:
raise ValueError('Valid weight_formats are HWIO, HWOI, OIHW or IOHW')

res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding)
res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
return res_nhwc

0 comments on commit cf65f02

Please sign in to comment.