From 2823be8c9efd03cf6b5938c5f609cfb000ced1ac Mon Sep 17 00:00:00 2001 From: optima2005 <56945758+optima2005@users.noreply.github.com> Date: Thu, 12 Dec 2019 14:06:20 +0800 Subject: [PATCH] [TOPI] implement pool3d op (#4478) * [TOPI] implement pool3d op * use PoolInferCorrectLayout for both 2d and 3d pooling * unify MakeMaxPool and MakeAvgPool --- include/tvm/relay/attrs/nn.h | 62 +++++ src/relay/op/nn/pooling.cc | 327 +++++++++++++++++++++---- topi/include/topi/nn/pooling.h | 207 +++++++++++++++- topi/python/topi/nn/pooling.py | 57 +++++ topi/src/topi.cc | 16 +- topi/tests/python/test_topi_pooling.py | 87 +++++++ 6 files changed, 699 insertions(+), 57 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index de404f49c6aa..4422fce250c2 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -406,6 +406,68 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode { }; +/*! \brief Attributes for 3D max pool operator */ +struct MaxPool3DAttrs : public tvm::AttrsNode { + Array pool_size; + Array strides; + Array padding; + std::string layout; + bool ceil_mode; + + TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") { + TVM_ATTR_FIELD(pool_size) + .describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCDHW") + .describe("Dimension ordering of data and weight. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false) + .describe("When true, will use ceil instead of floor to compute the output shape."); + } +}; + +/*! \brief Attributes for 3D avg pool operator */ +struct AvgPool3DAttrs : public tvm::AttrsNode { + Array pool_size; + Array strides; + Array padding; + std::string layout; + bool ceil_mode; + bool count_include_pad; + + TVM_DECLARE_ATTRS(AvgPool3DAttrs, "relay.attrs.AvgPool3DAttrs") { + TVM_ATTR_FIELD(pool_size) + .describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCDHW") + .describe("Dimension ordering of data and weight. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false) + .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad).set_default(false) + .describe("When true, will include padding to compute the average"); + } +}; + + /*! \brief Attributes for dense operator */ struct DenseAttrs : public tvm::AttrsNode { IndexExpr units; diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 8915fc5b6a76..aa3755921037 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -37,7 +37,7 @@ TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); template -Array > Pool2DInferCorrectLayout( +Array > PoolInferCorrectLayout( const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, @@ -55,6 +55,44 @@ Array > Pool2DInferCorrectLayout( return Array >{{inferred_layout}, {inferred_layout}}; } +template +Expr MakeMaxPool(Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode, + std::string op_name) { + auto attrs = make_node(); + attrs->pool_size = std::move(pool_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->layout = std::move(layout); + attrs->ceil_mode = ceil_mode; + static const Op& op = Op::Get(op_name); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +template +Expr MakeAvgPool(Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode, + bool count_include_pad, + std::string op_name) { + auto attrs = make_node(); + attrs->pool_size = std::move(pool_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->layout = std::move(layout); + attrs->ceil_mode = ceil_mode; + attrs->count_include_pad = count_include_pad; + static const Op& op = Op::Get(op_name); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + template bool Pool2DRel(const Array& types, int num_inputs, @@ -127,23 +165,6 @@ bool Pool2DRel(const Array& types, return true; } -// MaxPool2D -Expr MakeMaxPool2D(Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { - auto attrs = make_node(); - attrs->pool_size = std::move(pool_size); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->layout = std::move(layout); - attrs->ceil_mode = ceil_mode; - static const Op& op = Op::Get("nn.max_pool2d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); -} - template Array Pool2DCompute(const Attrs& attrs, const Array& inputs, @@ -193,7 +214,16 @@ Array Pool2DCompute(const Attrs& attrs, } TVM_REGISTER_API("relay.op.nn._make.max_pool2d") -.set_body_typed(MakeMaxPool2D); +.set_body_typed, Array, Array, + std::string, bool)>([](Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool2d"); +}); RELAY_REGISTER_OP("nn.max_pool2d") @@ -222,33 +252,23 @@ RELAY_REGISTER_OP("nn.max_pool2d") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("MaxPool2D", Pool2DRel) -.set_attr("FInferCorrectLayout", Pool2DInferCorrectLayout) +.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) .set_attr("FTVMCompute", Pool2DCompute); // AvgPool2D -Expr MakeAvgPool2D(Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { - auto attrs = make_node(); - attrs->pool_size = std::move(pool_size); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->layout = std::move(layout); - attrs->ceil_mode = ceil_mode; - attrs->count_include_pad = count_include_pad; - static const Op& op = Op::Get("nn.avg_pool2d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); -} - - TVM_REGISTER_API("relay.op.nn._make.avg_pool2d") -.set_body_typed(MakeAvgPool2D); - +.set_body_typed, Array, Array, + std::string, bool, bool)>([](Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool2d"); +}); RELAY_REGISTER_OP("nn.avg_pool2d") .describe(R"code( @@ -277,7 +297,7 @@ Average pooling operation for one dimensional data. .add_argument("data", "Tensor", "The input tensor.") .set_support_level(2) .add_type_rel("AvgPool2D", Pool2DRel) -.set_attr("FInferCorrectLayout", Pool2DInferCorrectLayout) +.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) .set_attr("FTVMCompute", Pool2DCompute); // relay.nn.global_pool_2d & relay.nn.max_pool_2d @@ -365,7 +385,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") .set_support_level(2) .add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) .set_attr("FInferCorrectLayout", - Pool2DInferCorrectLayout) + PoolInferCorrectLayout) .set_attr("FTVMCompute", GlobalPool2DCompute); // GlobalMaxPool @@ -396,7 +416,7 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") .set_support_level(2) .add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) .set_attr("FInferCorrectLayout", - Pool2DInferCorrectLayout) + PoolInferCorrectLayout) .set_attr("FTVMCompute", GlobalPool2DCompute); @@ -522,7 +542,7 @@ RELAY_REGISTER_OP("contrib.adaptive_avg_pool2d") .set_support_level(10) .add_type_rel("AdaptiveAvgPool2D", AdaptivePool2DRel) .set_attr("FInferCorrectLayout", - Pool2DInferCorrectLayout) + PoolInferCorrectLayout) .set_attr("FTVMCompute", AdaptivePool2DCompute); @@ -561,7 +581,7 @@ RELAY_REGISTER_OP("contrib.adaptive_max_pool2d") .set_support_level(10) .add_type_rel("AdaptiveMaxPool2D", AdaptivePool2DRel) .set_attr("FInferCorrectLayout", - Pool2DInferCorrectLayout) + PoolInferCorrectLayout) .set_attr("FTVMCompute", AdaptivePool2DCompute); @@ -720,5 +740,220 @@ RELAY_REGISTER_OP("nn.avg_pool2d_grad") .set_attr("FTVMCompute", Pool2DGradCompute); +// relay.nn.max_pool3d & relay.nn.avg_pool3d +TVM_REGISTER_NODE_TYPE(MaxPool3DAttrs); +TVM_REGISTER_NODE_TYPE(AvgPool3DAttrs); + +template +bool Pool3DRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + + if (data == nullptr) return false; + + const auto dshape = data->shape; + CHECK_GE(dshape.size(), 3U) + << "Pool3D only support input >= 3-D: input must have depth, height and width"; + const auto param = attrs.as(); + CHECK(param != nullptr); + + Layout layout(param->layout); + CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) && + layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) && + !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) + << "Invalid layout " << layout + << ". Pool3D layout must have D, H and W, which cannot be split"; + + const auto didx = layout.IndexOf(LayoutAxis::Get('D')); + const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); + const auto widx = layout.IndexOf(LayoutAxis::Get('W')); + + IndexExpr pad_d, pad_h, pad_w; + if (param->padding.size() == 1) { + pad_d = param->padding[0] * 2; + pad_h = param->padding[0] * 2; + pad_w = param->padding[0] * 2; + } else if (param->padding.size() == 3) { + // (front, top, left) + pad_d = param->padding[0] * 2; + pad_h = param->padding[1] * 2; + pad_w = param->padding[2] * 2; + } else if (param->padding.size() == 6) { + // (front, top, left, back, bottom, right) + pad_d = param->padding[0] + param->padding[3]; + pad_h = param->padding[1] + param->padding[4]; + pad_w = param->padding[2] + param->padding[5]; + } else { + return false; + } + + std::vector oshape; + for (const auto& e : dshape) { + oshape.push_back(e); + } + + std::vector idxes = {didx, hidx, widx}; + for (int i = 0; i < 3; i++) { + int ii = idxes[i]; + if (dshape[ii].as()) { + oshape[ii] = dshape[ii]; + } else { + if (param->ceil_mode) { + oshape[ii] = ((dshape[ii] + pad_d - param->pool_size[i] + + param->strides[i] - 1) / param->strides[i]) + 1; + } else { + oshape[ii] = ((dshape[ii] + pad_d - param->pool_size[i]) / param->strides[i]) + 1; + } + } + } + + // assign output type + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + + +template +Array Pool3DCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + static const Layout kNCDHW("NCDHW"); + const auto* param = attrs.as(); + CHECK(param != nullptr); + auto pool_size = param->pool_size; + auto strides = param->strides; + auto padding = param->padding; + auto ceil_mode = param->ceil_mode; + Layout layout(param->layout); + + CHECK(BijectiveLayoutNode::make(layout, kNCDHW).defined()) + << "max_pool3d currently only supports layouts that are convertible from NCDHW"; + CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1) + << "max_pool3d does not support input split on depth"; + CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) + << "max_pool3d does not support input split on height"; + CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) + << "max_pool3d does not support input split on width"; + + CHECK(inputs[0].ndim() == 4U || + inputs[0].ndim() == 5U || + inputs[0].ndim() == 6U) + << "Pool3D only support 5-D input (e.g., NCDHW)" + << " or 6-D input (e.g. NCDHWc on for vector instructions)" + << " or 7-D input (e.g. NCDHWnc for tensor accelerators)"; + + if (param->padding.size() == 1) { + padding.push_back(padding[0]); + padding.push_back(padding[0]); + padding.push_back(padding[0]); + } else if (param->padding.size() == 3) { + padding.push_back(padding[0]); + padding.push_back(padding[1]); + padding.push_back(padding[2]); + } + if (mode == topi::nn::kAvgPool) { + bool count_include_pad = reinterpret_cast(param)->count_include_pad; + return Array{ + topi::nn::pool3d(inputs[0], pool_size, strides, padding, + mode, ceil_mode, layout.name(), count_include_pad)}; + } else { + return Array{ + topi::nn::pool3d(inputs[0], pool_size, strides, padding, + mode, ceil_mode, layout.name())}; + } +} + +TVM_REGISTER_API("relay.op.nn._make.max_pool3d") +.set_body_typed, Array, Array, + std::string, bool)>([](Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool3d"); +}); + +RELAY_REGISTER_OP("nn.max_pool3d") +.describe(R"code(Max pooling operation for three dimensional data. + +- **data**: This depends on the `layout` parameter. Input is 5D array of shape + (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. +- **out**: This depends on the `layout` parameter. Output is 5D array of shape + (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. + out_depth, out_height and out_width are calculated as:: + + out_depth = floor((depth+padding[0]+padding[3]-pool_size[0])/strides[0])+1 + out_height = floor((height+padding[1]+padding[4]-pool_size[1])/strides[1])+1 + out_width = floor((width+padding[2]+padding[5]-pool_size[2])/strides[2])+1 + + where padding will be an expanded array based on number of values passed as:: + one int : all sides same padding used. + three int : front, bottom, right use same as back, top and left. + six int: padding width in the order of (front, top, left, back, bottom, right). + + When `ceil_mode` is `True`, ceil will be used instead of floor in this + equation. + +)code" TVM_ADD_FILELINE) +.set_attrs_type() +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("MaxPool3D", Pool3DRel) +.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) +.set_attr("FTVMCompute", Pool3DCompute); + + +// AvgPool3D +TVM_REGISTER_API("relay.op.nn._make.avg_pool3d") +.set_body_typed, Array, Array, + std::string, bool, bool)>([](Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool3d"); +}); + +RELAY_REGISTER_OP("nn.avg_pool3d") +.describe(R"code( +Average pooling operation for three dimensional data. + +- **data**: This depends on the `layout` parameter. Input is 5D array of shape + (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. +- **out**: This depends on the `layout` parameter. Output is 5D array of shape + (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. + out_depth, out_height and out_width are calculated as:: + + out_depth = floor((depth+padding[0]+padding[3]-pool_size[0])/strides[0])+1 + out_height = floor((height+padding[1]+padding[4]-pool_size[1])/strides[1])+1 + out_width = floor((width+padding[2]+padding[5]-pool_size[2])/strides[2])+1 + + where padding will be an expanded array based on number of values passed as:: + one int : all sides same padding used. + three int : front, bottom, right use same as back, top and left. + six int: padding width in the order of (front, top, left, back, bottom, right). + + When `ceil_mode` is `True`, ceil will be used instead of floor in this + equation. + +)code" TVM_ADD_FILELINE) +.set_attrs_type() +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("AvgPool3D", Pool3DRel) +.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) +.set_attr("FTVMCompute", Pool3DCompute); + } // namespace relay } // namespace tvm diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index f92db65e31fe..72ea2b86c280 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -24,6 +24,7 @@ #ifndef TOPI_NN_POOLING_H_ #define TOPI_NN_POOLING_H_ +#include #include #include @@ -43,6 +44,7 @@ enum PoolType : int { kMaxPool, }; + /*! * \brief Perform pooling on height and width dimension of data. * @@ -325,31 +327,46 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, } } -inline bool find_height_width(const std::string& layout, - int* height_axis, - int* width_axis) { - *height_axis = -1, *width_axis = -1; +inline bool find_depth_height_width(const std::string& layout, + int* depth_axis, + int* height_axis, + int* width_axis) { + *depth_axis = -1, *height_axis = -1, *width_axis = -1; int curr_idx = 0; for (size_t i = 0; i < layout.size(); ++i) { if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) { - if (layout[i] == 'H') { + if (layout[i] == 'D') { + if (*depth_axis != -1) return false; + *depth_axis = curr_idx; + } else if (layout[i] == 'H') { if (*height_axis != -1) return false; *height_axis = curr_idx; } else if (layout[i] == 'W') { if (*width_axis != -1) return false; *width_axis = curr_idx; - } else if (layout[i] == 'h' || layout[i] == 'w') { + } else if (layout[i] == 'd' || layout[i] == 'h' || layout[i] == 'w') { // do not support split on height or width, e.g., NCHW16w return false; } ++curr_idx; } } - if (*height_axis == -1 || *width_axis == -1) return false; + if (*depth_axis == -1 || *height_axis == -1 || *width_axis == -1) return false; return true; } +inline bool find_height_width(const std::string& layout, + int* height_axis, + int* width_axis) { + int dummy; + CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false); + if (*height_axis != -1 && *width_axis != -1) { + return true; + } + return false; +} + /*! * \brief Perform pooling on height and width dimension of data. * It decides the height and width dimension according to the layout string, @@ -591,6 +608,182 @@ inline Tensor global_pool(const Tensor& x, return adaptive_pool(x, Array{1, 1}, pool_type, layout); } +/*! +* \brief Perform pooling on N-dimension of data. +* +* \param x The input tensor +* \param kernel_size Vector of N ints +* \param stride_size Vector of N ints +* \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ..., +* head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN] +* \param pool_type The type of pooling operator +* \param ceil_mode Whether to use ceil when calculating the output size +* \param axis Vector of indices for the N dimensions +* \param count_include_pad Whether include padding in the calculation +* +* \return The output tensor in same layout order +*/ +inline Tensor pool_impl_nd(const Tensor& x, + const Array& kernel_size, + const Array& stride_size, + const Array& padding_size, + PoolType pool_type, + bool ceil_mode, + const std::vector& axis, + bool count_include_pad) { + int k_size = kernel_size.size(); + int x_size = x->shape.size(); + CHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel"; + CHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of" + " kernel"; + CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel"; + + Array daxis; + std::vector kernel(k_size); + std::vector stride(k_size); + std::vector pad_head(k_size); + std::vector pad_tail(k_size); + Array pad_before(std::vector(x_size, 0)); + Array pad_after(std::vector(x_size, 0)); + Array out_shape = x->shape; + + bool do_pad = false; + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + kernel[i] = cast(Int(32), kernel_size[i]); + stride[i] = cast(Int(32), stride_size[i]); + pad_head[i] = cast(Int(32), padding_size[i]); + pad_tail[i] = cast(Int(32), padding_size[i + k_size]); + const int64_t *padding0 = as_const_int(pad_head[i]); + const int64_t *padding1 = as_const_int(pad_tail[i]); + do_pad = (do_pad) ? do_pad : ((padding0 && *padding0) || (padding1 && *padding1)); + + if (ceil_mode) { + // Additional padding to ensure we do ceil instead of floor when + // dividing by stride. + pad_tail[i] += stride[i] - 1; + } + + daxis.push_back(tvm::reduce_axis(Range(0, kernel[i]))); + + pad_before.Set(ii, pad_head[i]); + pad_after.Set(ii, pad_tail[i]); + + auto out_dim = tvm::ir::Simplify( + indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); + + out_shape.Set(ii, out_dim); + } + + if (pool_type == kMaxPool) { + auto temp = do_pad ? pad(x, pad_before, pad_after, x->dtype.min(), "pad_temp") : x; + return tvm::compute(out_shape, [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + indices.Set(ii, output[ii] * stride[i] + daxis[i]); + } + + return tvm::max(temp(indices), daxis); + }, "tensor", "pool_max"); + } else if (pool_type == kAvgPool) { + // Pad the inputs + auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; + + // TVM compute for summing the pooling window. + auto pool_sum = tvm::compute(out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + indices.Set(ii, output[ii] * stride[i] + daxis[i]); + } + return tvm::sum(temp(indices), daxis); + }, "tensor", "pool_sum"); + + // TVM compute for dividing the reduced window sum by kernel size. + return tvm::compute(out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + auto kernel_size = make_const(Int(32), 1); + for (int i = 0; i < k_size; i++) { + kernel_size *= kernel[i]; + } + return div(pool_sum(indices), kernel_size); + } else { + std::vector start(k_size); + std::vector end(k_size); + auto kernel_size = make_const(Int(32), 1); + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + start[i] = output[ii] * stride[i] - pad_head[i]; + end[i] = ir::Min::make(start[i] + kernel[i], x->shape[ii]); + start[i] = ir::Max::make(start[i], make_const(Int(32), 0)); + kernel_size *= (end[i] - start[i]); + } + + Expr divide_factor = ir::Max::make(kernel_size, make_const(Int(32), 1)); + return div(pool_sum(indices), divide_factor); + } + }, "tensor", kElementWise); + } else { + LOG(ERROR) << "Unrecognized pool_type: " << pool_type; + return x; + } +} + +/*! +* \brief Perform pooling on depth, height and width dimension of data. +* It decides the depth, height and width dimension according to the layout string, +* in which 'D', 'W' and 'H' means depth, width and height respectively. +* Depth, Width and height dimension cannot be split. +* For example, NCDHW, NCDHW16c, etc. are valid for pool, +* while NCDHW16d, NCDHW16w or NCDHW16h are not. +* See \a layout for more information of the layout string convention. +* \param x The input tensor. +* \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width} +* \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width} +* \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width, +* tail_pad_depth, tail_pad_height, tail_pad_width} +* \param pool_type The type of pooling operator +* \param ceil_mode Whether to use ceil when calculating the output size +* \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear. +* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, +* where upper case indicates a dimension and +* the corresponding lower case (with factor size) indicates the split dimension. +* For example, NCDHW16c can describe a 6-D tensor of +* [batch_size, channel, depth, height, width, channel_block]. +* (in which factor size `16` will not be used in pooling but for other operators, +* it can be used to decide the output shape). +* Since pooling does not care about the factor size of dimensions +* other than `D`, `H` and `W`, one can pass `NCDHWc` as well. +* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' +* +* +* \return The output tensor in the same layout +*/ +inline Tensor pool3d(const Tensor& x, + const Array& kernel_size, + const Array& stride_size, + const Array& padding_size, + PoolType pool_type, + bool ceil_mode, + const std::string& layout = "NCDHW", + bool count_include_pad = true) { + int depth_axis = -1, height_axis = -1, width_axis = -1; + CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) + << "Unsupported layout " << layout; + std::vector axis = {depth_axis, height_axis, width_axis}; + return pool_impl_nd(x, kernel_size, stride_size, padding_size, + pool_type, ceil_mode, axis, count_include_pad); +} + } // namespace nn } // namespace topi #endif // TOPI_NN_POOLING_H_ diff --git a/topi/python/topi/nn/pooling.py b/topi/python/topi/nn/pooling.py index cc75ea6ef908..a8a8215cb49f 100644 --- a/topi/python/topi/nn/pooling.py +++ b/topi/python/topi/nn/pooling.py @@ -289,3 +289,60 @@ def adaptive_pool(data, n-D in the same layout """ return cpp.nn.adaptive_pool(data, output_size, POOL_TYPE_CODE[pool_type], layout) + + +def pool3d(data, + kernel, + stride, + padding, + pool_type, + ceil_mode=False, + layout="NCDHW", + count_include_pad=True): + """Perform pooling on depth, height and width dimension of data. + It decides the depth, height and width dimension according to the layout string, + in which 'D', 'W' and 'H' means depth, width and height respectively. + Depth, width and height dimension cannot be split. + For example, NCDHW, NCDHW16c, etc. are valid for pool, + while NCDHW16d, NCDHW16w, NCDHW16h are not. + See parameter `layout` for more information of the layout string convention. + + Parameters + ---------- + data : tvm.Tensor + n-D with shape of layout + + kernel : list/tuple of three ints + Kernel size, [kernel_depth, kernel_height, kernel_width] + + stride : list/tuple of three ints + Stride size, [stride_depth, stride_height, stride_width] + + padding : list/tuple of six ints + Pad size, [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right] + + pool_type : str + Pool type, 'max' or 'avg' + + ceil_mode : bool + Whether to use ceil when calculating output size. + + layout: string + Layout of the input data. + The layout is supposed to be composed of upper cases, lower cases and numbers, + where upper case indicates a dimension and + the corresponding lower case with factor size indicates the split dimension. + For example, NCDHW16c can describe a 6-D tensor of + [batch_size, channel, depth, height, width, channel_block], + in which channel_block=16 is a split of dimension channel. + + count_include_pad: bool + Whether include padding in the calculation when pool_type is 'avg' + + Returns + ------- + output : tvm.Tensor + n-D in the same layout + """ + return cpp.nn.pool3d(data, kernel, stride, padding, + POOL_TYPE_CODE[pool_type], ceil_mode, layout, count_include_pad) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index f3e6b89a43a8..bf3d9518c509 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -535,6 +535,13 @@ TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool") args[3]); }); +TVM_REGISTER_GLOBAL("topi.nn.pool3d") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::pool3d(args[0], args[1], args[2], args[3], + static_cast(static_cast(args[4])), + args[5], args[6], args[7]); + }); + /* Ops from nn/softmax.h */ TVM_REGISTER_GLOBAL("topi.nn.softmax") .set_body([](TVMArgs args, TVMRetValue *rv) { @@ -599,7 +606,7 @@ TVM_REGISTER_GLOBAL("topi.generic.schedule_injective") TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); - }); + }); /* x86 schedules */ TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack") @@ -629,7 +636,7 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_injective") TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); - }); + }); /* ROCm schedules */ TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda") @@ -701,7 +708,7 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective") TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); - }); + }); TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool") .set_body([](TVMArgs args, TVMRetValue *rv) { @@ -824,7 +831,8 @@ inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder build TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing) .set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) .register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) -.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing)); +.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting( + topi::cuda::schedule_injective_from_existing)); /*! \brief Builder function for instantiating dense ops. */ using FTVMDenseOpBuilder = std::function 0, axis=(2,3,4)) + b_np[:,:,k,i,j] = np.sum(pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], \ + axis=(2,3, 4)) / np.maximum(pad_count, 1) + + elif pool_type =='max': + for k in range(oz): + for i in range(oh): + for j in range(ow): + b_np[:,:,k,i,j] = np.max( \ + pad_np[:, :, k*sz:k*sz+kz, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3,4)) + b_np = np.maximum(b_np, 0.0) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_pool(B, layout) + + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) + f = tvm.build(s, [A, B], device) + f(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in get_all_backend(): + check_device(device) + + +def test_pool3d(): + verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'avg', False, True) + verify_pool3d(1, 256, 31, 3, 3, [1, 1, 2, 2, 2, 1], 'avg', False, True) + verify_pool3d(1, 256, 32, 2, 2, [1, 1, 2, 2, 2, 1], 'avg', False, False) + verify_pool3d(1, 256, 31, 4, 4, [3, 3, 3, 3, 3, 3], 'avg', False, False) + verify_pool3d(1, 256, 31, 4, 4, [0, 0, 0, 0, 0, 0], 'avg', False, False) + verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'max', False) + verify_pool3d(1, 256, 31, 3, 3, [2, 2, 1, 1, 1, 2], 'max', False) + verify_pool3d(1, 256, 31, 3, 3, [2, 2, 1, 1, 1, 2], 'max', True) + + verify_pool3d(1, 256, 31, 3, 3, [2, 1, 0, 5, 4, 3], 'avg', False, True) + verify_pool3d(1, 256, 32, 2, 2, [0, 5, 4, 3, 2, 1], 'avg', False, False) + verify_pool3d(1, 256, 31, 3, 3, [1, 0, 5, 4, 3, 2], 'max', False) + verify_pool3d(1, 256, 31, 3, 3, [3, 2, 1, 0, 5, 4], 'max', True) + if __name__ == "__main__": test_pool() test_pool_grad() test_global_pool() test_adaptive_pool() + test_pool3d()