Skip to content

Commit

Permalink
Fix the meaning of conv{1,2}d_transpose output_padding parameter. (#5758
Browse files Browse the repository at this point in the history
)

* Add output_padding to generic

* Add output_padding to the reference impl

* Add output_padding to arm_cpu

* Add output_padding to the test

* Add output_padding for cuda

* Add output_padding for x86

* Make use of the new output_padding argument in Relay

* Adjust conv2d_transpose Relay test

* Fix lint errors

* Fix the VTA declaration of conv2d_transpose

* support for output padding in conv2d transpose

* some output padding will break IR pass

* Fix new conv2d_transpose test

* Update tophub

* Fix conv1d output_padding too.

* Fix the conv1d_transpose reference function.

* Fix the cuda impl

* fix the topi test for conv1d

* format

* Add tests for conv1d_transpose output_padding and some check that the values are valid.

* Add check in the implementations

* Add checks to the implementations of conv2d

* Make use of the output_padding argument from topi in relay.

* Fix relay tests asking for invalid output_padding

* Fix line length

* Fix vta tests

* Update tophub references

* Trigger CI

Co-authored-by: Thierry Moreau <[email protected]>
  • Loading branch information
abergeron and tmoreau89 authored Jun 30, 2020
1 parent 2e04393 commit bc22fb9
Show file tree
Hide file tree
Showing 18 changed files with 195 additions and 120 deletions.
7 changes: 3 additions & 4 deletions python/tvm/autotvm/tophub.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,15 @@

# the version of each package
PACKAGE_VERSION = {
'arm_cpu': "v0.06",
'arm_cpu': "v0.07",
'llvm': "v0.04",

'cuda': "v0.08",
'cuda': "v0.09",
'rocm': "v0.05",
'opencl': "v0.04",
'mali': "v0.06",
'intel_graphics': "v0.02",

'vta': "v0.08",
'vta': "v0.09",
'amd_apu': "v0.01",
}

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
reg.register_strategy("nn.conv2d_transpose", strategy.conv2d_transpose_strategy)
reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_legalize("nn.conv2d_transpose")
def legalize_conv2d_transpose(attrs, inputs, types):
"""Legalize conv2d_transpose op.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def conv2d_transpose(data,
Layout of the output, by default, out_layout is the same as data_layout
output_padding : Tuple[int], optional
Additional zero-padding to be added to one side of the output.
Used to disambiguate the output shape.
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Expand Down Expand Up @@ -562,7 +562,7 @@ def conv1d_transpose(data,
Layout of the output, by default, out_layout is the same as data_layout
output_padding : Tuple[int], optional
Additional zero-padding to be added to one side of the output.
Used to disambiguate the output shape.
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Expand Down
9 changes: 3 additions & 6 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,9 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype):
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
out = topi_compute(
inputs[0], inputs[1], strides, padding, out_dtype)
output_padding = get_const_tuple(attrs.output_padding)
out = topi.nn.pad(out, [0, 0, 0, 0],
[0, 0, output_padding[0], output_padding[1]])
out = topi_compute(
inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
return [out]
return compute_conv2d_transpose

Expand Down Expand Up @@ -502,9 +500,8 @@ def _compute_conv1d_tranpsoe(attrs, inputs, out_type):
strides = get_const_tuple(attrs.strides)
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if out_dtype in ("same", "") else out_dtype)
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype)
output_padding = get_const_tuple(attrs.output_padding)
out = topi.nn.pad(out, [0, 0, 0], [0, 0, output_padding[0]])
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
return [out]
return _compute_conv1d_tranpsoe

Expand Down
37 changes: 18 additions & 19 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,21 +704,18 @@ def test_conv2d_transpose_infer_type():
def test_conv2d_transpose_nchw_run():
dshape = (1, 3, 18, 18)
kshape = (3, 10, 3, 3)
oshape = (1, 10, 37, 37)
oshape = (1, 10, 36, 36)
x = relay.var("x", shape=dshape)
w = relay.var("w")
y = relay.nn.conv2d_transpose(x, w,
channels=10, kernel_size=(3,3), strides=(2,2),
padding=(1,1), output_padding=(2, 2))
padding=(1,1), output_padding=(1, 1))
func = relay.Function([x, w], y)
dtype = "float32"
data = np.random.uniform(size=dshape).astype(dtype)
kernel = np.random.uniform(size=kshape).astype(dtype)
c_np = topi.testing.conv2d_transpose_nchw_python(
data, kernel, 2, 1)
d_np = np.zeros(shape=oshape)
d_np[:,:,0:c_np.shape[2],0:c_np.shape[3]] = c_np
ref_res = d_np
ref_res = topi.testing.conv2d_transpose_nchw_python(
data, kernel, 2, 1, (1, 1))

for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
Expand All @@ -729,43 +726,45 @@ def test_conv2d_transpose_nchw_run():
def test_conv2d_transpose_nhwc_run():
dshape_nhwc = (1, 18, 18, 3)
kshape_hwoi = (3, 3, 10, 3)
oshape_nhwc = (1, 37, 37, 10)
oshape_nhwc = (1, 36, 36, 10)
x = relay.var("x", shape=dshape_nhwc)
w = relay.var("w")
# kshape and kernel_layout should have swapped IO.
# kshape is HWOI and kernel_layout is HWIO
y = relay.nn.conv2d_transpose(x, w,
channels=10, kernel_size=(3, 3), strides=(2, 2),
padding=(1, 1), output_padding=(2, 2),
padding=(1, 1), output_padding=(1, 1),
data_layout="NHWC", kernel_layout="HWIO")
func = relay.Function([x, w], y)
dtype = "float32"
data = np.random.uniform(size=dshape_nhwc).astype(dtype)
kernel = np.random.uniform(size=kshape_hwoi).astype(dtype)
# use true kshape layout here - HWOI
c_np = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1)
d_np = np.zeros(shape=oshape_nhwc)
d_np[:,0:c_np.shape[1],0:c_np.shape[2],:] = c_np

ref_res = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI',
2, 1, output_padding=(1, 1))

for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data, kernel)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)


def test_conv1d_transpose_ncw_run():
dshape = (1, 3, 18)
kshape = (3, 10, 3)
oshape = (1, 10, 37)
oshape = (1, 10, 36)
x = relay.var("x", shape=dshape)
w = relay.var("w")
y = relay.nn.conv1d_transpose(x, w,
channels=10, kernel_size=(3,), strides=(2,),
padding=(1,), output_padding=(2,))
padding=(1,), output_padding=(1,))
func = relay.Function([x, w], y)
dtype = "float32"
data = np.random.uniform(size=dshape).astype(dtype)
kernel = np.random.uniform(size=kshape).astype(dtype)
c_np = topi.testing.conv1d_transpose_ncw_python(
data, kernel, 2, 1)
d_np = np.zeros(shape=oshape)
d_np[:,:,0:c_np.shape[2]] = c_np
ref_res = d_np
ref_res = topi.testing.conv1d_transpose_ncw_python(
data, kernel, 2, 1, output_padding=(1,))

for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
Expand Down
26 changes: 18 additions & 8 deletions topi/python/topi/arm_cpu/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
from ..util import get_const_tuple, traverse_inline
from .conv2d_spatial_pack import schedule_conv2d_spatial_pack_nchw



@autotvm.register_topi_compute("conv2d_transpose_nchw.arm_cpu")
def conv2d_transpose_nchw(cfg, Input, Filter, strides, padding, out_dtype):
def conv2d_transpose_nchw(cfg, Input, Filter, strides, padding, out_dtype,
output_padding):
"""Transposed 2D convolution nchw forward operator.
Parameters
Expand All @@ -47,27 +50,34 @@ def conv2d_transpose_nchw(cfg, Input, Filter, strides, padding, out_dtype):
out_dtype: str
The output data type. This is used for mixed precision.
output_padding : tuple of int
Used to get the right output shape in gradients
Returns
-------
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return _decl_spatial_pack(cfg, Input, Filter, strides, padding, "NCHW", out_dtype, 2)
return _decl_spatial_pack(cfg, Input, Filter, strides, padding, "NCHW", out_dtype, 2,
output_padding)

def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile):
def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile,
output_padding):
assert layout == "NCHW", "Only support NCHW"
out_dtype = out_dtype or data.dtype

N, CI, IH, IW = get_const_tuple(data.shape)
_, CO, KH, KW = get_const_tuple(kernel.shape)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
opad_h, opad_w = output_padding
assert opad_h < HSTR and opad_w < WSTR

pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (KH, KW))
bpad_top, bpad_bottom = KH - 1 - pad_top, KH - 1 - pad_bottom
bpad_left, bpad_right = KW - 1 - pad_left, KW - 1 - pad_right
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
bpad_top, bpad_bottom = KH - 1 - pad_top, KH - 1 - pad_bottom + opad_h
bpad_left, bpad_right = KW - 1 - pad_left, KW - 1 - pad_right + opad_w

OH = (IH - 1) * HSTR - pad_top - pad_bottom + KH
OW = (IW - 1) * WSTR - pad_left - pad_right + KW
OH = (IH - 1) * HSTR - pad_top - pad_bottom + KH + opad_h
OW = (IW - 1) * WSTR - pad_left - pad_right + KW + opad_w

dilated_input = dilate(data, [1, 1, HSTR, WSTR])
data_pad = pad(dilated_input, [0, 0, bpad_top, bpad_left], [0, 0, bpad_bottom, bpad_right])
Expand Down
13 changes: 10 additions & 3 deletions topi/python/topi/cuda/conv1d_transpose_ncw.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from ..util import get_const_tuple, traverse_inline

@autotvm.task.register_topi_compute("conv1d_transpose_nchw.cuda")
def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype):
def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype,
output_padding):
"""Transposed 1D convolution ncw forward operator.
Parameters
Expand All @@ -43,6 +44,8 @@ def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype):
string: ['VALID', 'SAME']
out_dtype: str
The output type. This is used in mixed precision
output_padding : ints
Used to disambiguate the output shape.
Returns
-------
Expand All @@ -51,13 +54,17 @@ def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype):
"""
if isinstance(stride, (tuple, list)):
stride = stride[0]
if isinstance(output_padding, (tuple, list)):
output_padding = output_padding[0]
assert output_padding < stride
cfg.stride = stride
cfg.output_padding = output_padding
batch, inp_channels, inp_width = get_const_tuple(data.shape)
_, out_channels, kernel_size = get_const_tuple(kernel.shape)
pad_left, pad_right = nn.get_pad_tuple1d(padding, kernel_size)
out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right
out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right + output_padding
pad_left = kernel_size - 1 - pad_left
pad_right = kernel_size - 1 - pad_right
pad_right = kernel_size - 1 - pad_right + output_padding
dilated_width = stride * (inp_width - 1) + 1
data = te.compute(
(batch, inp_channels, pad_left + dilated_width + pad_right),
Expand Down
12 changes: 9 additions & 3 deletions topi/python/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
from ..util import get_const_tuple, traverse_inline



@autotvm.register_topi_compute("conv2d_transpose_nchw.cuda")
def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype):
def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype,
output_padding):
"""Transposed 2D convolution nchw forward operator.
Parameters
Expand All @@ -43,6 +45,8 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype):
Padding size, or ['VALID', 'SAME']
out_dtype: str
The output type. This is used in mixed precision
output_padding : tuple of two ints
Used to disambiguate output shape.
Returns
-------
Expand All @@ -52,18 +56,20 @@ def conv2d_transpose_nchw(cfg, data, kernel, stride, padding, out_dtype):
batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape)
_, out_channels, kernel_height, kernel_width = get_const_tuple(kernel.shape)
stride_height, stride_width = stride
outpad_height, outpad_width = output_padding
assert outpad_height < stride_height and outpad_width < stride_width
cfg.stride = stride
pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
padding, (kernel_height, kernel_width))

out_width = (inp_width - 1) * stride_width + \
kernel_width - pad_left - pad_right
kernel_width - pad_left - pad_right + outpad_width
pad_left = kernel_width - 1 - pad_left
pad_right = kernel_width - 1 - pad_right
dilated_width = stride_width * (inp_width - 1) + 1

out_height = (inp_height - 1) * stride_height + \
kernel_height - pad_top - pad_bottom
kernel_height - pad_top - pad_bottom + outpad_height
pad_top = kernel_height - 1 - pad_top
pad_bottom = kernel_height - 1 - pad_bottom
dilated_height = stride_height * (inp_height - 1) + 1
Expand Down
13 changes: 11 additions & 2 deletions topi/python/topi/nn/conv1d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from .util import get_pad_tuple1d


def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype):
def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype,
output_padding):
"""Transposed 1D convolution ncw forward operator.
Parameters
Expand All @@ -43,22 +44,30 @@ def conv1d_transpose_ncw(data, kernel, stride, padding, out_dtype):
out_dtype : str
The output data type. This is used for mixed precision.
output_padding : ints
Used to recover the actual output shape in case there are more
than one possible shape. Must be smaller than stride.
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, out_channel, out_width]
"""

# dilate and pad
if isinstance(stride, (tuple, list)):
stride = stride[0]
if isinstance(output_padding, (tuple, list)):
output_padding = output_padding[0]
batch, channels_in, data_width = data.shape
_, channels_out, kernel_width = kernel.shape
assert output_padding < stride
channels_out = simplify(channels_out)
data = dilate(data, [1, 1, stride], name='data_dilate')
pad_left, pad_right = get_pad_tuple1d(padding, (kernel_width,))
pad_left = kernel_width - 1 - pad_left
pad_right = kernel_width - 1 - pad_right
pad_right = kernel_width - 1 - pad_right + output_padding
data = pad(data, [0, 0, pad_left], [0, 0, pad_right], name='data_pad')

# transpose kernel, switch kernel layout to IOW
Expand Down
Loading

0 comments on commit bc22fb9

Please sign in to comment.