Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Op] Adaptive pooling #3085

Merged
merged 13 commits into from
May 9, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use adaptive_pool to compute global_pool
Wang authored and kevinthesun committed May 6, 2019
commit 5bcc61a479d7c10b2463d7cb248cbf4ca8cb3f8c
103 changes: 31 additions & 72 deletions topi/include/topi/nn/pooling.h
Original file line number Diff line number Diff line change
@@ -231,78 +231,6 @@ inline Tensor pool(const Tensor& x,
count_include_pad);
}

/*!
* \brief Perform global pooling on height and width dimension of data.
* It decides the height and width dimension according to the layout string,
* in which 'W' and 'H' means width and height respectively.
* Width and height dimension cannot be split.
* For example, NCHW, NCHW16c, ... are valid for global_pool,
* while NCHW16w, NCHW16h are not.
* See \a layout for more information of the layout string convention.
*
* \param x The input tensor represent as layout
* \param pool_type The type of pooling operator
* \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear.
* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
* where upper case indicates a dimension and
* the corresponding lower case (with factor size) indicates the sub-dimension.
* For example, `NCHW16c` can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* (in which factor size `16` will not be used in pooling but for other operators,
* it can be used to decide the output shape).
* Since pooling does not care about the factor size of
* dimensions other than `H` and `W`, one can pass `NCHWc` as well.
*
* \return The output tensor in same layout with height and width dimension size of 1.
* e.g., for NCHW, the output shape will be [batch, channel, 1, 1]
*/
inline Tensor global_pool(const Tensor& x,
PoolType pool_type,
const std::string& layout = "NCHW") {
CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";

int height_axis = -1, width_axis = -1;
CHECK(find_height_width(layout, &height_axis, &width_axis))
<< "Unsupported layout " << layout;

Array<Expr> out_shape = x->shape;
out_shape.Set(height_axis, 1);
out_shape.Set(width_axis, 1);

auto height = x->shape[height_axis];
auto width = x->shape[width_axis];

auto dheight = tvm::reduce_axis(Range(0, height), "rv1");
auto dwidth = tvm::reduce_axis(Range(0, width), "rv2");

if (pool_type == kMaxPool) {
return tvm::compute(out_shape,
[&](const Array<Var>& output) {
Array<Expr> indices;
for (const Var& var : output) indices.push_back(var);
indices.Set(height_axis, dheight);
indices.Set(width_axis, dwidth);
return tvm::max(x(indices), { dheight, dwidth }); // NOLINT(*)
}, "tensor", "global_pool_max");
} else if (pool_type == kAvgPool) {
auto tsum = tvm::compute(out_shape,
[&](const Array<Var>& output) {
Array<Expr> indices;
for (const Var& var : output) indices.push_back(var);
indices.Set(height_axis, dheight);
indices.Set(width_axis, dwidth);
return tvm::sum(x(indices), { dheight, dwidth });
}, "tensor", "global_pool_sum");

return tvm::compute(out_shape,
[&](const Array<Var>& output) {
return tsum(output) / tvm::cast(x->dtype, height * width);
}, "tensor", kElementWise);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
}
}

inline Expr start_index(const Var& out_index,
const Expr& odim,
@@ -417,6 +345,37 @@ inline Tensor adaptive_pool(const Tensor& x,
return adaptive_pool_impl(x, output_size, pool_type, height_axis, width_axis);
}

/*!
* \brief Perform global pooling on height and width dimension of data.
* It decides the height and width dimension according to the layout string,
* in which 'W' and 'H' means width and height respectively.
* Width and height dimension cannot be split.
* For example, NCHW, NCHW16c, ... are valid for global_pool,
* while NCHW16w, NCHW16h are not.
* See \a layout for more information of the layout string convention.
*
* \param x The input tensor represent as layout
* \param pool_type The type of pooling operator
* \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear.
* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
* where upper case indicates a dimension and
* the corresponding lower case (with factor size) indicates the sub-dimension.
* For example, `NCHW16c` can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* (in which factor size `16` will not be used in pooling but for other operators,
* it can be used to decide the output shape).
* Since pooling does not care about the factor size of
* dimensions other than `H` and `W`, one can pass `NCHWc` as well.
*
* \return The output tensor in same layout with height and width dimension size of 1.
* e.g., for NCHW, the output shape will be [batch, channel, 1, 1]
*/
inline Tensor global_pool(const Tensor& x,
PoolType pool_type,
const std::string& layout = "NCHW") {
return adaptive_pool(x, Array<Expr>{1, 1}, pool_type, layout);
}

} // namespace nn
} // namespace topi
#endif // TOPI_NN_POOLING_H_
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .pooling import schedule_pool, schedule_global_pool, schedule_adaptive_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
from .batch_matmul import schedule_batch_matmul
34 changes: 28 additions & 6 deletions topi/python/topi/cuda/pooling.py
Original file line number Diff line number Diff line change
@@ -20,23 +20,26 @@
from .. import tag
from .. import generic

@generic.schedule_global_pool.register(["cuda", "gpu"])
def schedule_global_pool(outs):
"""Schedule for global_pool.


@generic.schedule_adaptive_pool.register(["cuda", "gpu"])
def schedule_adaptive_pool(outs):
"""Schedule for adaptive_pool.

Parameters
----------
outs: Array of Tensor
The computation graph description of global_pool
The computation graph description of adaptive_poo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adaptive_poo->adaptive_pool

in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for global_pool.
The computation schedule for adaptive_pool.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])

def _schedule(Pool):
num_thread = 8
block_x = tvm.thread_axis("blockIdx.x")
@@ -73,7 +76,7 @@ def traverse(OP):
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('global_pool'):
elif OP.tag.startswith('adaptive_pool'):
Pool = OP.output(0)
_schedule(Pool)
else:
@@ -84,6 +87,23 @@ def traverse(OP):
traverse(outs[0].op)
return s

@generic.schedule_global_pool.register(["cuda", "gpu"])
def schedule_global_pool(outs):
icemelon marked this conversation as resolved.
Show resolved Hide resolved
"""Schedule for global_pool.

Parameters
----------
outs: Array of Tensor
The computation graph description of global_pool
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for global_pool.
"""
return schedule_adaptive_pool(outs)


@generic.schedule_pool.register(["cuda", "gpu"])
def schedule_pool(outs, layout):
@@ -147,3 +167,5 @@ def traverse(OP):

traverse(outs[0].op)
return s


44 changes: 11 additions & 33 deletions topi/python/topi/x86/pooling.py
Original file line number Diff line number Diff line change
@@ -110,14 +110,14 @@ def traverse(OP):
return s


@generic.schedule_global_pool.register(["cpu"])
def schedule_global_pool(outs):
"""Schedule for global pool
@generic.schedule_adaptive_pool.register(["cpu"])
def schedule_adaptive_pool(outs):
"""Schedule for adaptive pool

Parameters
----------
outs: Array of Tensor
The computation graph description of pool
The computation graph description of adaptive pool
icemelon marked this conversation as resolved.
Show resolved Hide resolved
in the format of an array of tensors.

Returns
@@ -139,7 +139,7 @@ def traverse(OP):
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('global_pool'):
elif OP.tag.startswith('adaptive_pool'):
Pool = OP.output(0)
_parallel_sch(s[Pool], outs[0].shape)
else:
@@ -150,42 +150,20 @@ def traverse(OP):
traverse(outs[0].op)
return s

@generic.schedule_adaptive_pool.register(["cpu"])
def schedule_adaptive_pool(outs):
"""Schedule for adaptive pool

@generic.schedule_global_pool.register(["cpu"])
def schedule_global_pool(outs):
"""Schedule for global pool

Parameters
----------
outs: Array of Tensor
The computation graph description of adaptive pool
The computation graph description of pool
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []

def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('adaptive_pool'):
Pool = OP.output(0)
_parallel_sch(s[Pool], outs[0].shape)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)

scheduled_ops.append(OP)

traverse(outs[0].op)
return s
return schedule_adaptive_pool(outs)
5 changes: 0 additions & 5 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
@@ -676,11 +676,6 @@ TVM_REGISTER_GENERIC_FUNC(schedule_global_pool)
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool));

TVM_REGISTER_GENERIC_FUNC(schedule_adaptive_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool));

TVM_REGISTER_GENERIC_FUNC(schedule_reduce)
.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline))