Skip to content

Commit

Permalink
add dilation in x86 NCHWc depthwise conv support (#4962)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjliu1998 committed Aug 13, 2020
1 parent abfa79d commit 2aaea26
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
18 changes: 12 additions & 6 deletions python/tvm/topi/x86/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,18 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,

strides = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HSTR, WSTR = strides
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (filter_height, filter_width))

dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
assert (dh, dw) == (1, 1), "Does not support dilation"

out_height = (in_height - filter_height + pad_top + pad_down) // HSTR + 1
out_width = (in_width - filter_width + pad_left + pad_right) // WSTR + 1
dilated_kernel_h = (filter_height - 1) * dh + 1
dilated_kernel_w = (filter_width - 1) * dw + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
HPAD = pad_top + pad_down
WPAD = pad_left + pad_right

out_height = (in_height + HPAD - dilated_kernel_h) // HSTR + 1
out_width = (in_width + WPAD - dilated_kernel_w) // WSTR + 1

cfg.define_split("tile_ic", in_channel, num_outputs=2)
cfg.define_split("tile_oc", out_channel, num_outputs=2)
Expand All @@ -140,7 +145,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
te.placeholder((batch, in_channel, in_height, in_width), dtype=data.dtype),
te.placeholder((out_channel, channel_multiplier, filter_height, filter_width),
dtype=kernel.dtype),
strides, padding, out_dtype)
strides, (pad_top, pad_down), out_dtype)
if cfg.is_fallback:
_fallback_schedule(cfg, wkl)

Expand Down Expand Up @@ -172,6 +177,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
else:
data_pad = data


# depthconv stage
idxdiv = tvm.tir.indexdiv
idxmod = tvm.tir.indexmod
Expand All @@ -184,7 +190,7 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation,
(data_pad[
b,
idxdiv(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block),
oh*HSTR+kh, ow*WSTR+kw,
oh*HSTR+kh*dh, ow*WSTR+kw*dw,
idxmod(idxdiv(oco * out_channel_block + oci, channel_multiplier), in_channel_block)]
.astype(out_dtype) *
kernel[oco, 0, kh, kw, 0, oci].astype(out_dtype)),
Expand Down
8 changes: 5 additions & 3 deletions tests/python/topi/python/test_topi_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def depthwise_conv2d_with_workload_NCHWc(batch, in_channel, in_height, channel_m
filter_width = filter_height
stride_h = stride_w = stride

assert dilation == 1, "depthwise_conv2d_NCHWc currently does not support dilation."
#assert dilation == 1, "depthwise_conv2d_NCHWc currently does not support dilation."
assert channel_multiplier == 1, "depthwise_conv2d_NCHWc currently does not support channel multiplier > 1."
pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
padding_args = (pad_h, pad_w)
Expand Down Expand Up @@ -306,7 +306,7 @@ def check_device(device):
# declare
DepthwiseConv2d = topi.x86.depthwise_conv2d_NCHWc(Input, Filter,
(stride_h, stride_w),
padding_args,
padding,
(dilation, dilation),
in_layout,
out_layout, dtype)
Expand All @@ -329,8 +329,9 @@ def get_ref_data():
input_np = np.random.uniform(size=input_shape).astype(dtype)
filter_np = np.random.uniform(size=filter_shape).astype(dtype)
# correctness with scipy
dw_np = tvm.topi.testing.dilate_python(filter_np, (1, 1, dilation, dilation)).astype(dtype)
depthwise_conv2d_scipy = tvm.topi.testing.depthwise_conv2d_python_nchw(
input_np, filter_np, stride, padding)
input_np, dw_np, stride, padding)
relu_scipy = np.maximum(depthwise_conv2d_scipy, 0)
return (_transform_data(input_np, ic_block),
_transform_kernel(filter_np, oc_block),
Expand Down Expand Up @@ -389,6 +390,7 @@ def test_depthwise_conv2d():
# depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2)

# NCHW[x]c
depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME", dilation=2)
depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "VALID")

Expand Down

0 comments on commit 2aaea26

Please sign in to comment.