From 34da389fedc749ce64baeb74a2b582daf6fef3e6 Mon Sep 17 00:00:00 2001 From: optima2005 <56945758+optima2005@users.noreply.github.com> Date: Mon, 6 Jan 2020 11:53:47 +0800 Subject: [PATCH] [CONV] Asymmetric padding (#4511) * [CONV] Asymmetic padding * fix lint error * update for legalize, rocm and cudnn * add more test cases * change more symmetric padding * change conv2d winograd tests according orginal cases * remove 'alter_op_layout.h' header in bitserial.cc --- include/tvm/relay/attrs/nn.h | 26 ++++-- python/tvm/contrib/pickle_memoize.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 19 +--- src/relay/op/nn/bitserial.cc | 7 +- src/relay/op/nn/convolution.cc | 27 +++--- src/relay/op/nn/convolution.h | 6 +- tests/python/contrib/test_nnpack.py | 19 ++-- topi/python/topi/arm_cpu/conv2d.py | 30 ++++--- topi/python/topi/bifrost/conv2d.py | 8 +- topi/python/topi/cuda/conv2d.py | 17 ++-- topi/python/topi/cuda/conv2d_winograd.py | 8 +- topi/python/topi/intel_graphics/conv2d.py | 6 +- topi/python/topi/mali/conv2d.py | 8 +- topi/python/topi/nn/conv2d.py | 88 +++++++++++++------ topi/python/topi/rocm/conv2d.py | 10 ++- .../python/topi/testing/conv2d_hwcn_python.py | 27 +++--- .../python/topi/testing/conv2d_nchw_python.py | 34 +++---- .../python/topi/testing/conv2d_nhwc_python.py | 28 +++--- .../testing/deformable_conv2d_nchw_python.py | 20 ++--- topi/python/topi/x86/conv2d.py | 32 ++++--- topi/python/topi/x86/conv2d_alter_op.py | 7 +- topi/python/topi/x86/conv2d_int8.py | 7 +- topi/python/topi/x86/depthwise_conv2d.py | 6 +- topi/tests/python/test_topi_conv2d_NCHWc.py | 31 +++++-- topi/tests/python/test_topi_conv2d_int8.py | 30 +++++-- topi/tests/python/test_topi_conv2d_nchw.py | 37 ++++++-- topi/tests/python/test_topi_conv2d_nhwc.py | 5 ++ .../tests/python/test_topi_conv2d_winograd.py | 25 +++++- topi/tests/python/test_topi_conv3d_ncdhw.py | 2 - 29 files changed, 338 insertions(+), 234 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index a2cad94320d7..f50d8fcfdb97 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -67,7 +67,10 @@ struct Conv2DAttrs : public tvm::AttrsNode { .describe("Specifies the strides of the convolution."); TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); TVM_ATTR_FIELD(groups).set_default(1) @@ -138,7 +141,10 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { .describe("Specifies the strides of the convolution."); TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); TVM_ATTR_FIELD(groups).set_default(1) @@ -288,10 +294,17 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) .describe("The strides of the convolution."); TVM_ATTR_FIELD(output_padding).set_default(Array({0, 0})) - .describe("Zero-padding added to one side of the output."); + .describe("Zero-padding added to one side of the output." + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); TVM_ATTR_FIELD(groups).set_default(1) @@ -817,7 +830,10 @@ struct DeformableConv2DAttrs : public tvm::AttrsNode { .describe("Specifies the strides of the convolution."); TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); TVM_ATTR_FIELD(deformable_groups).set_default(1) diff --git a/python/tvm/contrib/pickle_memoize.py b/python/tvm/contrib/pickle_memoize.py index b5abf9b9b7a6..5c16419f8e14 100644 --- a/python/tvm/contrib/pickle_memoize.py +++ b/python/tvm/contrib/pickle_memoize.py @@ -84,7 +84,7 @@ def memoize(key, save_at_exit=False): """ def _register(f): """Registration function""" - allow_types = (string_types, int, float) + allow_types = (string_types, int, float, tuple) fkey = key + "." + f.__name__ + ".pkl" if fkey not in Cache.cache_by_key: Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 8a6e5b778283..dceadbf6dbe1 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -372,24 +372,7 @@ def _impl(inputs, attr, params): pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) - if opname != 'conv_transpose': - if attr['data_format'] == 'NHWC': - inputs_data = _op.nn.pad(data=inputs_data, - pad_width=((0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]), - (0, 0))) - else: - inputs_data = _op.nn.pad(data=inputs_data, - pad_width=((0, 0), - (0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]))) - - attr['padding'] = [0, 0] - else: - attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] - + attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: msg = 'Value {} in attribute "padding" of operator Conv is not ' \ 'valid.' diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index 09c060d02c25..a8f2e8618a75 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -26,6 +26,7 @@ #include #include +#include "../op_common.h" #include "../../pass/infer_layout_util.h" namespace tvm { @@ -134,10 +135,12 @@ bool BinaryConv2DRel(const Array& types, int num_inputs, const Attrs& attr CHECK(param->channels.defined()); CHECK(param->kernel_size.defined()); Array oshape({dshape_nchw[0], param->channels, 0, 0}); + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); oshape.Set( - 2, (dshape_nchw[2] + param->padding[0] * 2 - param->kernel_size[0]) / param->strides[0] + 1); + 2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1); oshape.Set( - 3, (dshape_nchw[3] + param->padding[1] * 2 - param->kernel_size[1]) / param->strides[1] + 1); + 3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1); DataType out_dtype = param->out_dtype; oshape = trans_in_layout.BackwardShape(oshape); // assign output type diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 85d8c0fdb02e..2a7d6b3f4d1f 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -166,7 +166,6 @@ with the layer input to produce a tensor of outputs. .add_type_rel("Conv3D", Conv3DRel) .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); - // relay.nn.conv2d_transpose TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); @@ -250,18 +249,8 @@ bool Conv2DTransposeRel(const Array& types, } // dilation Array oshape({dshape_nchw[0], channels, 0, 0}); - auto pad_h = param->padding[0]; - auto pad_w = param->padding[1]; - if (param->padding.size() == 2) { - pad_h *= 2; - pad_w *= 2; - } else if (param->padding.size() == 4) { - pad_h += param->padding[2]; - pad_w += param->padding[3]; - } else { - CHECK_EQ(param->padding.size(), 4) << " Padding should be 2 or 4, but got " - << param->padding.size(); - } + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + param->output_padding[0])); oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - @@ -557,14 +546,16 @@ bool Conv2DWinogradRel(const Array& types, // dilation Array oshape({dshape_nchw[0], channels, 0, 0}); + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 + oshape.Set(2, (dshape_nchw[2] + pad_h - dilated_ksize_y) / param->strides[0] + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 + oshape.Set(3, (dshape_nchw[3] + pad_w - dilated_ksize_x) / param->strides[1] + 1); } else { oshape.Set(3, dshape_nchw[3]); @@ -1015,9 +1006,11 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& // dilation Array oshape({data->shape[0], channels, 0, 0}); - oshape.Set(2, indexdiv(data->shape[2] + param->padding[0] * 2 - dilated_ksize_y, + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); - oshape.Set(3, indexdiv(data->shape[3] + param->padding[1] * 2 - dilated_ksize_x, + oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); DataType out_dtype = param->out_dtype; diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 0f4bb05883a0..913c5b01852a 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -117,15 +117,17 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, // dilation Array oshape({dshape_nchw[0], channels, 0, 0}); + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y, + oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x, + oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); } else { oshape.Set(3, dshape_nchw[3]); diff --git a/tests/python/contrib/test_nnpack.py b/tests/python/contrib/test_nnpack.py index 8bf99dff7daf..2ded24d646c4 100644 --- a/tests/python/contrib/test_nnpack.py +++ b/tests/python/contrib/test_nnpack.py @@ -17,6 +17,7 @@ import tvm import numpy as np import scipy.signal +from topi.nn.util import get_pad_tuple from tvm.contrib import nnpack import pytest @@ -59,17 +60,9 @@ def np_conv(na, nw, padding, stride=1): else: stride_h, stride_w = stride - if isinstance(padding, int): - pad_h = pad_w = padding * 2 - else: - pad_h, pad_w = padding - pad_h *= 2 - pad_w *= 2 - - pad_top = int(np.ceil(float(pad_h) / 2)) - pad_bottom = pad_h - pad_top - pad_left = int(np.ceil(float(pad_w) / 2)) - pad_right = pad_w - pad_left + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w)) + pad_h = pad_top + pad_bottom + pad_w = pad_left + pad_right out_channel = num_filter out_height = (in_height - kernel_h + pad_h) // stride_h + 1 @@ -78,9 +71,9 @@ def np_conv(na, nw, padding, stride=1): for n in range(batch): for f in range(out_channel): for c in range(in_channel): - if pad_h > 0: + if pad_h > 0 or pad_w > 0: apad = np.zeros((in_height + pad_h, in_width + pad_w)) - apad[pad_top:-pad_bottom, pad_left:-pad_right] = na[n, c] + apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = na[n, c] else: apad = na[n, c] out = scipy.signal.convolve2d( diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 2adb71848400..47179c9969eb 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -197,11 +197,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt CO *= VC KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) assert layout == 'NCHW' assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad") idxd = tvm.indexdiv idxm = tvm.indexmod @@ -214,8 +214,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt K = CO C = CI - H = (IH + 2 * HPAD - 3) // HSTR + 1 - W = (IW + 2 * WPAD - 3) // WSTR + 1 + H = (IH + pt + pb - 3) // HSTR + 1 + W = (IW + pl + pr - 3) // WSTR + 1 nH, nW = (H + m-1) // m, (W + m-1) // m P = N * nH * nW @@ -387,12 +387,13 @@ def conv2d_arm_cpu_winograd_nnpack( assert len(kernel.shape) == 4 CO, _, KH, KW = get_const_tuple(kernel.shape) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) assert layout == 'NCHW' - assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1 - H = (IH + 2 * HPAD - 3) // HSTR + 1 - W = (IW + 2 * WPAD - 3) // WSTR + 1 + assert KH == 3 and KW == 3 and pt == 1 and pb == 1 and pl == 1 and pr == 1 and HSTR == 1\ + and WSTR == 1 + H = (IH + pt + pb - 3) // HSTR + 1 + W = (IW + pl + pr - 3) // WSTR + 1 cfg.define_knob('winograd_nnpack_algorithm', [convolution_algorithm]) @@ -407,7 +408,7 @@ def conv2d_arm_cpu_winograd_nnpack( output = tvm.contrib.nnpack.convolution_inference_without_weight_transform( data, transformed_kernel, bias=None, - padding=[HPAD, HPAD, WPAD, WPAD], + padding=[pt, pb, pl, pr], stride=[HSTR, WSTR], algorithm=cfg['winograd_nnpack_algorithm'].val) @@ -467,13 +468,14 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides, assert len(transformed_kernel.shape) == 4 CO, _, _, _ = get_const_tuple(transformed_kernel.shape) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - HPAD, WPAD, _, _ = get_pad_tuple(padding, (3, 3)) KH, KW = 3, 3 + pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) assert layout == 'NCHW' - assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1 - H = (IH + 2 * HPAD - 3) // HSTR + 1 - W = (IW + 2 * WPAD - 3) // WSTR + 1 + assert KH == 3 and KW == 3 and pt == 1 and pb == 1 and pl == 1 and pr == 1 and HSTR == 1\ + and WSTR == 1 + H = (IH + pt + pb - 3) // HSTR + 1 + W = (IW + pl + pr - 3) // WSTR + 1 assert N == 1 with tvm.tag_scope("winograd_nnpack_conv2d_output"): @@ -481,7 +483,7 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides, data=data, transformed_kernel=transformed_kernel, bias=bias, - padding=[HPAD, HPAD, WPAD, WPAD], + padding=[pt, pb, pl, pr], stride=[HSTR, WSTR], algorithm=cfg['winograd_nnpack_algorithm'].val) diff --git a/topi/python/topi/bifrost/conv2d.py b/topi/python/topi/bifrost/conv2d.py index 1ed3f2c6e300..328139d458d6 100644 --- a/topi/python/topi/bifrost/conv2d.py +++ b/topi/python/topi/bifrost/conv2d.py @@ -276,11 +276,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt H_CAT, W_CAT, CO, CI = get_const_tuple(kernel.shape) KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) assert layout == 'NCHW' assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad") r = KW m = tile_size @@ -289,8 +289,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt K = CO C = CI - H = (IH + 2 * HPAD - 3) // HSTR + 1 - W = (IW + 2 * WPAD - 3) // WSTR + 1 + H = (IH + pt + pb - 3) // HSTR + 1 + W = (IW + pl + pr - 3) // WSTR + 1 nH, nW = (H + m-1) // m, (W + m-1) // m P = N * nH * nW diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index 929937c3ef17..3117a298830f 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -21,6 +21,7 @@ from tvm.contrib import cudnn from .. import nn, generic +from ..nn.util import get_pad_tuple from ..util import get_const_tuple, traverse_inline from .conv2d_direct import schedule_direct_cuda @@ -48,8 +49,10 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou strides : int or a list/tuple of two ints stride size, or [stride_height, stride_width] - padding : int or a list/tuple of two ints - padding size, or [pad_height, pad_width] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] @@ -80,11 +83,13 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou # handle dilation stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides - pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation - OH = (H + 2 * pad_h - KH) // stride_h + 1 - OW = (W + 2 * pad_w - KW) // stride_w + 1 + if isinstance(padding, (list, tuple)) and len(padding) > 2: + raise ValueError("Cudnn doesn't support asymmetric padding.") + pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) + OH = (H + pt + pb - KH) // stride_h + 1 + OW = (W + pl + pr - KW) // stride_w + 1 cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\ ((KW - 1) * dilation_w + 1)) @@ -97,7 +102,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou return cudnn.conv_forward(data, kernel, - [pad_h, pad_w], + [pt, pl], # cudnn padding pt, pl on both sides of input [stride_h, stride_w], [dilation_h, dilation_w], conv_mode=1, diff --git a/topi/python/topi/cuda/conv2d_winograd.py b/topi/python/topi/cuda/conv2d_winograd.py index 13bad7d39264..d9b948671729 100644 --- a/topi/python/topi/cuda/conv2d_winograd.py +++ b/topi/python/topi/cuda/conv2d_winograd.py @@ -64,15 +64,15 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty KH = KW = alpha + 1 - tile_size assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 - HPAD, WPAD, _, _ = nn.get_pad_tuple(padding, kernel) - data_pad = nn.pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad") + pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) + data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad") r = KW m = tile_size A, B, G = winograd_transform_matrices(m, r, out_dtype) - H = (H + 2 * HPAD - KH) // HSTR + 1 - W = (W + 2 * WPAD - KW) // WSTR + 1 + H = (H + pt + pb - KH) // HSTR + 1 + W = (W + pl + pr - KW) // WSTR + 1 nH, nW = (H + m-1) // m, (W + m-1) // m P = N * nH * nW diff --git a/topi/python/topi/intel_graphics/conv2d.py b/topi/python/topi/intel_graphics/conv2d.py index 56b63cb69017..8af78d03534e 100644 --- a/topi/python/topi/intel_graphics/conv2d.py +++ b/topi/python/topi/intel_graphics/conv2d.py @@ -83,10 +83,10 @@ def _create_schedule_template(cfg, data, kernel, strides, padding, dilation, lay else: raise ValueError("Not support this layout {} with " "schedule template.".format(layout)) - ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) + pt, pl, pb, pr = get_pad_tuple(padding, kernel) sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) - oh = (h - kh + 2 * ph) // sh + 1 - ow = (w - kw + 2 * pw) // sw + 1 + oh = (h - kh + pt + pb) // sh + 1 + ow = (w - kw + pl + pr) // sw + 1 ic_bn_upper = 32 oc_bn_upper = 64 oc_bn_lower = min(oc, 8) diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index 45882b7f41bc..ea4661f7602e 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -226,19 +226,19 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt CO *= VC KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1 HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) assert layout == 'NCHW' assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 - data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad") r = KW m = tile_size alpha = m + r - 1 A, B, G = winograd_transform_matrices(m, r, out_dtype) - H = (IH + 2 * HPAD - 3) // HSTR + 1 - W = (IW + 2 * WPAD - 3) // WSTR + 1 + H = (IH + pt + pb - 3) // HSTR + 1 + W = (IW + pl + pr - 3) // WSTR + 1 nH, nW = (H + m-1) // m, (W + m-1) // m P = N * nH * nW diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 5af30335a9c5..169878c11a85 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -23,7 +23,7 @@ from .pad import pad from .util import get_pad_tuple -from ..util import simplify, get_const_tuple +from ..util import simplify, get_const_tuple, get_const_int from .winograd_util import winograd_transform_matrices # workload description of conv2d @@ -46,8 +46,10 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N strides : int or a list/tuple of two ints stride size, or [stride_height, stride_width] - padding : int or a list/tuple of two ints - padding size, or [pad_height, pad_width] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] @@ -153,7 +155,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): else: KH, KW, CIG, CO = [x.value for x in kernel.shape] - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + HPAD, WPAD, _, _ = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW))) GRPS = CI // CIG if isinstance(stride, (tuple, list)): HSTR, WSTR = stride @@ -179,8 +181,10 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str - Padding size, or ['VALID', 'SAME'] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] @@ -221,7 +225,6 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): rc = tvm.reduce_axis((0, in_channel), name='rc') ry = tvm.reduce_axis((0, kernel_h), name='ry') rx = tvm.reduce_axis((0, kernel_w), name='rx') - return tvm.compute( (batch, out_channel, out_height, out_width), lambda nn, ff, yy, xx: tvm.sum( @@ -245,8 +248,10 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None): stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str - Padding size, or ['VALID', 'SAME'] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] @@ -311,8 +316,10 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str - Padding size, or ['VALID', 'SAME'] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] @@ -378,8 +385,10 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou stride : int or a list/tuple of two ints stride size, or [stride_height, stride_width] - padding : int or a list/tuple of two ints - padding size, or [pad_height, pad_width] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] @@ -425,8 +434,10 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l stride : int or a list/tuple of two ints stride size, or [stride_height, stride_width] - padding : int or a list/tuple of two ints - padding size, or [pad_height, pad_width] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] @@ -448,7 +459,6 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l # layout and out_layout are not used here, # we keep them for debug convenience when dumping autotvm workload - HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \ else (dilation, dilation) @@ -464,15 +474,22 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + HPAD = pad_top + pad_down + WPAD = pad_left + pad_right + # output shape - out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 - out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 + out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1 oshape = (n, oc_chunk, out_height, out_width, oc_bn) + pad_before = (0, 0, pad_top, pad_left, 0) + pad_after = (0, 0, pad_down, pad_right, 0) # DOPAD DOPAD = (HPAD != 0 or WPAD != 0) if DOPAD: - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") + data_pad = pad(data, pad_before, pad_after, name="data_pad") else: data_pad = data @@ -517,8 +534,10 @@ def conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, layout, out_layo stride : int or a list/tuple of two ints stride size, or [stride_height, stride_width] - padding : int or a list/tuple of two ints - padding size, or [pad_height, pad_width] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] @@ -565,8 +584,10 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, stride : int or a list/tuple of two ints stride size, or [stride_height, stride_width] - padding : int or a list/tuple of two ints - padding size, or [pad_height, pad_width] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] @@ -588,7 +609,6 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, # layout and out_layout are not used here, # we keep them for debug convenience when dumping autotvm workload - HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \ else (dilation, dilation) @@ -603,15 +623,23 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + HPAD = pad_top + pad_down + WPAD = pad_left + pad_right + # output shape - out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 - out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 + out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1 oshape = (n, oc_chunk, out_height, out_width, oc_bn) + pad_before = (0, 0, pad_top, pad_left, 0) + pad_after = (0, 0, pad_down, pad_right, 0) # DOPAD DOPAD = (HPAD != 0 or WPAD != 0) if DOPAD: - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") + data_pad = pad(data, pad_before, pad_after, name="data_pad") else: data_pad = data @@ -780,8 +808,10 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str - Padding size, or ['VALID', 'SAME'] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation : int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] diff --git a/topi/python/topi/rocm/conv2d.py b/topi/python/topi/rocm/conv2d.py index ce9e57e4061d..0a41838aa50e 100644 --- a/topi/python/topi/rocm/conv2d.py +++ b/topi/python/topi/rocm/conv2d.py @@ -23,6 +23,7 @@ from .. import nn, generic from ..util import get_const_tuple from ..cuda.conv2d import conv2d_cuda, schedule_conv2d_nchw_cuda +from ..nn.util import get_pad_tuple @autotvm.register_topi_compute(nn.conv2d, 'rocm', ['direct', 'winograd']) def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', out_dtype='float32'): @@ -42,8 +43,10 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou strides : int or a list/tuple of two ints stride size, or [stride_height, stride_width] - padding : int or a list/tuple of two ints - padding size, or [pad_height, pad_width] + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints layout : str layout of data @@ -62,7 +65,8 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou # handle dilation stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides - pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding + pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) + pad_h, pad_w = pt + pb, pl + pr dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation OH = (H + 2 * pad_h - KH) // stride_h + 1 diff --git a/topi/python/topi/testing/conv2d_hwcn_python.py b/topi/python/topi/testing/conv2d_hwcn_python.py index 07a771837810..489e7eb683df 100644 --- a/topi/python/topi/testing/conv2d_hwcn_python.py +++ b/topi/python/topi/testing/conv2d_hwcn_python.py @@ -18,6 +18,7 @@ """Convolution in python""" import numpy as np import scipy.signal +from topi.nn.util import get_pad_tuple def conv2d_hwcn_python(a_np, w_np, stride, padding): @@ -34,8 +35,10 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding): stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str - Padding size, or ['VALID', 'SAME'] + padding : int or str or a list/tuple of 2 or 4 ints + Padding size, or ['VALID', 'SAME'], or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 2 ints Returns ------- @@ -48,18 +51,10 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding): stride_h = stride_w = stride else: stride_h, stride_w = stride - if isinstance(padding, int): - pad_h = pad_w = padding * 2 - elif padding == 'VALID': - pad_h = 0 - pad_w = 0 - else: # 'SAME' - pad_h = kernel_h - 1 - pad_w = kernel_w - 1 - pad_top = int(np.ceil(float(pad_h) / 2)) - pad_bottom = pad_h - pad_top - pad_left = int(np.ceil(float(pad_w) / 2)) - pad_right = pad_w - pad_left + + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w)) + pad_h = pad_top + pad_bottom + pad_w = pad_left + pad_right # compute the output shape out_channel = num_filter out_height = (in_height - kernel_h + pad_h) // stride_h + 1 @@ -72,9 +67,9 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding): for n in range(batch): for f in range(out_channel): for c in range(in_channel): - if pad_h > 0: + if pad_h > 0 or pad_w > 0: apad = np.zeros((in_height + pad_h, in_width + pad_w)) - apad[pad_top:-pad_bottom, pad_left:-pad_right] = at[n, c] + apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = at[n, c] else: apad = at[n, c] out = scipy.signal.convolve2d( diff --git a/topi/python/topi/testing/conv2d_nchw_python.py b/topi/python/topi/testing/conv2d_nchw_python.py index c2cc021c6866..9f7ae7a62df1 100644 --- a/topi/python/topi/testing/conv2d_nchw_python.py +++ b/topi/python/topi/testing/conv2d_nchw_python.py @@ -18,6 +18,7 @@ """Convolution in python""" import numpy as np import scipy.signal +from topi.nn.util import get_pad_tuple def _conv2d_nchw_python(a_np, w_np, stride, padding): @@ -34,8 +35,10 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding): stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str or a list/tuple of two ints - Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width] + padding : int or str or a list/tuple of 2 or 4 ints + Padding size, or ['VALID', 'SAME'], or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 2 ints Returns ------- @@ -48,17 +51,9 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding): stride_h = stride_w = stride else: stride_h, stride_w = stride - if isinstance(padding, int): - pad_h = pad_w = padding * 2 - elif isinstance(padding, (list, tuple)): - pad_h, pad_w = padding[0] * 2, padding[1] * 2 - else: - pad_h = 0 if padding == 'VALID' else kernel_h - 1 - pad_w = 0 if padding == 'VALID' else kernel_w - 1 - pad_top = int(np.ceil(float(pad_h) / 2)) - pad_bottom = pad_h - pad_top - pad_left = int(np.ceil(float(pad_w) / 2)) - pad_right = pad_w - pad_left + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w)) + pad_h = pad_top + pad_bottom + pad_w = pad_left + pad_right # compute the output shape out_channel = num_filter out_height = (in_height - kernel_h + pad_h) // stride_h + 1 @@ -70,12 +65,7 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding): for c in range(in_channel): if pad_h > 0 or pad_w > 0: apad = np.zeros((in_height + pad_h, in_width + pad_w)) - if pad_h == 0: - apad[:, pad_left:-pad_right] = a_np[n, c] - elif pad_w == 0: - apad[pad_top:-pad_bottom, :] = a_np[n, c] - else: - apad[pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c] + apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = a_np[n, c] else: apad = a_np[n, c] out = scipy.signal.convolve2d( @@ -98,8 +88,10 @@ def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1): stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str or a list/tuple of two ints - Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width] + padding : int or str or a list/tuple of 2 or 4 ints + Padding size, or ['VALID', 'SAME'], or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 2 ints groups : int Number of groups diff --git a/topi/python/topi/testing/conv2d_nhwc_python.py b/topi/python/topi/testing/conv2d_nhwc_python.py index 8a6a467a80c4..dc5f915daa22 100644 --- a/topi/python/topi/testing/conv2d_nhwc_python.py +++ b/topi/python/topi/testing/conv2d_nhwc_python.py @@ -18,6 +18,7 @@ """Convolution in python""" import numpy as np import scipy.signal +from topi.nn.util import get_pad_tuple def conv2d_nhwc_python(a_np, w_np, stride, padding): @@ -34,8 +35,10 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str - Padding size, or ['VALID', 'SAME'] + padding : int or str or a list/tuple of 2 or 4 ints + Padding size, or ['VALID', 'SAME'], or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 2 ints Returns ------- @@ -48,18 +51,11 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): stride_h = stride_w = stride else: stride_h, stride_w = stride - if isinstance(padding, int): - pad_h = pad_w = padding * 2 - elif padding == 'VALID': - pad_h = 0 - pad_w = 0 - else: # 'SAME' - pad_h = kernel_h - 1 - pad_w = kernel_w - 1 - pad_top = int(np.ceil(float(pad_h) / 2)) - pad_bottom = pad_h - pad_top - pad_left = int(np.ceil(float(pad_w) / 2)) - pad_right = pad_w - pad_left + + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w)) + pad_h = pad_top + pad_bottom + pad_w = pad_left + pad_right + # compute the output shape out_channel = num_filter out_height = (in_height - kernel_h + pad_h) // stride_h + 1 @@ -72,9 +68,9 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): for n in range(batch): for f in range(out_channel): for c in range(in_channel): - if pad_h > 0: + if pad_h > 0 or pad_w > 0: apad = np.zeros((in_height + pad_h, in_width + pad_w)) - apad[pad_top:-pad_bottom, pad_left:-pad_right] = at[n, c] + apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = at[n, c] else: apad = at[n, c] out = scipy.signal.convolve2d( diff --git a/topi/python/topi/testing/deformable_conv2d_nchw_python.py b/topi/python/topi/testing/deformable_conv2d_nchw_python.py index 7e179db4c87d..80e2a18250ce 100644 --- a/topi/python/topi/testing/deformable_conv2d_nchw_python.py +++ b/topi/python/topi/testing/deformable_conv2d_nchw_python.py @@ -18,7 +18,7 @@ """Deformable convolution in python""" import itertools import numpy as np - +from topi.nn.util import get_pad_tuple def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilation, deformable_groups, groups): @@ -39,8 +39,10 @@ def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilati stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str or a list/tuple of two ints - Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width] + padding : int or str or a list/tuple of 2 or 4 ints + Padding size, or ['VALID', 'SAME'], or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 2 ints dilation : int or a list/tuple of two ints Dilation size, or [dilate_height, dilate_width] @@ -67,15 +69,9 @@ def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilati stride_h = stride_w = stride else: stride_h, stride_w = stride - if isinstance(padding, int): - pad_h = pad_w = padding * 2 - elif isinstance(padding, (list, tuple)): - pad_h, pad_w = padding[0] * 2, padding[1] * 2 - else: - pad_h = 0 if padding == 'VALID' else kernel_h - 1 - pad_w = 0 if padding == 'VALID' else kernel_w - 1 - pad_top = int(np.ceil(float(pad_h) / 2)) - pad_left = int(np.ceil(float(pad_w) / 2)) + + pad_top, pad_left, _, _ = get_pad_tuple(padding, (kernel_h, kernel_w)) + if isinstance(dilation, int): dilation_h = dilation_w = dilation else: diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 0e284da17ee6..8a6b57eb9e66 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -30,6 +30,7 @@ conv2d_infer_layout, _get_workload as _get_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.pad import pad +from ..nn.util import get_pad_tuple from ..util import get_const_tuple from . import conv2d_avx_1x1, conv2d_avx_common @@ -84,10 +85,10 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): "schedule template.".format(layout)) is_kernel_1x1 = kh == 1 and kw == 1 - ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) + pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw)) sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) - oh = (h - kh + 2 * ph) // sh + 1 - ow = (w - kw + 2 * pw) // sw + 1 + oh = (h - kh + pt + pb) // sh + 1 + ow = (w - kw + pl + pr) // sw + 1 # Create schedule config cfg.define_split("tile_ic", ic, num_outputs=2) @@ -102,7 +103,6 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): @autotvm.register_topi_compute(conv2d, 'cpu', ['direct']) def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): out_dtype = data.dtype if out_dtype is None else out_dtype - padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) @@ -141,24 +141,27 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout else: dilation_h, dilation_w = dilation - HPAD, WPAD = padding HSTR, WSTR = strides - batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape) num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape) - pad_height = in_height + 2 * HPAD - pad_width = in_width + 2 * WPAD + pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_height, kernel_width)) + pad_h = pad_top + pad_down + pad_w = pad_left + pad_right + + pad_height = in_height + pad_h + pad_width = in_width + pad_w dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 - out_height = (in_height + 2 * HPAD - dilated_kernel_h) // HSTR + 1 - out_width = (in_width + 2 * WPAD - dilated_kernel_w) // WSTR + 1 + out_height = (in_height + pad_h - dilated_kernel_h) // HSTR + 1 + out_width = (in_width + pad_w - dilated_kernel_w) // WSTR + 1 # pack data - DOPAD = (HPAD != 0 or WPAD != 0) + DOPAD = (pad_h != 0 or pad_w != 0) if DOPAD: - data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + data_pad = pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right), \ + name="data_pad") else: data_pad = data @@ -353,8 +356,9 @@ def _conv2d_infer_layout(workload, cfg): out_channel, _, k_height, k_width = kernel[:-1] idxdiv = tvm.indexdiv - out_height = idxdiv(in_height + 2 * padding[0] - k_height, strides[0]) + 1 - out_width = idxdiv(in_width + 2 * padding[1] - k_width, strides[1]) + 1 + pt, pl, pb, pr = get_pad_tuple(padding, (k_height, k_width)) + out_height = idxdiv(in_height + pt + pb - k_height, strides[0]) + 1 + out_width = idxdiv(in_width + pl + pr - k_width, strides[1]) + 1 tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic) in_layout = "NCHW%dc" % tile_ic diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index 1332c687a301..60d632b54a27 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -28,6 +28,7 @@ from ..nn import conv2d_legalize from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, conv2d_alter_layout from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw +from ..nn.util import get_pad_tuple logger = logging.getLogger('topi') @@ -227,12 +228,14 @@ def _conv2d_legalize(attrs, inputs, arg_types): if data_tensor.dtype == 'int8' and kernel_tensor.dtype == 'int8': is_int8_inputs = True padding = attrs.get_int_tuple("padding") + kh, kw = attrs.get_int_tuple("kernel_size") + pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw)) if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO': adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(0, 1, 2)) - pad_width = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0)) + pad_width = ((0, 0), (pt, pb), (pl, pr), (0, 0)) elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW': - pad_width = ((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1])) + pad_width = ((0, 0), (0, 0), (pt, pb), (pl, pr)) adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(1, 2, 3)) adjust_shift = relay.expand_dims(adjust_shift, axis=1, num_newaxis=2) else: diff --git a/topi/python/topi/x86/conv2d_int8.py b/topi/python/topi/x86/conv2d_int8.py index df53850ec603..cb23eec0cd48 100644 --- a/topi/python/topi/x86/conv2d_int8.py +++ b/topi/python/topi/x86/conv2d_int8.py @@ -25,6 +25,7 @@ from ..nn.conv2d import _get_workload as _get_conv2d_workload from .. import generic, tag from ..generic import conv2d as conv2d_generic +from ..nn.util import get_pad_tuple from ..util import get_const_tuple from ..nn.conv2d import conv2d_NCHWc_int8 from .. import nn @@ -92,10 +93,10 @@ def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, lay "schedule template.".format(layout)) is_kernel_1x1 = kh == 1 and kw == 1 - ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) + pt, pl, pb, pr = get_pad_tuple(padding, kernel) sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) - oh = (h - kh + 2 * ph) // sh + 1 - ow = (w - kw + 2 * pw) // sw + 1 + oh = (h - kh + pt + pb) // sh + 1 + ow = (w - kw + pl + pr) // sw + 1 # Create schedule config cfg.define_split('tile_ic', ic, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0) diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index 8af41dacf293..385537b95e4d 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -204,10 +204,10 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs): batch, in_channel, height, width = get_const_tuple(data.shape) filter_channel, channel_multiplier, kh, kw = get_const_tuple(kernel.shape) - ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) + pt, pl, pb, pr = get_pad_tuple(padding, kernel) sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) - out_height = (height - kh + 2 * ph) // sh + 1 - out_width = (width - kw + 2 * pw) // sw + 1 + out_height = (height - kh + pt + pb) // sh + 1 + out_width = (width - kw + pl + pr) // sw + 1 out_channel = filter_channel * channel_multiplier # get config here diff --git a/topi/tests/python/test_topi_conv2d_NCHWc.py b/topi/tests/python/test_topi_conv2d_NCHWc.py index 26b4642bd333..7c8595d49cc1 100644 --- a/topi/tests/python/test_topi_conv2d_NCHWc.py +++ b/topi/tests/python/test_topi_conv2d_NCHWc.py @@ -22,6 +22,7 @@ import topi import topi.testing from tvm.contrib.pickle_memoize import memoize +from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple from common import get_all_backend @@ -49,10 +50,11 @@ def _transform_bias(bias, bn): def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"): - print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % - (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right in_height = in_width = in_size + print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % + (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum)) # for testing functionality, # we choose arbitrary block size that can divide the channel, @@ -96,7 +98,7 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding), + C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), padding, (dilation, dilation), layout='NCHW%dc'%ic_block, out_layout="NCHW%dc"%oc_block, @@ -114,12 +116,12 @@ def check_device(device): if add_bias: func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % - (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % - (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3) @@ -217,5 +219,22 @@ def test_conv2d_NCHWc(): verify_conv2d_NCHWc(1, 512, 5, 126, 3, 1, 1) verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1) + # Asymmetric padding + verify_conv2d_NCHWc(1, 3, 224, 64, 7, 2, (0, 0, 1, 1)) + verify_conv2d_NCHWc(1, 64, 56, 128, 3, 1, (3, 3, 2, 2)) + verify_conv2d_NCHWc(1, 64, 56, 64, 1, 1, (1, 2, 2, 1)) + verify_conv2d_NCHWc(1, 64, 288, 192, 1, 1, (1, 2)) + verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, (3, 1)) + verify_conv2d_NCHWc(1, 128, 56, 384, 3, 1, (0, 2)) + verify_conv2d_NCHWc(1, 64, 56, 64, 1, 1, "VALID") + verify_conv2d_NCHWc(1, 388, 56, 64, 3, 1, "VALID") + verify_conv2d_NCHWc(1, 512, 19, 64, 1, 1, "SAME") + verify_conv2d_NCHWc(1, 64, 2048, 32, 2, 1, "SAME") + verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, (1, 2, 2, 1), add_relu=True) + verify_conv2d_NCHWc(1, 64, 56, 64, 5, 2, (1, 3), add_bias=True) + verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True) + verify_conv2d_NCHWc(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True) + + if __name__ == "__main__": test_conv2d_NCHWc() diff --git a/topi/tests/python/test_topi_conv2d_int8.py b/topi/tests/python/test_topi_conv2d_int8.py index 09adbcecefc3..5b8a2eb6e1f1 100644 --- a/topi/tests/python/test_topi_conv2d_int8.py +++ b/topi/tests/python/test_topi_conv2d_int8.py @@ -23,6 +23,7 @@ import topi import topi.testing from tvm.contrib.pickle_memoize import memoize +from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple from common import get_all_backend, Int8Fallback @@ -31,7 +32,9 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): - print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) in_height = in_width = in_size @@ -79,7 +82,7 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): - C = topi.nn.conv2d(A, W, (stride, stride), (padding, padding), (dilation, dilation), + C = topi.nn.conv2d(A, W, (stride, stride), padding, (dilation, dilation), layout='NCHW', out_dtype=dtype) if add_bias: C = topi.add(C, bias) @@ -92,11 +95,11 @@ def check_device(device): b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) if add_bias: - tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) - func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: - func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) @@ -184,5 +187,22 @@ def test_conv2d_nchw(): verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0) verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0) + # Asymmetric padding + verify_conv2d_NCHWc_int8(1, 32, 224, 64, 7, 2, (0, 0, 1, 1)) + verify_conv2d_NCHWc_int8(1, 64, 56, 128, 3, 1, (3, 3, 2, 2)) + verify_conv2d_NCHWc_int8(1, 64, 56, 64, 1, 1, (1, 2, 2, 1)) + verify_conv2d_NCHWc_int8(1, 64, 288, 192, 1, 1, (1, 2)) + verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, (3, 1)) + verify_conv2d_NCHWc_int8(1, 128, 56, 384, 3, 1, (0, 2)) + verify_conv2d_NCHWc_int8(1, 64, 56, 64, 1, 1, "VALID") + verify_conv2d_NCHWc_int8(1, 388, 56, 64, 3, 1, "VALID") + verify_conv2d_NCHWc_int8(1, 512, 19, 64, 1, 1, "SAME") + verify_conv2d_NCHWc_int8(1, 64, 2048, 32, 2, 1, "SAME") + verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, (1, 2, 2, 1), add_relu=True) + verify_conv2d_NCHWc_int8(1, 64, 56, 64, 5, 2, (1, 3), add_bias=True) + verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True) + verify_conv2d_NCHWc_int8(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True) + + if __name__ == "__main__": test_conv2d_nchw() diff --git a/topi/tests/python/test_topi_conv2d_nchw.py b/topi/tests/python/test_topi_conv2d_nchw.py index d7c39a9cc016..7c8041b87d6f 100644 --- a/topi/tests/python/test_topi_conv2d_nchw.py +++ b/topi/tests/python/test_topi_conv2d_nchw.py @@ -22,12 +22,17 @@ import topi import topi.testing from tvm.contrib.pickle_memoize import memoize +from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple from common import get_all_backend -def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): - print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) +def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,\ + use_cudnn=False): + + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) in_height = in_width = in_size @@ -62,7 +67,7 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - C = topi.nn.conv2d(A, W, (stride, stride), (padding, padding), + C = topi.nn.conv2d(A, W, (stride, stride), padding, (dilation, dilation), layout='NCHW', out_dtype=dtype) if add_bias: C = topi.add(C, bias) @@ -75,10 +80,10 @@ def check_device(device): b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) if add_bias: - func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: - func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4) @@ -86,6 +91,9 @@ def check_device(device): with autotvm.tophub.context(device): # load tophub pre-tuned parameters check_device(device) + if use_cudnn: + check_device("cuda -model=unknown -libs=cudnn") + def test_conv2d_nchw(): # ResNet18 workloads @@ -176,6 +184,25 @@ def test_conv2d_nchw(): verify_conv2d_nchw(1, 512, 5, 126, 3, 1, 1) verify_conv2d_nchw(1, 256, 3, 126, 3, 1, 1) + # Asymmetric padding + verify_conv2d_nchw(1, 3, 224, 64, 7, 2, (0, 0, 1, 1)) + verify_conv2d_nchw(1, 64, 56, 128, 3, 1, (3, 3, 2, 2)) + verify_conv2d_nchw(1, 64, 56, 64, 1, 1, (1, 2, 2, 1)) + verify_conv2d_nchw(1, 64, 288, 192, 1, 1, (1, 2)) + verify_conv2d_nchw(1, 64, 56, 64, 3, 1, (3, 1)) + verify_conv2d_nchw(1, 128, 56, 384, 3, 1, (0, 2)) + verify_conv2d_nchw(1, 64, 384, 64, 3, 1, (1, 2), use_cudnn=True) + verify_conv2d_nchw(1, 64, 56, 64, 1, 1, "VALID") + verify_conv2d_nchw(1, 388, 56, 64, 3, 1, "VALID") + verify_conv2d_nchw(1, 64, 1280, 48, 3, 1, "VALID", use_cudnn=True) + verify_conv2d_nchw(1, 512, 19, 64, 1, 1, "SAME") + verify_conv2d_nchw(1, 64, 2048, 32, 2, 1, "SAME") + verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "SAME", use_cudnn=True) + verify_conv2d_nchw(1, 64, 56, 64, 3, 1, (1, 2, 2, 1), add_relu=True) + verify_conv2d_nchw(1, 64, 56, 64, 5, 2, (1, 3), add_bias=True) + verify_conv2d_nchw(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True) + verify_conv2d_nchw(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True) + if __name__ == "__main__": test_conv2d_nchw() diff --git a/topi/tests/python/test_topi_conv2d_nhwc.py b/topi/tests/python/test_topi_conv2d_nhwc.py index d53748c590c7..8c6e0090640c 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc.py +++ b/topi/tests/python/test_topi_conv2d_nhwc.py @@ -71,8 +71,13 @@ def test_conv2d_nhwc(): verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID") verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID") verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID") + verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (0, 0, 1, 1)) + verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (1, 1, 2, 2)) + verify_conv2d_nhwc(1, 128, 16, 128, 5, 2, (3, 3, 2, 2)) + verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (0, 1, 2, 3)) # dilation = 2 verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME", dilation=2) + verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, (1, 1, 2, 2), dilation=2) if __name__ == "__main__": diff --git a/topi/tests/python/test_topi_conv2d_winograd.py b/topi/tests/python/test_topi_conv2d_winograd.py index 5974dad20f88..548aea909cae 100644 --- a/topi/tests/python/test_topi_conv2d_winograd.py +++ b/topi/tests/python/test_topi_conv2d_winograd.py @@ -23,12 +23,15 @@ import topi import topi.testing from tvm.contrib.pickle_memoize import memoize +from topi.nn.util import get_pad_tuple from topi.util import get_const_tuple def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False, devices=['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']): - print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) in_height = in_width = in_size @@ -76,14 +79,13 @@ def check_device(device): b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) if add_bias: - func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, b, c) else: - func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func(a, w, c) rtol = 1e-3 - tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) @@ -133,5 +135,20 @@ def test_conv2d_nchw(): verify_conv2d_nchw(3, 3, 3, 3, 3, 1, 1) verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1) + # Asymmetric padding + verify_conv2d_nchw(1, 64, 56, 64, 3, 1, (1, 1, 1, 1)) + verify_conv2d_nchw(1, 128, 28, 128, 3, 1, (1, 1, 1, 1)) + verify_conv2d_nchw(1, 256, 14, 256, 3, 1, (1, 1)) + verify_conv2d_nchw(1, 512, 7, 512, 3, 1, "SAME") + verify_conv2d_nchw(2, 13, 71, 59, 3, 1, (1, 1, 1, 1)) + verify_conv2d_nchw(2, 64, 56, 64, 3, 1, (1, 1, 1, 1), add_bias=True) + verify_conv2d_nchw(2, 64, 56, 64, 3, 1, (1, 1), add_relu=True) + verify_conv2d_nchw(2, 64, 56, 64, 3, 1, "SAME", add_relu=True, add_bias=True) + verify_conv2d_nchw(1, 128, 17, 192, 7, 1, (3, 1), devices=['cuda']) + verify_conv2d_nchw(1, 128, 17, 128, 7, 1, (3, 3, 2, 2), devices=['cuda']) + verify_conv2d_nchw(1, 160, 17, 160, 7, 1, "SAME", devices=['cuda']) + verify_conv2d_nchw(1, 48, 35, 64, 5, 1, "VALID", devices=['cuda']) + + if __name__ == "__main__": test_conv2d_nchw() diff --git a/topi/tests/python/test_topi_conv3d_ncdhw.py b/topi/tests/python/test_topi_conv3d_ncdhw.py index 681190633d67..92b1068a11ec 100644 --- a/topi/tests/python/test_topi_conv3d_ncdhw.py +++ b/topi/tests/python/test_topi_conv3d_ncdhw.py @@ -90,7 +90,6 @@ def check_device(device): with autotvm.tophub.context(device): # load tophub pre-tuned parameters check_device(device) - def test_conv3d_ncdhw(): #3DCNN workloads verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, 0) @@ -122,6 +121,5 @@ def test_conv3d_ncdhw(): verify_conv3d_ncdhw(1, 32, 32, 1, 3, 1, "VALID") verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, "VALID") - if __name__ == "__main__": test_conv3d_ncdhw()