Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pool2d grads #3

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions include/tvm/ir_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ inline Expr const_false(int lanes = 1) {
/*!
* \brief Get x as constant int expression.
* \param x The expression
* \param default_ Pointer to the default value. Defaults to nullptr.
* \return the address to the int expression,
* return nullptr, if x is not IntImm.
* return default_, if x is not IntImm.
*/
inline const int64_t* as_const_int(const Expr& x) {
if (!x.defined()) return nullptr;
inline const int64_t* as_const_int(const Expr& x, const int64_t* default_ = nullptr) {
if (!x.defined()) return default_;
if (const ir::IntImm* op = x.as<ir::IntImm>()) {
return &(op->value);
} else {
return nullptr;
return default_;
}
}

Expand Down
3 changes: 2 additions & 1 deletion topi/include/topi/detail/ravel_unravel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ using namespace tvm;
*
* \return The index after flattening
*/
inline Expr RavelIndex(Array<Var> indices, Array<Expr> shape) {
template<typename T>
inline Expr RavelIndex(Array<T> indices, Array<Expr> shape) {
CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
CHECK_GT(indices.size(), 0) << "indices must not be empty";
Expr idx;
Expand Down
146 changes: 140 additions & 6 deletions topi/include/topi/nn/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
#include "tvm/tvm.h"
#include "tvm/ir_pass.h"
#include "topi/tags.h"
#include "topi/detail/constant_utils.h"
#include "topi/detail/pad_utils.h"
#include "topi/nn.h"
#include "topi/reduction.h"

namespace topi {
namespace nn {
Expand Down Expand Up @@ -94,12 +96,12 @@ inline Tensor pool_impl(const Tensor& x,
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);

const int64_t *padding_h0 = as_const_int(pad_top);
const int64_t *padding_w0 = as_const_int(pad_left);
const int64_t *padding_h1 = as_const_int(pad_bottom);
const int64_t *padding_w1 = as_const_int(pad_right);
const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
const int64_t default_pad = 0;
const int64_t *padding_h0 = as_const_int(pad_top, &default_pad);
const int64_t *padding_w0 = as_const_int(pad_left, &default_pad);
const int64_t *padding_h1 = as_const_int(pad_bottom, &default_pad);
const int64_t *padding_w1 = as_const_int(pad_right, &default_pad);
const bool do_pad = *padding_h0 + *padding_w0 + *padding_h1 + *padding_w1 > 0;

if (pool_type == kMaxPool) {
auto temp = do_pad ? pad(x, pad_before, pad_after, x->dtype.min(), "pad_temp") : x;
Expand Down Expand Up @@ -142,6 +144,121 @@ inline Tensor pool_impl(const Tensor& x,
}
}

inline Tensor pool_grad_impl(const Tensor& ograd,
const Tensor& x,
const Array<Expr>& kernel_size,
const Array<Expr>& stride_size,
const Array<Expr>& padding_size,
PoolType pool_type,
bool ceil_mode,
const size_t height_axis,
const size_t width_axis,
bool count_include_pad) {
CHECK(ograd->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";

// std::vector<int> k = GetConstIntValues(kernel_size, )
auto kernel_height = kernel_size[0];
auto kernel_width = kernel_size[1];
auto stride_height = stride_size[0];
auto stride_width = stride_size[1];

auto height = x->shape[height_axis];
auto width = x->shape[width_axis];

auto pad_top = padding_size[0];
auto pad_left = padding_size[1];
auto pad_bottom = padding_size[2];
auto pad_right = padding_size[3];

if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
// dividing by stride.
pad_bottom += stride_height - 1;
pad_right += stride_width - 1;
}

Array<Expr> pad_before(std::vector<Expr>(x->shape.size(), 0));
pad_before.Set(height_axis, pad_top);
pad_before.Set(width_axis, pad_left);

Array<Expr> pad_after(std::vector<Expr>(x->shape.size(), 0));
pad_after.Set(height_axis, pad_bottom);
pad_after.Set(width_axis, pad_right);

auto out_height = tvm::ir::Simplify(
(height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
auto out_width = tvm::ir::Simplify(
(width - kernel_width + pad_left + pad_right) / stride_width + 1);

auto dheight = tvm::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::reduce_axis(Range(0, kernel_width));

Array<Expr> out_shape = x->shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);

const int64_t default_pad = 0;
const int64_t *padding_h0 = as_const_int(pad_top, &default_pad);
const int64_t *padding_w0 = as_const_int(pad_left, &default_pad);
const int64_t *padding_h1 = as_const_int(pad_bottom, &default_pad);
const int64_t *padding_w1 = as_const_int(pad_right, &default_pad);
const bool do_pad = *padding_h0 + *padding_w0 + *padding_h1 + *padding_w1 > 0;

if (pool_type == kMaxPool) {
auto argmax = MakeArgmaxReducer();
auto pad_x = do_pad ? pad(x, pad_before, pad_after, x->dtype.min(), "pad_temp") : x;

Array<Expr> ravel_shape;
for (const auto& sh : x->shape) ravel_shape.push_back(sh);
ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);

auto mp_argmax = tvm::compute(out_shape, [&](const Array<Var>& inds) {
Array<Expr> window_inds;
for (const Var& ind : inds) window_inds.push_back(ind);
window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
auto idx = detail::RavelIndex(window_inds, ravel_shape);
return argmax({ idx, pad_x(window_inds) }, { dheight, dwidth }, nullptr);
}, "maxpool_grad_argmax", kCommReduceIdx);

auto mp_inds = tvm::compute(out_shape, [&](const Array<Var>& inds) {
return mp_argmax[0](inds);
}, "maxpool_grad_inds", kCommReduceIdx);

auto windowh = tvm::reduce_axis(
Range(0, (kernel_height + stride_height - 1) / stride_height));
auto windoww = tvm::reduce_axis(
Range(0, (kernel_width + stride_width - 1) / stride_width));

return tvm::compute(x->shape, [&](const Array<Var>& inds) {
Array<Expr> pad_inds;
for (const Var& ind : inds) pad_inds.push_back(ind);
pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
auto idx = detail::RavelIndex(pad_inds, ravel_shape);

Array<Expr> out_idx;
for (const Var& ind : inds) out_idx.push_back(ind);
out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
return tvm::sum(
tvm::select(mp_inds(out_idx) == idx, ograd(out_idx), make_const(x->dtype, 0))
, { windowh, windoww });
}, "tensor", "pool_grad_max");
} else if (pool_type == kAvgPool) {
LOG(FATAL) << "implemented in python land";
return x;
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
}
}

inline bool find_height_width(const std::string& layout,
int* height_axis,
int* width_axis) {
Expand Down Expand Up @@ -212,6 +329,23 @@ inline Tensor pool(const Tensor& x,
count_include_pad);
}

inline Tensor pool_grad(const Tensor& ograd,
const Tensor& x,
const Array<Expr>& kernel_size,
const Array<Expr>& stride_size,
const Array<Expr>& padding_size,
PoolType pool_type,
bool ceil_mode,
const std::string& layout = "NCHW",
bool count_include_pad = true) {
int height_axis = -1, width_axis = -1;
CHECK(find_height_width(layout, &height_axis, &width_axis))
<< "Unsupported layout " << layout;
return pool_grad_impl(ograd, x, kernel_size, stride_size, padding_size,
pool_type, ceil_mode, height_axis, width_axis,
count_include_pad);
}

/*!
* \brief Perform global pooling on height and width dimension of data.
* It decides the height and width dimension according to the layout string,
Expand Down
35 changes: 21 additions & 14 deletions topi/include/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,26 @@ inline Tensor argmin(const Tensor& data,
return CommReduceIdx(data, axis, func, keepdims, atleast1d);
}

/*!
* \brief Creates a function that returns the argmax as (index, value)
* \return The argmax reducer function
*/
inline FCommReduce MakeArgmaxReducer() {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result;
result.push_back(tvm::select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<Type> types) {
Array<Expr> result;
result.push_back(tvm::make_const(types[0], -1)); // idx
result.push_back(types[1].min()); // val
return result;
};
return MakeCommReducer(fcombine, fidentity, "argmax");
}

/*!
* \brief Creates an operation that finds the indices of the maximum
* values over a given axis.
Expand All @@ -443,20 +463,7 @@ inline Tensor argmax(const Tensor& data,
const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result;
result.push_back(tvm::select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<Type> types) {
Array<Expr> result;
result.push_back(tvm::make_const(types[0], -1)); // idx
result.push_back(types[1].min()); // val
return result;
};
auto func = MakeCommReducer(fcombine, fidentity, "argmax");
return CommReduceIdx(data, axis, func, keepdims, atleast1d);
return CommReduceIdx(data, axis, MakeArgmaxReducer(), keepdims, atleast1d);
}

/*!
Expand Down
13 changes: 9 additions & 4 deletions topi/python/topi/nn/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@


@tvm.target.generic_func
def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype,
_filter_shape=None, tag="conv2d_transpose_nchw"):
"""Transposed 2D convolution nchw forward operator.

Parameters
Expand All @@ -35,8 +36,9 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
assert (Filter is None) ^ (_filter_shape is None), "must specify Filter or filter_shape"
batch, in_c, in_h, in_w = Input.shape
_, out_c, filter_h, filter_w = Filter.shape
_, out_c, filter_h, filter_w = Filter.shape if Filter else _filter_shape
stride_h, stride_w = strides
# dilate stage
DilatedInput = dilate(Input, [1, 1, stride_h, stride_w], name='DilatedInput')
Expand All @@ -58,11 +60,14 @@ def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
dh = tvm.reduce_axis((0, filter_h), name='dh')
dw = tvm.reduce_axis((0, filter_w), name='dw')

def _filt(*inds):
return Filter(*inds).astype(out_dtype) if Filter else 1 / (filter_h * filter_w)

Output = tvm.compute(
(batch, out_c, out_h, out_w),
lambda b, c, h, w: tvm.sum(
PaddedInput[b, dc, h+dh, w+dw].astype(out_dtype) *
Filter[dc, c, filter_h-1-dh, filter_w-1-dw].astype(out_dtype),
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
_filt(dc, c, filter_h-1-dh, filter_w-1-dw),
axis=[dc, dh, dw]), tag=tag)

return Output
68 changes: 68 additions & 0 deletions topi/python/topi/nn/pooling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""TVM operator pooling compute."""
from __future__ import absolute_import
from .. import cpp
from ..nn.conv2d_transpose import conv2d_transpose_nchw

POOL_TYPE_CODE = {
"avg": 0,
Expand Down Expand Up @@ -97,3 +98,70 @@ def pool(data,
"""
return cpp.nn.pool(data, kernel, stride, padding,
POOL_TYPE_CODE[pool_type], ceil_mode, layout, count_include_pad)


def pool_grad(grads,
data,
kernel,
stride,
padding,
pool_type,
ceil_mode=False,
layout="NCHW",
count_include_pad=True):
"""Perform pooling on height and width dimension of data.
It decides the height and width dimension according to the layout string,
in which 'W' and 'H' means width and height respectively.
Width and height dimension cannot be split.
For example, NCHW, NCHW16c, etc. are valid for pool,
while NCHW16w, NCHW16h are not.
See parameter `layout` for more information of the layout string convention.

Parameters
----------
grads : tvm.Tensor
n-D with shape of layout

data : tvm.Tensor
n-D with shape of layout

kernel : list/tuple of two ints
Kernel size, [kernel_height, kernel_width]

stride : list/tuple of two ints
Stride size, [stride_height, stride_width]

padding : list/tuple of four ints
Pad size, [pad_top, pad_left, pad_bottom, pad_right]]

pool_type : str
Pool type, 'max' or 'avg'

ceil_mode : bool
Whether to use ceil when calculating output size.

layout: string
Layout of the input data.
The layout is supposed to be composed of upper cases, lower cases and numbers,
where upper case indicates a dimension and
the corresponding lower case with factor size indicates the split dimension.
For example, NCHW16c can describe a 5-D tensor of
[batch_size, channel, height, width, channel_block],
in which channel_block=16 is a split of dimension channel.

count_include_pad: bool
Whether include padding in the calculation when pool_type is 'avg'

Returns
-------
output : tvm.Tensor
n-D in the same layout
"""
if pool_type == "max":
return cpp.nn.pool_grad(grads, data, kernel,
stride, padding, POOL_TYPE_CODE[pool_type],
ceil_mode, layout, count_include_pad)
else:
assert layout == 'NCHW', 'avg_pool2d_grad does not yet support %s layout' % layout
return conv2d_transpose_nchw(grads, None, stride, padding, data.dtype,
(1, data.shape[1], *kernel), tag='pool_avg_grad')
7 changes: 7 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,13 @@ TVM_REGISTER_GLOBAL("topi.nn.pool")
args[5], args[6], args[7]);
});

TVM_REGISTER_GLOBAL("topi.nn.pool_grad")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4],
static_cast<nn::PoolType>(static_cast<int>(args[5])),
args[6], args[7], args[8]);
});

TVM_REGISTER_GLOBAL("topi.nn.global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::global_pool(args[0],
Expand Down