Skip to content

Commit

Permalink
Fix the VTA declaration of conv2d_transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
abergeron committed Nov 14, 2019
1 parent 5c508a0 commit 6ee61d5
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions vta/python/vta/top/vta_conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,26 @@
from ..environment import get_env

@autotvm.register_topi_compute(topi.nn.conv2d_transpose_nchw, 'vta', 'direct')
def _declatation_conv2d_transpose(cfg,
def _declaration_conv2d_transpose(cfg,
data,
kernel,
strides,
padding,
out_dtype):
out_dtype,
output_padding=(0, 0)):
ishape = get_const_tuple(data.shape)
kshape = get_const_tuple(kernel.shape)
b, c_i, i_h, i_w, t_b, t_ci = ishape
c_o, _, k_h, k_w, t_co, t_ci = kshape
stride_h, stride_w = strides
opad_h, opad_w = output_padding

# derive padding parameters
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (k_h, k_w))
bpad_top = k_h - 1 - fpad_top
bpad_bottom = k_h - 1 - fpad_bottom
bpad_bottom = k_h - 1 - fpad_bottom + opad_h
bpad_left = k_w - 1 - fpad_left
bpad_right = k_w - 1 - fpad_right
bpad_right = k_w - 1 - fpad_right + opad_w

# padding stage
dilated_input = topi.nn.dilate(data, [1, 1, stride_h, stride_w, 1, 1])
Expand All @@ -53,8 +55,8 @@ def _declatation_conv2d_transpose(cfg,
[0, 0, bpad_bottom, bpad_right, 0, 0])

# convolution transpose stage
out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h
out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w
out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h + opad_h
out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w + opad_w
oshape = (b, c_o, out_h, out_w, t_b, t_co)
d_c = tvm.reduce_axis((0, c_i), name='d_c')
d_h = tvm.reduce_axis((0, k_h), name='d_h')
Expand Down

0 comments on commit 6ee61d5

Please sign in to comment.