Skip to content

Commit

Permalink
Fix new conv2d_transpose test
Browse files Browse the repository at this point in the history
  • Loading branch information
abergeron committed Nov 27, 2019
1 parent cab70d6 commit 18979ed
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 2 additions & 4 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,8 @@ def test_conv2d_transpose_nhwc_run():
data = np.random.uniform(size=dshape_nhwc).astype(dtype)
kernel = np.random.uniform(size=kshape_hwoi).astype(dtype)
# use true kshape layout here - HWOI
c_np = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI', 2, 1)
d_np = np.zeros(shape=oshape_nhwc)
d_np[:,0:c_np.shape[1],0:c_np.shape[2],:] = c_np
ref_res = d_np
ref_res = topi.testing.conv2d_transpose_nhwc_python(data, kernel, 'HWOI',
2, 1, output_padding=(2, 2))

for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
Expand Down
6 changes: 4 additions & 2 deletions topi/python/topi/testing/conv2d_transpose_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding=(0,
return b_np


def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding):
def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding,
output_padding=(0, 0)):
"""Transposed convolution operator in NHWC layout.
Parameters
Expand Down Expand Up @@ -118,6 +119,7 @@ def conv2d_transpose_nhwc_python(a_nhwc, weight, weight_format, stride, padding)
else:
raise ValueError('Valid weight_formats are HWIO, HWOI, OIHW or IOHW')

res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding)
res_nchw = conv2d_transpose_nchw_python(a_nchw, w_iohw, stride, padding,
output_padding=output_padding)
res_nhwc = np.transpose(res_nchw, (0, 2, 3, 1))
return res_nhwc

0 comments on commit 18979ed

Please sign in to comment.