Skip to content

Commit

Permalink
[TOPI] Fix x86 conv2d template when tuning with unpacked layout (#5938)
Browse files Browse the repository at this point in the history
* fix x86 conv2d and conv2d_transpose template

* address comments
  • Loading branch information
merrymercy authored Jul 2, 2020
1 parent aa035f4 commit 512ed39
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion topi/python/topi/x86/conv2d_avx_1x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
s[data_vec].parallel(parallel_axis)
data_vec = data_vec.op.input_tensors[0]

oc_bn = cfg["tile_oc"].size[-1]
if isinstance(kernel_vec.op, tvm.te.ComputeOp) and \
kernel_vec.name == 'kernel_vec':
# data and kernel are not pre-computed, schedule layout transform here.
Expand All @@ -84,7 +85,6 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):

oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
s[kernel_vec].vectorize(oc_block)
parallel_axis = s[kernel_vec].fuse(oc_chunk, oh)
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/x86/conv2d_avx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
s[data_vec].parallel(parallel_axis)
data_vec = data_vec.op.input_tensors[0]

oc_bn = cfg["tile_oc"].size[-1]
if isinstance(kernel_vec.op, tvm.te.ComputeOp) and \
kernel_vec.name == 'kernel_vec':
# data and kernel are not pre-computed, schedule layout transform here.
Expand All @@ -106,7 +107,6 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):

oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
s[kernel_vec].vectorize(oc_block)
parallel_axis = s[kernel_vec].fuse(oc_chunk, oh)
Expand Down
14 changes: 8 additions & 6 deletions topi/python/topi/x86/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ def _callback(op):
conv_out = op.input_tensors[0]
# retrieve data
data_vec = conv_out.op.input_tensors[0]
data_pad = data_vec.op.input_tensors[0]
data_dilate = data_pad.op.input_tensors[0]
s[data_dilate].compute_inline()
s[data_pad].compute_inline()
if isinstance(data_vec, te.ComputeOp):
data_pad = data_vec.op.input_tensors[0]
data_dilate = data_pad.op.input_tensors[0]
s[data_dilate].compute_inline()
s[data_pad].compute_inline()
# retrieve kernel
kernel_vec = conv_out.op.input_tensors[1]
kernel_transform = kernel_vec.op.input_tensors[0]
s[kernel_transform].compute_inline()
if isinstance(kernel_vec, te.ComputeOp):
kernel_transform = kernel_vec.op.input_tensors[0]
s[kernel_transform].compute_inline()

traverse_inline(s, outs[0].op, _callback)
return s

0 comments on commit 512ed39

Please sign in to comment.