Skip to content

Commit

Permalink
[CONV] Asymmetric padding (apache#4511)
Browse files Browse the repository at this point in the history
* [CONV] Asymmetic padding

* fix lint error

* update for legalize, rocm and cudnn

* add more test cases

* change more symmetric padding

* change conv2d winograd tests according orginal cases

* remove 'alter_op_layout.h' header in bitserial.cc
  • Loading branch information
optima2005 authored and alexwong committed Feb 28, 2020
1 parent 7cceff6 commit 34da389
Show file tree
Hide file tree
Showing 29 changed files with 338 additions and 234 deletions.
26 changes: 21 additions & 5 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
Expand Down Expand Up @@ -138,7 +141,10 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
Expand Down Expand Up @@ -288,10 +294,17 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("The strides of the convolution.");
TVM_ATTR_FIELD(output_padding).set_default(Array<IndexExpr>({0, 0}))
.describe("Zero-padding added to one side of the output.");
.describe("Zero-padding added to one side of the output."
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
Expand Down Expand Up @@ -817,7 +830,10 @@ struct DeformableConv2DAttrs : public tvm::AttrsNode<DeformableConv2DAttrs> {
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(deformable_groups).set_default(1)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/pickle_memoize.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def memoize(key, save_at_exit=False):
"""
def _register(f):
"""Registration function"""
allow_types = (string_types, int, float)
allow_types = (string_types, int, float, tuple)
fkey = key + "." + f.__name__ + ".pkl"
if fkey not in Cache.cache_by_key:
Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit)
Expand Down
19 changes: 1 addition & 18 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,24 +372,7 @@ def _impl(inputs, attr, params):
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)

if opname != 'conv_transpose':
if attr['data_format'] == 'NHWC':
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))

attr['padding'] = [0, 0]
else:
attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]

attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' \
'valid.'
Expand Down
7 changes: 5 additions & 2 deletions src/relay/op/nn/bitserial.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/relay/attrs/bitserial.h>
#include <tvm/relay/op.h>

#include "../op_common.h"
#include "../../pass/infer_layout_util.h"

namespace tvm {
Expand Down Expand Up @@ -134,10 +135,12 @@ bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
CHECK(param->channels.defined());
CHECK(param->kernel_size.defined());
Array<IndexExpr> oshape({dshape_nchw[0], param->channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(
2, (dshape_nchw[2] + param->padding[0] * 2 - param->kernel_size[0]) / param->strides[0] + 1);
2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1);
oshape.Set(
3, (dshape_nchw[3] + param->padding[1] * 2 - param->kernel_size[1]) / param->strides[1] + 1);
3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1);
DataType out_dtype = param->out_dtype;
oshape = trans_in_layout.BackwardShape(oshape);
// assign output type
Expand Down
27 changes: 10 additions & 17 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ with the layer input to produce a tensor of outputs.
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);


// relay.nn.conv2d_transpose
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);

Expand Down Expand Up @@ -250,18 +249,8 @@ bool Conv2DTransposeRel(const Array<Type>& types,
}
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
auto pad_h = param->padding[0];
auto pad_w = param->padding[1];
if (param->padding.size() == 2) {
pad_h *= 2;
pad_w *= 2;
} else if (param->padding.size() == 4) {
pad_h += param->padding[2];
pad_w += param->padding[3];
} else {
CHECK_EQ(param->padding.size(), 4) << " Padding should be 2 or 4, but got "
<< param->padding.size();
}
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
pad_h + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
Expand Down Expand Up @@ -557,14 +546,16 @@ bool Conv2DWinogradRel(const Array<Type>& types,
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});

IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<ir::Any>()) {
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2
oshape.Set(2, (dshape_nchw[2] + pad_h
- dilated_ksize_y) / param->strides[0] + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}
if (!dshape_nchw[3].as<ir::Any>()) {
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2
oshape.Set(3, (dshape_nchw[3] + pad_w
- dilated_ksize_x) / param->strides[1] + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
Expand Down Expand Up @@ -1015,9 +1006,11 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
// dilation
Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});

oshape.Set(2, indexdiv(data->shape[2] + param->padding[0] * 2 - dilated_ksize_y,
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y,
param->strides[0]) + 1);
oshape.Set(3, indexdiv(data->shape[3] + param->padding[1] * 2 - dilated_ksize_x,
oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x,
param->strides[1]) + 1);
DataType out_dtype = param->out_dtype;

Expand Down
6 changes: 4 additions & 2 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,17 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});

IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<ir::Any>()) {
oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y,
oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}

if (!dshape_nchw[3].as<ir::Any>()) {
oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x,
oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
Expand Down
19 changes: 6 additions & 13 deletions tests/python/contrib/test_nnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
import numpy as np
import scipy.signal
from topi.nn.util import get_pad_tuple
from tvm.contrib import nnpack
import pytest

Expand Down Expand Up @@ -59,17 +60,9 @@ def np_conv(na, nw, padding, stride=1):
else:
stride_h, stride_w = stride

if isinstance(padding, int):
pad_h = pad_w = padding * 2
else:
pad_h, pad_w = padding
pad_h *= 2
pad_w *= 2

pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w))
pad_h = pad_top + pad_bottom
pad_w = pad_left + pad_right

out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
Expand All @@ -78,9 +71,9 @@ def np_conv(na, nw, padding, stride=1):
for n in range(batch):
for f in range(out_channel):
for c in range(in_channel):
if pad_h > 0:
if pad_h > 0 or pad_w > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w))
apad[pad_top:-pad_bottom, pad_left:-pad_right] = na[n, c]
apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = na[n, c]
else:
apad = na[n, c]
out = scipy.signal.convolve2d(
Expand Down
30 changes: 16 additions & 14 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
CO *= VC
KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))

assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")

idxd = tvm.indexdiv
idxm = tvm.indexmod
Expand All @@ -214,8 +214,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
K = CO
C = CI

H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
H = (IH + pt + pb - 3) // HSTR + 1
W = (IW + pl + pr - 3) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW

Expand Down Expand Up @@ -387,12 +387,13 @@ def conv2d_arm_cpu_winograd_nnpack(
assert len(kernel.shape) == 4
CO, _, KH, KW = get_const_tuple(kernel.shape)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))

assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
assert KH == 3 and KW == 3 and pt == 1 and pb == 1 and pl == 1 and pr == 1 and HSTR == 1\
and WSTR == 1
H = (IH + pt + pb - 3) // HSTR + 1
W = (IW + pl + pr - 3) // WSTR + 1

cfg.define_knob('winograd_nnpack_algorithm', [convolution_algorithm])

Expand All @@ -407,7 +408,7 @@ def conv2d_arm_cpu_winograd_nnpack(
output = tvm.contrib.nnpack.convolution_inference_without_weight_transform(
data, transformed_kernel,
bias=None,
padding=[HPAD, HPAD, WPAD, WPAD],
padding=[pt, pb, pl, pr],
stride=[HSTR, WSTR],
algorithm=cfg['winograd_nnpack_algorithm'].val)

Expand Down Expand Up @@ -467,21 +468,22 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides,
assert len(transformed_kernel.shape) == 4
CO, _, _, _ = get_const_tuple(transformed_kernel.shape)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, (3, 3))
KH, KW = 3, 3
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))

assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
assert KH == 3 and KW == 3 and pt == 1 and pb == 1 and pl == 1 and pr == 1 and HSTR == 1\
and WSTR == 1
H = (IH + pt + pb - 3) // HSTR + 1
W = (IW + pl + pr - 3) // WSTR + 1

assert N == 1
with tvm.tag_scope("winograd_nnpack_conv2d_output"):
output = tvm.contrib.nnpack.convolution_inference_without_weight_transform(
data=data,
transformed_kernel=transformed_kernel,
bias=bias,
padding=[HPAD, HPAD, WPAD, WPAD],
padding=[pt, pb, pl, pr],
stride=[HSTR, WSTR],
algorithm=cfg['winograd_nnpack_algorithm'].val)

Expand Down
8 changes: 4 additions & 4 deletions topi/python/topi/bifrost/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
H_CAT, W_CAT, CO, CI = get_const_tuple(kernel.shape)
KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))

assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")

r = KW
m = tile_size
Expand All @@ -289,8 +289,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt

K = CO
C = CI
H = (IH + 2 * HPAD - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1
H = (IH + pt + pb - 3) // HSTR + 1
W = (IW + pl + pr - 3) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW

Expand Down
17 changes: 11 additions & 6 deletions topi/python/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm.contrib import cudnn

from .. import nn, generic
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, traverse_inline

from .conv2d_direct import schedule_direct_cuda
Expand Down Expand Up @@ -48,8 +49,10 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
padding : int or a list/tuple of 2 or 4 ints
padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
Expand Down Expand Up @@ -80,11 +83,13 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou

# handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding
dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation

OH = (H + 2 * pad_h - KH) // stride_h + 1
OW = (W + 2 * pad_w - KW) // stride_w + 1
if isinstance(padding, (list, tuple)) and len(padding) > 2:
raise ValueError("Cudnn doesn't support asymmetric padding.")
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
OH = (H + pt + pb - KH) // stride_h + 1
OW = (W + pl + pr - KW) // stride_w + 1
cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
((KW - 1) * dilation_w + 1))

Expand All @@ -97,7 +102,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou

return cudnn.conv_forward(data,
kernel,
[pad_h, pad_w],
[pt, pl], # cudnn padding pt, pl on both sides of input
[stride_h, stride_w],
[dilation_h, dilation_w],
conv_mode=1,
Expand Down
Loading

0 comments on commit 34da389

Please sign in to comment.