Skip to content

Commit

Permalink
[AVG POOL] Asymmetric padding (SAME) support. (apache#1346)
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 authored and tqchen committed Jul 4, 2018
1 parent 4f184db commit f9cb969
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 85 deletions.
10 changes: 8 additions & 2 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@ struct MaxPool2DParam : public dmlc::Parameter<MaxPool2DParam> {
.describe("Specifies the strides of the convolution.");
DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
Expand All @@ -266,7 +269,10 @@ struct AvgPool2DParam : public dmlc::Parameter<AvgPool2DParam> {
.describe("Specifies the strides of the convolution.");
DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
Expand Down
17 changes: 4 additions & 13 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,13 @@ def _impl(inputs, attr, params):
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)

if attr['data_format'] == 'NHWC':
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs[0] = _sym.pad(data=inputs[0],
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
attr['padding'] = [0, 0]
attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
else:
raise TypeError("Unsupported padding type : {}".format(attr['padding']))

if name == "avg_pool":
attr['count_include_pad'] = False

return AttrCvt(
op_name=_dimension_picker(name),
transforms={
Expand Down
69 changes: 57 additions & 12 deletions nnvm/src/top/nn/pooling.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

/*!
* Copyright (c) 2017 by Contributors
* \file pooling.cc
Expand Down Expand Up @@ -44,23 +45,39 @@ inline bool Pool2DInferShape(const nnvm::NodeAttrs& attrs,
const auto hidx = layout.indexof('H');
const auto widx = layout.indexof('W');

dim_t pad_h, pad_w;
if (param.padding.ndim() == 1) {
pad_h = param.padding[0] * 2;
pad_w = param.padding[0] * 2;
} else if (param.padding.ndim() == 2) {
// (top, left)
pad_h = param.padding[0] * 2;
pad_w = param.padding[1] * 2;
} else if (param.padding.ndim() == 4) {
// (top, left, bottom, right)
pad_h = param.padding[0] + param.padding[2];
pad_w = param.padding[1] + param.padding[3];
} else {
return false;
}

TShape oshape = dshape;
CHECK(param.pool_size[0] <= dshape[hidx] + 2 * param.padding[0])
CHECK(param.pool_size[0] <= dshape[hidx] + pad_h)
<< "pool size (" << param.pool_size[0] << ") exceeds input (" << dshape[hidx]
<< " padded to " << (dshape[hidx] + 2*param.padding[0]) << ")";
CHECK(param.pool_size[1] <= dshape[widx] + 2 * param.padding[1])
<< " padded to " << (dshape[hidx] + pad_h) << ")";
CHECK(param.pool_size[1] <= dshape[widx] + pad_w)
<< "pool size (" << param.pool_size[1] << ") exceeds input (" << dshape[widx]
<< " padded to " << (dshape[widx] + 2*param.padding[1]) << ")";
<< " padded to " << (dshape[widx] + pad_w) << ")";

if (!param.ceil_mode) {
oshape[hidx] = ((dshape[hidx] + 2 * param.padding[0] - param.pool_size[0]) /
oshape[hidx] = ((dshape[hidx] + pad_h - param.pool_size[0]) /
param.strides[0]) + 1;
oshape[widx] = ((dshape[widx] + 2 * param.padding[1] - param.pool_size[1]) /
oshape[widx] = ((dshape[widx] + pad_w - param.pool_size[1]) /
param.strides[1]) + 1;
} else {
oshape[hidx] = ((dshape[hidx] + 2 * param.padding[0] - param.pool_size[0] +
oshape[hidx] = ((dshape[hidx] + pad_h - param.pool_size[0] +
param.strides[0] - 1) / param.strides[0]) + 1;
oshape[widx] = ((dshape[3] + 2 * param.padding[1] - param.pool_size[1] +
oshape[widx] = ((dshape[3] + pad_w - param.pool_size[1] +
param.strides[1] - 1) / param.strides[1]) + 1;
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
Expand Down Expand Up @@ -108,8 +125,13 @@ NNVM_REGISTER_OP(max_pool2d)
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are calculated as::
out_height = floor((height+2*padding[0]-pool_size[0])/strides[0])+1
out_width = floor((width+2*padding[1]-pool_size[1])/strides[1])+1
out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int : bottom, right use same as top and left.
four int: padding width in the order of (top, left, bottom, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
Expand Down Expand Up @@ -143,6 +165,15 @@ NNVM_REGISTER_OP(max_pool2d)
<< "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";

if (param.padding.ndim() == 1) {
padding.push_back(padding[0]);
padding.push_back(padding[0]);
padding.push_back(padding[0]);
} else if (param.padding.ndim() == 2) {
padding.push_back(padding[0]);
padding.push_back(padding[1]);
}

return Array<Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding,
topi::nn::kMaxPool, ceil_mode, layout.name())};
Expand Down Expand Up @@ -182,8 +213,13 @@ NNVM_REGISTER_OP(avg_pool2d)
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are calculated as::
out_height = floor((height+2*padding[0]-pool_size[0])/strides[0])+1
out_width = floor((width+2*padding[1]-pool_size[1])/strides[1])+1
out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1
out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1
where padding will be an expanded array based on number of values passed as::
one int : all sides same padding used.
two int : bottom, right use same as top and left.
four int: padding width in the order of (top, left, bottom, right).
When `ceil_mode` is `True`, ceil will be used instead of floor in this
equation.
Expand Down Expand Up @@ -216,6 +252,15 @@ NNVM_REGISTER_OP(avg_pool2d)
<< "Pool2D only support 4-D input (e.g., NCHW)"
<< " or 5-D input (last dimension is a split of channel)";

if (param.padding.ndim() == 1) {
padding.push_back(padding[0]);
padding.push_back(padding[0]);
padding.push_back(padding[0]);
} else if (param.padding.ndim() == 2) {
padding.push_back(padding[0]);
padding.push_back(padding[1]);
}

return Array<Tensor>{
topi::nn::pool(inputs[0], pool_size, strides, padding,
topi::nn::kAvgPool, ceil_mode, layout.name(), count_include_pad)};
Expand Down
36 changes: 29 additions & 7 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def test_forward_pooling():
pooling_type='MAX',
dilation_rate=[1, 1],
strides=[1, 1])

_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[1, 1],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1],
strides=[1, 1])

_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[1, 1],
padding='SAME',
Expand All @@ -136,6 +136,33 @@ def test_forward_pooling():
dilation_rate=[1, 1],
strides=[1, 1])

_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[2, 1],
padding='SAME',
pooling_type='MAX',
dilation_rate=[1, 1],
strides=[1, 1])
_test_pooling(input_shape=[2, 9, 10, 2],
window_shape=[2, 1],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1],
strides=[2, 1])

_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[2, 3],
padding='SAME',
pooling_type='MAX',
dilation_rate=[1, 1],
strides=[2, 1])
_test_pooling(input_shape=[2, 10, 9, 2],
window_shape=[2, 3],
padding='SAME',
pooling_type='AVG',
dilation_rate=[1, 1],
strides=[1, 2])


#######################################################################
# Convolution
# -----------
Expand Down Expand Up @@ -419,12 +446,7 @@ def test_forward_inception_v3():
top_tvm = np.squeeze(tvm_output).argsort()[-3:][::-1]
top_tf = np.squeeze(tf_output).argsort()[-3:][::-1]

# TVM implementation of SAME padding some times make a slight deviation.
# Hence check for top predictions.
top_tvm = np.sort(top_tvm)
top_tf = np.sort(top_tf)

np.testing.assert_allclose(top_tf, top_tvm)
np.testing.assert_allclose(top_tf, top_tvm, rtol=1e-5, atol=1e-5)

#######################################################################
# Inception V1
Expand Down
32 changes: 16 additions & 16 deletions topi/include/topi/nn/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,25 @@ inline Tensor pool_impl(const Tensor& x,
CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 2) << "Pooling padding_size must have 2 elements";
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";

auto kernel_height = kernel_size[0];
auto kernel_width = kernel_size[1];
auto stride_height = stride_size[0];
auto stride_width = stride_size[1];
auto padding_height = padding_size[0];
auto padding_width = padding_size[1];

auto height = x->shape[height_axis];
auto width = x->shape[width_axis];

auto pad_tuple = detail::GetPadTuple(padding_height, padding_width);
auto pad_top = pad_tuple[0];
auto pad_left = pad_tuple[1];
auto pad_down = pad_tuple[2];
auto pad_right = pad_tuple[3];
auto pad_top = padding_size[0];
auto pad_left = padding_size[1];
auto pad_bottom = padding_size[2];
auto pad_right = padding_size[3];

if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
// dividing by stride.
pad_down += stride_height - 1;
pad_bottom += stride_height - 1;
pad_right += stride_width - 1;
}

Expand All @@ -82,11 +79,11 @@ inline Tensor pool_impl(const Tensor& x,
pad_before.Set(width_axis, pad_left);

Array<Expr> pad_after(std::vector<Expr>(x->shape.size(), 0));
pad_after.Set(height_axis, pad_down);
pad_after.Set(height_axis, pad_bottom);
pad_after.Set(width_axis, pad_right);

auto out_height = tvm::ir::Simplify(
(height - kernel_height + pad_top + pad_down) / stride_height + 1);
(height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
auto out_width = tvm::ir::Simplify(
(width - kernel_width + pad_left + pad_right) / stride_width + 1);

Expand All @@ -97,9 +94,12 @@ inline Tensor pool_impl(const Tensor& x,
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);

const int64_t *padding_h = HalideIR::Internal::as_const_int(padding_height);
const int64_t *padding_w = HalideIR::Internal::as_const_int(padding_width);
const bool do_pad = ((padding_h && *padding_h) || (padding_w && *padding_w));
const int64_t *padding_h0 = HalideIR::Internal::as_const_int(pad_top);
const int64_t *padding_w0 = HalideIR::Internal::as_const_int(pad_left);
const int64_t *padding_h1 = HalideIR::Internal::as_const_int(pad_bottom);
const int64_t *padding_w1 = HalideIR::Internal::as_const_int(pad_right);
const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));

if (pool_type == kMaxPool) {
auto temp = do_pad ? pad(x, pad_before, pad_after, x->dtype.min(), "pad_temp") : x;
Expand All @@ -125,8 +125,8 @@ inline Tensor pool_impl(const Tensor& x,
if (count_include_pad) {
return tsum(output) / (kernel_height * kernel_width);
} else {
Expr h_start = output[height_axis] * stride_height - padding_height;
Expr w_start = output[width_axis] * stride_width - padding_width;
Expr h_start = output[height_axis] * stride_height - pad_top;
Expr w_start = output[width_axis] * stride_width - pad_left;
Expr h_end = ir::Min::make(h_start + kernel_height, height);
Expr w_end = ir::Min::make(w_start + kernel_width, width);
h_start = ir::Max::make(h_start, make_const(Int(32), 0));
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/nn/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def pool(data,
stride : list/tuple of two ints
Stride size, [stride_height, stride_width]
padding : list/tuple of two ints
Pad size, [pad_height, pad_width]
padding : list/tuple of four ints
Pad size, [pad_top, pad_left, pad_bottom, pad_right]]
pool_type : str
Pool type, 'max' or 'avg'
Expand Down
37 changes: 20 additions & 17 deletions topi/tests/python/test_topi_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
iw = ih
kw = kh
sw = sh
ph, pw = padding
pt, pl, pb, pr = padding
A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
Expand All @@ -19,16 +19,15 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
bshape = get_const_tuple(B.shape)
ashape = get_const_tuple(A.shape)
if ceil_mode:
assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1)
assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1)
assert bshape[2] == int(math.ceil(float(ashape[2] - kh + pt + pb) / sh) + 1)
assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pl + pr) / sw) + 1)
else:
assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1)
assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1)

assert bshape[2] == int(math.floor(float(ashape[2] - kh + pt + pb) / sh) + 1)
assert bshape[3] == int(math.floor(float(ashape[3] - kw + pl + pr) / sw) + 1)

a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype)
no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
pad_np = np.zeros(shape=(n, ic, ih+pt+pb, iw+pl+pr)).astype(dtype)
no_zero = (range(n), range(ic), (range(pt, ih+pt)), (range(pl, iw+pl)))
pad_np[np.ix_(*no_zero)] = a_np
_, oc, oh, ow = get_const_tuple(B.shape)
b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)
Expand Down Expand Up @@ -67,15 +66,19 @@ def check_device(device):
check_device(device)

def test_pool():
verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False, True)
verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False, True)
verify_pool(1, 256, 32, 2, 2, [1, 2], 'avg', False, False)
verify_pool(1, 256, 31, 4, 4, [3, 3], 'avg', False, False)
verify_pool(1, 256, 31, 4, 4, [0, 0], 'avg', False, False)
verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True)

verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False)
verify_pool(1, 256, 31, 4, 4, [3, 3, 3, 3], 'avg', False, False)
verify_pool(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False)
verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], 'max', True)

verify_pool(1, 256, 31, 3, 3, [2, 1, 0, 3], 'avg', False, True)
verify_pool(1, 256, 32, 2, 2, [0, 3, 2, 1], 'avg', False, False)
verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False)
verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)


def verify_global_pool(n, c, h, w, pool_type):
Expand Down
Loading

0 comments on commit f9cb969

Please sign in to comment.