Skip to content

Commit

Permalink
Change the meaning of conv3d_transpose output_padding to match conv{1…
Browse files Browse the repository at this point in the history
…,2}d_transpose (apache#6065)

* Change the meaning of output_padding to correspond to conv{1,2}d_transpose

* Fix long lines

* Fix the relay test

* Add missing doc.

* fix size ordering problem
  • Loading branch information
abergeron authored and wjliu1998 committed Aug 13, 2020
1 parent 1b88105 commit e42019c
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 38 deletions.
7 changes: 2 additions & 5 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,15 +364,12 @@ def compute_conv3d_transpose(attrs, inputs, out_dtype):
"""Compute definition of conv3d_transpose"""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
output_padding = get_const_tuple(attrs.output_padding)
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
out = topi_compute(
inputs[0], inputs[1], strides, padding, out_dtype)
output_padding = get_const_tuple(attrs.output_padding)
out = topi.nn.pad(out,
[0, 0, 0, 0, 0],
[0, 0, output_padding[0], output_padding[1], output_padding[2]])
inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
return [out]
return compute_conv3d_transpose

Expand Down
14 changes: 10 additions & 4 deletions python/tvm/topi/cuda/conv3d_transpose_ncdhw.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@


@autotvm.register_topi_compute("conv3d_transpose_ncdhw.cuda")
def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype):
def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype,
output_padding):
"""Transposed 3D convolution ncdhw forward operator.
Parameters
Expand All @@ -43,6 +44,8 @@ def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype):
Padding size, or ['VALID', 'SAME']
out_dtype: str
The output type. This is used in mixed precision
output_padding : tuple of three ints
Used to disambiguate output shape
Returns
-------
Expand All @@ -52,24 +55,27 @@ def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype):
batch, inp_channels, inp_depth, inp_height, inp_width = get_const_tuple(data.shape)
_, out_channels, kernel_depth, kernel_height, kernel_width = get_const_tuple(kernel.shape)
stride_depth, stride_height, stride_width = stride
outpad_depth, outpad_height, outpad_width = output_padding
assert (outpad_height < stride_height and outpad_width < stride_width and
outpad_depth < stride_depth)
cfg.stride = stride
pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = nn.get_pad_tuple3d(
padding, (kernel_depth, kernel_height, kernel_width))

out_depth = (inp_depth - 1) * stride_depth + \
kernel_depth - pad_front - pad_back
kernel_depth - pad_front - pad_back + outpad_depth
pad_front = kernel_depth - 1 - pad_front
pad_back = kernel_depth - 1 - pad_back
dilated_depth = stride_depth * (inp_depth - 1) + 1

out_width = (inp_width - 1) * stride_width + \
kernel_width - pad_left - pad_right
kernel_width - pad_left - pad_right + outpad_width
pad_left = kernel_width - 1 - pad_left
pad_right = kernel_width - 1 - pad_right
dilated_width = stride_width * (inp_width - 1) + 1

out_height = (inp_height - 1) * stride_height + \
kernel_height - pad_top - pad_bottom
kernel_height - pad_top - pad_bottom + outpad_height
pad_top = kernel_height - 1 - pad_top
pad_bottom = kernel_height - 1 - pad_bottom
dilated_height = stride_height * (inp_height - 1) + 1
Expand Down
22 changes: 14 additions & 8 deletions python/tvm/topi/nn/conv3d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..util import simplify


def conv3d_transpose_ncdhw(Input, Filter, strides, padding, out_dtype):
def conv3d_transpose_ncdhw(Input, Filter, strides, padding, out_dtype, output_padding):
"""Transposed 3D convolution ncdhw forward operator.
Parameters
Expand All @@ -45,31 +45,37 @@ def conv3d_transpose_ncdhw(Input, Filter, strides, padding, out_dtype):
out_dtype : str
The output data type. This is used for mixed precision.
output_padding : tuple of ints
Used to get the right output shape for gradients
Returns
-------
Output : tvm.te.Tensor
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
"""
return declaration_conv3d_transpose_impl(Input, Filter, strides, padding, out_dtype)
return declaration_conv3d_transpose_impl(Input, Filter, strides, padding,
out_dtype, output_padding)


def conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype):
def conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype, output_padding):
"""Preprocess data and kernel to make the compute pattern
of conv3d_transpose the same as conv3d"""
batch, in_c, in_d, in_h, in_w = data.shape
_, out_c, filter_d, filter_h, filter_w = kernel.shape
stride_d, stride_h, stride_w = strides
opad_d, opad_h, opad_w = output_padding
assert opad_d < stride_d and opad_h < stride_h and opad_w < stride_w
# dilate data
data_dilate = dilate(data, [1, 1, stride_d, stride_h, stride_w], name='data_dilate')
# pad data
fpad_front, fpad_top, fpad_left, fpad_back, fpad_bottom, fpad_right = get_pad_tuple3d(
padding, (filter_d, filter_h, filter_w))
bpad_front = filter_d - 1 - fpad_front
bpad_back = filter_d - 1 - fpad_back
bpad_back = filter_d - 1 - fpad_back + opad_d
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom
bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right
bpad_right = filter_w - 1 - fpad_right + opad_w
data_pad = pad(data_dilate, \
[0, 0, bpad_front, bpad_top, bpad_left], \
[0, 0, bpad_back, bpad_bottom, bpad_right], \
Expand All @@ -82,10 +88,10 @@ def conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype)
return data_pad, kernel_transform


def declaration_conv3d_transpose_impl(data, kernel, strides, padding, out_dtype):
def declaration_conv3d_transpose_impl(data, kernel, strides, padding, out_dtype, output_padding):
"""Implementation of conv3d transpose"""
data_pad, kernel_transform = \
conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype)
conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype, output_padding)
batch, in_c, in_d, in_h, in_w = data_pad.shape
out_c, _, filter_d, filter_h, filter_w = kernel_transform.shape
stride_d, stride_h, stride_w = strides
Expand Down
18 changes: 13 additions & 5 deletions python/tvm/topi/testing/conv3d_transpose_ncdhw_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.topi.nn.util import get_pad_tuple3d


def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding):
def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding):
"""Transposed 3d convolution operator in NCDHW layout.
Parameters
Expand All @@ -38,6 +38,9 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding):
padding : int or str
Padding size
output_padding : int or list/tuple of three ints
Used to disambiguate output shape.
Returns
-------
b_np : np.ndarray
Expand All @@ -49,6 +52,11 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding):
stride_d = stride_h = stride_w = stride
else:
stride_d, stride_h, stride_w = stride
if isinstance(output_padding, int):
opad_d = opad_h = opad_w = output_padding
else:
opad_d, opad_h, opad_w = output_padding
assert opad_d < stride_d and opad_h < stride_h and opad_w < stride_w

# dilate stage
dilated_a_np = tvm.topi.testing.dilate_python(a_np, [1, 1, stride_d, stride_h, stride_w])
Expand All @@ -58,19 +66,19 @@ def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding):
padding, (filter_d, filter_h, filter_w))

bpad_front = filter_d - 1 - fpad_front
bpad_back = filter_d - 1 - fpad_back
bpad_back = filter_d - 1 - fpad_back + opad_d
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom
bpad_bottom = filter_h - 1 - fpad_bottom + opad_h
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right
bpad_right = filter_w - 1 - fpad_right + opad_w

padded_a_np = np.zeros((batch,
in_c,
dilated_a_np.shape[2]+bpad_front+bpad_back,
dilated_a_np.shape[3]+bpad_top+bpad_bottom,
dilated_a_np.shape[4]+bpad_left+bpad_right))

padded_a_np[:, :, bpad_front:dilated_a_np.shape[2]+bpad_back,
padded_a_np[:, :, bpad_front:dilated_a_np.shape[2]+bpad_front,
bpad_top:dilated_a_np.shape[3]+bpad_top,
bpad_left:dilated_a_np.shape[4]+bpad_left] = dilated_a_np

Expand Down
5 changes: 3 additions & 2 deletions python/tvm/topi/x86/conv3d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from .. import nn
from .conv3d import conv3d_ncdhw, schedule_conv3d_ncdhw

def conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype):
def conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype, output_padding):
data_pad, kernel_transform = \
nn.conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype)
nn.conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding,
out_dtype, output_padding)

# reuse conv3d_ncdhw implementation
return conv3d_ncdhw(data_pad, kernel_transform, (1, 1, 1),
Expand Down
3 changes: 1 addition & 2 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,7 @@ def test_conv3d_transpose_ncdhw_run():

data = np.random.uniform(size=dshape).astype(dtype)
kernel = np.random.uniform(size=kshape).astype(dtype)

ref_res = tvm.topi.testing.conv3d_transpose_ncdhw_python(data, kernel, 1, 1)
ref_res = tvm.topi.testing.conv3d_transpose_ncdhw_python(data, kernel, 1, 1, 0)

for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
Expand Down
27 changes: 15 additions & 12 deletions tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"gpu": (topi.cuda.conv3d_transpose_ncdhw, topi.cuda.schedule_conv3d_transpose_ncdhw),
}

def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding, output_padding):
in_depth, in_height, in_width = in_size
kernel_depth, kernel_height, kernel_width = kernel
stride_depth, stride_height, stride_width = stride
Expand All @@ -49,7 +49,7 @@ def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = tvm.topi.testing.conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding)
b_np = tvm.topi.testing.conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding, output_padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np

Expand All @@ -66,7 +66,7 @@ def check_device(device):
B = fcompute(A, W,
[stride_depth, stride_height, stride_width],
[pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right],
A.dtype)
A.dtype, output_padding)
C = topi.nn.relu(B)
s1 = fschedule([B])
s2 = fschedule([C])
Expand All @@ -86,15 +86,18 @@ def check_device(device):


def test_conv3d_transpose_ncdhw():
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 1, (1, 1, 1), (1, 1, 1), (0, 0, 0, 0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 2, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (2, 2, 2), (1, 1, 1, 1, 1, 1))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (2, 2, 2), (2, 2, 2), (0, 0, 0, 0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 32, (5, 5, 5), (1, 1, 1), (0, 0, 0, 0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 1, (1, 1, 1), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 2, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (2, 2, 2))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0), (1, 0, 2))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (2, 2, 2), (2, 2, 2), (0, 0, 0, 0, 0, 0), (0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 32, (5, 5, 5), (1, 1, 1), (0, 0, 0, 0, 0, 0), (0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0))
verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (1, 1, 1))

if __name__ == "__main__":
test_conv3d_transpose_ncdhw()

0 comments on commit e42019c

Please sign in to comment.