diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index f4c47c2dae8f..c7b175837d3d 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -686,6 +686,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { Array pool_size; Array strides; Array padding; + Array dilation; tvm::String layout; bool ceil_mode; @@ -694,6 +695,9 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(strides) .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) + .describe("Specifies the dilation of the convolution."); TVM_ATTR_FIELD(padding) .set_default(Array({0, 0})) .describe( @@ -717,6 +721,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { Array pool_size; Array strides; Array padding; + Array dilation; tvm::String layout; bool ceil_mode; bool count_include_pad; @@ -726,6 +731,9 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(strides) .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) + .describe("Specifies the dilation of the convolution."); TVM_ATTR_FIELD(padding) .set_default(Array({0, 0})) .describe( @@ -813,6 +821,7 @@ struct AdaptivePool3DAttrs : public tvm::AttrsNode { struct MaxPool1DAttrs : public tvm::AttrsNode { Array pool_size; Array strides; + Array dilation; Array padding; std::string layout; bool ceil_mode; @@ -822,6 +831,9 @@ struct MaxPool1DAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(strides) .set_default(Array({1})) .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1})) + .describe("Specifies the dilation of the convolution."); TVM_ATTR_FIELD(padding) .set_default(Array({0})) .describe( @@ -843,6 +855,7 @@ struct MaxPool1DAttrs : public tvm::AttrsNode { struct AvgPool1DAttrs : public tvm::AttrsNode { Array pool_size; Array strides; + Array dilation; Array padding; std::string layout; bool ceil_mode; @@ -853,6 +866,9 @@ struct AvgPool1DAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(strides) .set_default(Array({1})) .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1})) + .describe("Specifies the dilation of the convolution."); TVM_ATTR_FIELD(padding) .set_default(Array({0})) .describe( @@ -877,6 +893,7 @@ struct AvgPool1DAttrs : public tvm::AttrsNode { struct MaxPool3DAttrs : public tvm::AttrsNode { Array pool_size; Array strides; + Array dilation; Array padding; std::string layout; bool ceil_mode; @@ -886,6 +903,9 @@ struct MaxPool3DAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(strides) .set_default(Array({1, 1, 1})) .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1, 1})) + .describe("Specifies the dilation of the convolution."); TVM_ATTR_FIELD(padding) .set_default(Array({0, 0, 0})) .describe( @@ -908,6 +928,7 @@ struct MaxPool3DAttrs : public tvm::AttrsNode { struct AvgPool3DAttrs : public tvm::AttrsNode { Array pool_size; Array strides; + Array dilation; Array padding; std::string layout; bool ceil_mode; @@ -918,6 +939,9 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(strides) .set_default(Array({1, 1, 1})) .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1, 1})) + .describe("Specifies the dilation of the convolution."); TVM_ATTR_FIELD(padding) .set_default(Array({0, 0, 0})) .describe( diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index e40759907e6b..db3e74a3da5c 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -46,136 +46,6 @@ enum PoolType : int { kMaxPool, }; -/*! - * \brief Perform pooling on height and width dimension of data. - * - * \param x The input tensor - * \param kernel_size Vector of two ints: {kernel_height, kernel_width} - * \param stride_size Vector of two ints: {stride_height, stride_width} - * \param padding_size Vector of two ints: {padding_height, padding_width} - * \param pool_type The type of pooling operator - * \param ceil_mode Whether to use ceil when calculating the output size - * \param height_axis index of the height dimension - * \param width_axis index of the width dimension - * \param count_include_pad Whether include padding in the calculation - * - * \return The output tensor in same layout order - */ -inline Tensor pool_impl(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& padding_size, - PoolType pool_type, bool ceil_mode, const size_t height_axis, - const size_t width_axis, bool count_include_pad) { - ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; - ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; - ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; - ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements"; - - auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]); - auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]); - auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]); - auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]); - - auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]); - auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]); - - auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]); - auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]); - auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]); - auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]); - - if (ceil_mode) { - // Additional padding to ensure we do ceil instead of floor when - // dividing by stride. - pad_bottom += stride_height - 1; - pad_right += stride_width - 1; - } - - Array pad_before(std::vector(x->shape.size(), 0)); - pad_before.Set(height_axis, pad_top); - pad_before.Set(width_axis, pad_left); - - Array pad_after(std::vector(x->shape.size(), 0)); - pad_after.Set(height_axis, pad_bottom); - pad_after.Set(width_axis, pad_right); - arith::Analyzer analyzer; - auto out_height = - analyzer.Simplify(indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1); - auto out_width = - analyzer.Simplify(indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); - - auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh"); - auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw"); - - Array out_shape = x->shape; - for (size_t i = 0; i < out_shape.size(); ++i) { - out_shape.Set(i, cast(DataType::DataType::Int(32), out_shape[i])); - } - out_shape.Set(height_axis, out_height); - out_shape.Set(width_axis, out_width); - - const int64_t* padding_h0 = as_const_int(pad_top); - const int64_t* padding_w0 = as_const_int(pad_left); - const int64_t* padding_h1 = as_const_int(pad_bottom); - const int64_t* padding_w1 = as_const_int(pad_right); - const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) || - ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); - - if (pool_type == kMaxPool) { - auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - return tvm::te::compute( - out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - indices.Set(height_axis, output[height_axis] * stride_height + dheight); - indices.Set(width_axis, output[width_axis] * stride_width + dwidth); - return tvm::max(temp(indices), {dheight, dwidth}); - }, - "tensor", "pool_max"); - } else if (pool_type == kAvgPool) { - // Pad the inputs - auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; - - // TVM compute for summing the pooling window. - auto pool_sum = tvm::te::compute( - out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - indices.Set(height_axis, output[height_axis] * stride_height + dheight); - indices.Set(width_axis, output[width_axis] * stride_width + dwidth); - return tvm::sum(temp(indices), {dheight, dwidth}); - }, - "tensor", "pool_sum"); - - // TVM compute for dividing the reduced window sum by kernel size. - return tvm::te::compute( - out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - if (count_include_pad) { - return div(pool_sum(indices), (kernel_height * kernel_width)); - } else { - PrimExpr h_start = output[height_axis] * stride_height - pad_top; - PrimExpr w_start = output[width_axis] * stride_width - pad_left; - - PrimExpr h_end = min(h_start + kernel_height, height); - PrimExpr w_end = min(w_start + kernel_width, width); - h_start = max(h_start, make_const(DataType::DataType::Int(32), 0)); - w_start = max(w_start, make_const(DataType::DataType::Int(32), 0)); - PrimExpr divide_factor = max((h_end - h_start) * (w_end - w_start), - make_const(DataType::DataType::Int(32), 1)); - return div(pool_sum(indices), divide_factor); - } - }, - "tensor", kElementWise); - } else { - LOG(ERROR) << "Unrecognized pool_type: " << pool_type; - return x; - } -} - inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, const Array& kernel_size, const Array& stride_size, const Array& padding_size, PoolType pool_type, @@ -390,45 +260,6 @@ inline bool find_width(const std::string& layout, int* width_axis) { return false; } -/*! - * \brief Perform pooling on height and width dimension of data. - * It decides the height and width dimension according to the layout string, - * in which 'W' and 'H' means width and height respectively. - * Width and height dimension cannot be split. - * For example, NCHW, NCHW16c, etc. are valid for pool, - * while NCHW16w, NCHW16h are not. - * See \a layout for more information of the layout string convention. - * \param x The input tensor. - * \param kernel_size Vector of two ints: {kernel_height, kernel_width} - * \param stride_size Vector of two ints: {stride_height, stride_width} - * \param padding_size Vector of two ints: {padding_height, padding_width} - * \param pool_type The type of pooling operator - * \param ceil_mode Whether to use ceil when calculating the output size - * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. - * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, - * where upper case indicates a dimension and - * the corresponding lower case (with factor size) indicates the split dimension. - * For example, NCHW16c can describe a 5-D tensor of - * [batch_size, channel, height, width, channel_block]. - * (in which factor size `16` will not be used in pooling but for other operators, - * it can be used to decide the output shape). - * Since pooling does not care about the factor size of dimensions - * other than `H` and `W`, one can pass `NCHWc` as well. - * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' - * - * - * \return The output tensor in the same layout - */ -inline Tensor pool(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& padding_size, - PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", - bool count_include_pad = true) { - int height_axis = -1, width_axis = -1; - ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; - return pool_impl(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, height_axis, - width_axis, count_include_pad); -} - /*! * \brief Calculate gradient of pooling on height and width dimension of data. * It decides the height and width dimension according to the layout string, @@ -663,6 +494,7 @@ inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string * \param x The input tensor * \param kernel_size Vector of N ints * \param stride_size Vector of N ints + * \param dilation_size Vector of N ints * \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ..., * head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN] * \param pool_type The type of pooling operator @@ -673,9 +505,9 @@ inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string * \return The output tensor in same layout order */ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& padding_size, - PoolType pool_type, bool ceil_mode, const std::vector& axis, - bool count_include_pad) { + const Array& stride_size, const Array& dilation_size, + const Array& padding_size, PoolType pool_type, bool ceil_mode, + const std::vector& axis, bool count_include_pad) { int k_size = kernel_size.size(); int x_size = x->shape.size(); ICHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel"; @@ -686,6 +518,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, Array daxis; std::vector kernel(k_size); std::vector stride(k_size); + std::vector dilation(k_size); std::vector pad_head(k_size); std::vector pad_tail(k_size); Array pad_before(std::vector(x_size, 0)); @@ -701,11 +534,9 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, int ii = axis[i]; kernel[i] = cast(DataType::Int(32), kernel_size[i]); stride[i] = cast(DataType::Int(32), stride_size[i]); + dilation[i] = cast(DataType::Int(32), dilation_size[i]); pad_head[i] = cast(DataType::Int(32), padding_size[i]); pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]); - const int64_t* padding0 = as_const_int(pad_head[i]); - const int64_t* padding1 = as_const_int(pad_tail[i]); - do_pad = (do_pad) ? do_pad : ((padding0 && *padding0) || (padding1 && *padding1)); if (ceil_mode) { // Additional padding to ensure we do ceil instead of floor when @@ -713,15 +544,20 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, pad_tail[i] += stride[i] - 1; } + const int64_t* padding0 = as_const_int(pad_head[i]); + const int64_t* padding1 = as_const_int(pad_tail[i]); + do_pad = do_pad || (padding0 && *padding0) || (padding1 && *padding1); + daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i]), "rv" + std::to_string(i))); pad_before.Set(ii, pad_head[i]); pad_after.Set(ii, pad_tail[i]); arith::Analyzer analyzer; - auto out_dim = analyzer.Simplify( - indexdiv(data_shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); + PrimExpr numerator = + data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i]; + auto out_dim = analyzer.Simplify(indexdiv(numerator, stride[i]) + 1); out_shape.Set(ii, out_dim); } @@ -735,9 +571,8 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, for (int i = 0; i < k_size; i++) { int ii = axis[i]; - indices.Set(ii, output[ii] * stride[i] + daxis[i]); + indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]); } - return tvm::max(temp(indices), daxis); }, "tensor", "pool_max"); @@ -754,7 +589,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, for (int i = 0; i < k_size; i++) { int ii = axis[i]; - indices.Set(ii, output[ii] * stride[i] + daxis[i]); + indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]); } return tvm::sum(temp(indices), daxis); }, @@ -767,24 +602,36 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, Array indices; for (const Var& var : output) indices.push_back(var); if (count_include_pad) { - auto kernel_size = make_const(DataType::Int(32), 1); + auto num_el = make_const(DataType::Int(32), 1); for (int i = 0; i < k_size; i++) { - kernel_size *= kernel[i]; + num_el *= kernel[i]; } - return div(pool_sum(indices), kernel_size); + return div(pool_sum(indices), num_el); } else { std::vector start(k_size); std::vector end(k_size); - auto kernel_size = make_const(DataType::Int(32), 1); + auto num_el = make_const(DataType::Int(32), 1); for (int i = 0; i < k_size; i++) { int ii = axis[i]; + + // Let start and end contain the first and last index of our Tensor + // along the relevant dimension we use in our calculation. + // Assume indices -1, -2 represent the padding before (tail) and + // len(arr), len(arr) + 1 represent the padding after (head). start[i] = output[ii] * stride[i] - pad_head[i]; - end[i] = min(start[i] + kernel[i], data_shape[ii]); - start[i] = max(start[i], make_const(DataType::Int(32), 0)); - kernel_size *= (end[i] - start[i]); + end[i] = start[i] + (kernel[i] - 1) * dilation[i]; + + // if start[i] < 0, e.g. we start on a tail padded number this will be a positive + // number that represents the number of steps along the dilated kernel to reach a + // non-padded value. Otherwise this should be 0. + PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i]; + jumps_to_non_pad = max(jumps_to_non_pad, make_const(DataType::Int(32), 0)); + + end[i] = min(end[i], data_shape[ii] - 1); + num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1; } - PrimExpr divide_factor = max(kernel_size, make_const(DataType::Int(32), 1)); + PrimExpr divide_factor = max(num_el, make_const(DataType::Int(32), 1)); return div(pool_sum(indices), divide_factor); } }, @@ -804,9 +651,10 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, * while NCW16w is not. * See \a layout for more information of the layout string convention. * \param x The input tensor. - * \param kernel_size Vector of three ints: {kernel_width} - * \param stride_size Vector of three ints: {stride_width} - * \param padding_size Vector of six ints: {head_pad_width, tail_pad_width} + * \param kernel_size Vector of one int: {kernel_width} + * \param stride_size Vector of one int: {stride_width} + * \param dilation_size Vector of one int: {dilation_width} + * \param padding_size Vector of two ints: {head_pad_width, tail_pad_width} * \param pool_type The type of pooling operator * \param ceil_mode Whether to use ceil when calculating the output size * \param layout The input layout. Pooling supports any layout as long as 'W' appears. @@ -825,14 +673,55 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, * \return The output tensor in the same layout */ inline Tensor pool1d(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& padding_size, - PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW", - bool count_include_pad = true) { + const Array& stride_size, const Array& dilation_size, + const Array& padding_size, PoolType pool_type, bool ceil_mode, + const std::string& layout = "NCW", bool count_include_pad = true) { int width_axis = -1; ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; std::vector axis = {width_axis}; - return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis, - count_include_pad); + return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type, + ceil_mode, axis, count_include_pad); +} + +/*! + * \brief Perform pooling on height and width dimension of data. + * It decides the height and width dimension according to the layout string, + * in which 'W' and 'H' means width and height respectively. + * Width and height dimension cannot be split. + * For example, NCHW, NCHW16c, etc. are valid for pool, + * while NCHW16w, NCHW16h are not. + * See \a layout for more information of the layout string convention. + * \param x The input tensor. + * \param kernel_size Vector of two ints: {kernel_height, kernel_width} + * \param stride_size Vector of two ints: {stride_height, stride_width} + * \param dilation_size Vector of two ints: {dilation_height, dilation_width} + * \param padding_size Vector of two ints: {padding_height, padding_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `H` and `W`, one can pass `NCHWc` as well. + * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' + * + * + * \return The output tensor in the same layout + */ +inline Tensor pool2d(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& dilation_size, + const Array& padding_size, PoolType pool_type, bool ceil_mode, + const std::string& layout = "NCHW", bool count_include_pad = true) { + int height_axis = -1, width_axis = -1; + ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; + std::vector axis = {height_axis, width_axis}; + return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type, + ceil_mode, axis, count_include_pad); } /*! @@ -846,6 +735,7 @@ inline Tensor pool1d(const Tensor& x, const Array& kernel_size, * \param x The input tensor. * \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width} * \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width} + * \param dilation_size Vector of three ints: {dilation_depth, dilation_height, dilation_width} * \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width, * tail_pad_depth, tail_pad_height, tail_pad_width} * \param pool_type The type of pooling operator @@ -866,15 +756,15 @@ inline Tensor pool1d(const Tensor& x, const Array& kernel_size, * \return The output tensor in the same layout */ inline Tensor pool3d(const Tensor& x, const Array& kernel_size, - const Array& stride_size, const Array& padding_size, - PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW", - bool count_include_pad = true) { + const Array& stride_size, const Array& dilation_size, + const Array& padding_size, PoolType pool_type, bool ceil_mode, + const std::string& layout = "NCDHW", bool count_include_pad = true) { int depth_axis = -1, height_axis = -1, width_axis = -1; ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) << "Unsupported layout " << layout; std::vector axis = {depth_axis, height_axis, width_axis}; - return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis, - count_include_pad); + return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type, + ceil_mode, axis, count_include_pad); } } // namespace nn diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ff0328275604..a62e505b287a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -19,6 +19,7 @@ """ONNX: Open Neural Network Exchange frontend for Relay.""" import copy import warnings + import numpy as np import tvm from tvm.ir import IRModule @@ -28,16 +29,23 @@ from .. import analysis from .. import expr as _expr from .. import function as _function +from .. import loops as _loops from .. import op as _op from .. import qnn as _qnn -from .. import vision as _vision -from .. import loops as _loops from .. import ty as _ty - -from .common import AttrCvt, Renamer -from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value, fold_constant -from .common import infer_type, get_name - +from .. import vision as _vision +from .common import ( + AttrCvt, + Renamer, + fold_constant, + get_name, + get_relay_op, + infer_channels, + infer_shape, + infer_type, + infer_value, + new_var, +) __all__ = ["from_onnx"] @@ -328,8 +336,12 @@ def _impl_v1(cls, inputs, attr, params): return AttrCvt( op_name=dimension_picker(cls.name), - transforms={"kernel_shape": "pool_size", "pads": ("padding", 0)}, - ignores=["dilations", "storage_order"], + transforms={ + "kernel_shape": "pool_size", + "pads": ("padding", 0), + "dilations": ("dilation", 1), + }, + ignores=["storage_order"], custom_check=dimension_constraint(), )([data], attr, params) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 025942bcfa22..1ac43750e6b6 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -21,31 +21,28 @@ """PT: PyTorch frontend.""" import itertools import logging -import sys import math +import sys import numpy as np - import tvm -from tvm.topi.utils import get_const_tuple from tvm.ir import IRModule +from tvm.topi.utils import get_const_tuple from .. import analysis as _analysis from .. import expr as _expr from .. import function as _function from .. import op as _op -from .. import qnn -from ..ty import TupleType, TensorType, Any +from .. import qnn, transform +from ..expr_functor import ExprMutator from ..loops import while_loop -from .. import transform +from ..prelude import Prelude, StaticTensorArrayOps +from ..ty import Any, TensorType, TupleType +from . import qnn_torch from .common import AttrCvt, get_relay_op from .common import infer_value as _infer_value -from .common import try_infer_value from .common import infer_value_simulated as _infer_value_simulated -from ..prelude import Prelude, StaticTensorArrayOps -from ..expr_functor import ExprMutator - -from . import qnn_torch +from .common import try_infer_value from .pytorch_utils import is_version_greater_than __all__ = ["from_pytorch"] @@ -883,11 +880,15 @@ def maxpool_2d(self, inputs, input_types): dilation = inputs[4] ceil_mode = int(inputs[5]) - if dilation != [1, 1]: - msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation)) - raise NotImplementedError(msg) - - return _op.nn.max_pool2d(data, pool_size, strides, padding, "NCHW", ceil_mode) + return _op.nn.max_pool2d( + data, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=padding, + layout="NCHW", + ceil_mode=ceil_mode, + ) def maxpool_2d_with_indices(self, inputs, input_types): # returns dummy indices too @@ -902,11 +903,15 @@ def maxpool_1d(self, inputs, input_types): dilation = inputs[4] ceil_mode = int(inputs[5]) - if dilation != [1]: - msg = "MaxPool1d with dilation %s is not implemented" % (str(dilation)) - raise NotImplementedError(msg) - - return _op.nn.max_pool1d(data, pool_size, strides, padding, "NCW", ceil_mode) + return _op.nn.max_pool1d( + data, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=padding, + layout="NCW", + ceil_mode=ceil_mode, + ) def maxpool_3d(self, inputs, input_types): data = inputs[0] @@ -916,12 +921,14 @@ def maxpool_3d(self, inputs, input_types): padding = inputs[3] dilation = inputs[4] ceil_mode = int(inputs[5]) - if dilation != [1, 1, 1]: - msg = "MaxPool3d with dilation %s is not implemented" % (str(dilation)) - raise NotImplementedError(msg) return _op.nn.max_pool3d( - data, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode + data, + pool_size=pool_size, + strides=strides, + dilation=dilation, + padding=padding, + ceil_mode=ceil_mode, ) def hardtanh(self, inputs, input_types): @@ -1370,6 +1377,7 @@ def func(x): pool_size=pool_size, strides=strides, padding=padding, + dilation=(1,), ceil_mode=ceil_mode, count_include_pad=count_include_pad, ) @@ -1379,6 +1387,7 @@ def func(x): pool_size=pool_size, strides=strides, padding=padding, + dilation=(1, 1), ceil_mode=ceil_mode, count_include_pad=count_include_pad, ) @@ -1388,6 +1397,7 @@ def func(x): pool_size=pool_size, strides=strides, padding=padding, + dilation=(1, 1, 1), ceil_mode=ceil_mode, count_include_pad=count_include_pad, ) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index c449651f1130..ba491954ac63 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -747,7 +747,9 @@ def log_softmax(data, axis=-1): return _make.log_softmax(data, axis) -def max_pool1d(data, pool_size=(1,), strides=(1,), padding=(0,), layout="NCW", ceil_mode=False): +def max_pool1d( + data, pool_size=(1,), strides=(1,), dilation=(1,), padding=(0,), layout="NCW", ceil_mode=False +): r"""1D maximum pooling operator. This operator takes data as input and does 1D max value calculation @@ -772,6 +774,9 @@ def max_pool1d(data, pool_size=(1,), strides=(1,), padding=(0,), layout="NCW", c strides : int or tuple of int, optional The strides of pooling. + dilation : int or tuple of int, optional + The dilation of pooling. + padding : int or tuple of int, optional The padding for pooling. @@ -790,12 +795,20 @@ def max_pool1d(data, pool_size=(1,), strides=(1,), padding=(0,), layout="NCW", c pool_size = (pool_size,) if isinstance(strides, int): strides = (strides,) + if isinstance(dilation, int): + dilation = (dilation,) padding = get_pad_tuple1d(padding) - return _make.max_pool1d(data, pool_size, strides, padding, layout, ceil_mode) + return _make.max_pool1d(data, pool_size, strides, dilation, padding, layout, ceil_mode) def max_pool2d( - data, pool_size=(1, 1), strides=(1, 1), padding=(0, 0), layout="NCHW", ceil_mode=False + data, + pool_size=(1, 1), + strides=(1, 1), + dilation=(1, 1), + padding=(0, 0), + layout="NCHW", + ceil_mode=False, ): r"""2D maximum pooling operator. @@ -829,6 +842,9 @@ def max_pool2d( strides : tuple of int, optional The strides of pooling. + dilation : int or tuple of int, optional + The dilation of pooling. + padding : tuple of int, optional The padding for pooling. @@ -847,12 +863,20 @@ def max_pool2d( pool_size = (pool_size, pool_size) if isinstance(strides, int): strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) padding = get_pad_tuple2d(padding) - return _make.max_pool2d(data, pool_size, strides, padding, layout, ceil_mode) + return _make.max_pool2d(data, pool_size, strides, dilation, padding, layout, ceil_mode) def max_pool3d( - data, pool_size=(1, 1, 1), strides=(1, 1, 1), padding=(0, 0, 0), layout="NCDHW", ceil_mode=False + data, + pool_size=(1, 1, 1), + strides=(1, 1, 1), + dilation=(1, 1, 1), + padding=(0, 0, 0), + layout="NCDHW", + ceil_mode=False, ): r"""3D maximum pooling operator. @@ -879,6 +903,9 @@ def max_pool3d( strides : tuple of int, optional The strides of pooling. + dilation : int or tuple of int, optional + The dilation of pooling. + padding : tuple of int, optional The padding for pooling. @@ -897,14 +924,17 @@ def max_pool3d( pool_size = (pool_size, pool_size, pool_size) if isinstance(strides, int): strides = (strides, strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) padding = get_pad_tuple3d(padding) - return _make.max_pool3d(data, pool_size, strides, padding, layout, ceil_mode) + return _make.max_pool3d(data, pool_size, strides, dilation, padding, layout, ceil_mode) def avg_pool1d( data, pool_size=(1,), strides=(1,), + dilation=(1,), padding=(0,), layout="NCW", ceil_mode=False, @@ -934,6 +964,9 @@ def avg_pool1d( strides : int or tuple of int, optional The strides of pooling. + dilation : int or tuple of int, optional + The dilation of pooling. + padding : int or tuple of int, optional The padding for pooling. @@ -955,14 +988,19 @@ def avg_pool1d( pool_size = (pool_size,) if isinstance(strides, int): strides = (strides,) + if isinstance(dilation, int): + dilation = (dilation,) padding = get_pad_tuple1d(padding) - return _make.avg_pool1d(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad) + return _make.avg_pool1d( + data, pool_size, strides, dilation, padding, layout, ceil_mode, count_include_pad + ) def avg_pool2d( data, pool_size=(1, 1), strides=(1, 1), + dilation=(1, 1), padding=(0, 0), layout="NCHW", ceil_mode=False, @@ -1001,6 +1039,9 @@ def avg_pool2d( strides : tuple of int, optional The strides of pooling. + dilation : int or tuple of int, optional + The dilation of pooling. + padding : tuple of int, optional The padding for pooling. @@ -1022,14 +1063,19 @@ def avg_pool2d( pool_size = (pool_size, pool_size) if isinstance(strides, int): strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) padding = get_pad_tuple2d(padding) - return _make.avg_pool2d(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad) + return _make.avg_pool2d( + data, pool_size, strides, dilation, padding, layout, ceil_mode, count_include_pad + ) def avg_pool3d( data, pool_size=(1, 1, 1), strides=(1, 1, 1), + dilation=(1, 1, 1), padding=(0, 0, 0), layout="NCDHW", ceil_mode=False, @@ -1060,6 +1106,9 @@ def avg_pool3d( strides : tuple of int, optional The strides of pooling. + dilation : int or tuple of int, optional + The dilation of pooling. + padding : tuple of int, optional The padding for pooling. @@ -1081,12 +1130,22 @@ def avg_pool3d( pool_size = (pool_size, pool_size, pool_size) if isinstance(strides, int): strides = (strides, strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) padding = get_pad_tuple3d(padding) - return _make.avg_pool3d(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad) + return _make.avg_pool3d( + data, pool_size, strides, dilation, padding, layout, ceil_mode, count_include_pad + ) def max_pool2d_grad( - out_grad, data, pool_size=(1, 1), strides=(1, 1), padding=(0, 0), layout="NCHW", ceil_mode=False + out_grad, + data, + pool_size=(1, 1), + strides=(1, 1), + padding=(0, 0), + layout="NCHW", + ceil_mode=False, ): r"""Gradient of 2D maximum pooling operator. diff --git a/python/tvm/topi/nn/pooling.py b/python/tvm/topi/nn/pooling.py index 8c4be5a5cb35..df3888980773 100644 --- a/python/tvm/topi/nn/pooling.py +++ b/python/tvm/topi/nn/pooling.py @@ -16,6 +16,7 @@ # under the License. """TVM operator pooling compute.""" from __future__ import absolute_import + from .. import cpp POOL_TYPE_CODE = {"avg": 0, "max": 1} @@ -56,66 +57,6 @@ def global_pool(data, pool_type, layout="NCHW"): return cpp.nn.global_pool(data, POOL_TYPE_CODE[pool_type], layout) -def pool( - data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCHW", count_include_pad=True -): - """Perform pooling on height and width dimension of data. - It decides the height and width dimension according to the layout string, - in which 'W' and 'H' means width and height respectively. - Width and height dimension cannot be split. - For example, NCHW, NCHW16c, etc. are valid for pool, - while NCHW16w, NCHW16h are not. - See parameter `layout` for more information of the layout string convention. - - Parameters - ---------- - data : tvm.te.Tensor - n-D with shape of layout - - kernel : list/tuple of two ints - Kernel size, [kernel_height, kernel_width] - - stride : list/tuple of two ints - Stride size, [stride_height, stride_width] - - padding : list/tuple of four ints - Pad size, [pad_top, pad_left, pad_bottom, pad_right]] - - pool_type : str - Pool type, 'max' or 'avg' - - ceil_mode : bool - Whether to use ceil when calculating output size. - - layout: string - Layout of the input data. - The layout is supposed to be composed of upper cases, lower cases and numbers, - where upper case indicates a dimension and - the corresponding lower case with factor size indicates the split dimension. - For example, NCHW16c can describe a 5-D tensor of - [batch_size, channel, height, width, channel_block], - in which channel_block=16 is a split of dimension channel. - - count_include_pad: bool - Whether include padding in the calculation when pool_type is 'avg' - - Returns - ------- - output : tvm.te.Tensor - n-D in the same layout - """ - return cpp.nn.pool( - data, - kernel, - stride, - padding, - POOL_TYPE_CODE[pool_type], - ceil_mode, - layout, - count_include_pad, - ) - - def pool_grad( grads, data, @@ -235,7 +176,15 @@ def adaptive_pool3d(data, output_size, pool_type, layout="NCDHW"): def pool1d( - data, kernel, stride, padding, pool_type, ceil_mode=False, layout="NCW", count_include_pad=True + data, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode=False, + layout="NCW", + count_include_pad=True, ): """Perform pooling on width dimension of data. Width axis is determined according to the layout string. @@ -294,6 +243,76 @@ def pool1d( data, kernel, stride, + dilation, + padding, + POOL_TYPE_CODE[pool_type], + ceil_mode, + layout, + count_include_pad, + ) + + +def pool2d( + data, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode=False, + layout="NCHW", + count_include_pad=True, +): + """Perform pooling on height and width dimension of data. + It decides the height and width dimension according to the layout string, + in which 'W' and 'H' means width and height respectively. + Width and height dimension cannot be split. + For example, NCHW, NCHW16c, etc. are valid for pool, + while NCHW16w, NCHW16h are not. + See parameter `layout` for more information of the layout string convention. + + Parameters + ---------- + data : tvm.te.Tensor + n-D with shape of layout + + kernel : list/tuple of two ints + Kernel size, [kernel_height, kernel_width] + + stride : list/tuple of two ints + Stride size, [stride_height, stride_width] + + padding : list/tuple of four ints + Pad size, [pad_top, pad_left, pad_bottom, pad_right]] + + pool_type : str + Pool type, 'max' or 'avg' + + ceil_mode : bool + Whether to use ceil when calculating output size. + + layout: string + Layout of the input data. + The layout is supposed to be composed of upper cases, lower cases and numbers, + where upper case indicates a dimension and + the corresponding lower case with factor size indicates the split dimension. + For example, NCHW16c can describe a 5-D tensor of + [batch_size, channel, height, width, channel_block], + in which channel_block=16 is a split of dimension channel. + + count_include_pad: bool + Whether include padding in the calculation when pool_type is 'avg' + + Returns + ------- + output : tvm.te.Tensor + n-D in the same layout + """ + return cpp.nn.pool2d( + data, + kernel, + stride, + dilation, padding, POOL_TYPE_CODE[pool_type], ceil_mode, @@ -306,6 +325,7 @@ def pool3d( data, kernel, stride, + dilation, padding, pool_type, ceil_mode=False, @@ -361,6 +381,7 @@ def pool3d( data, kernel, stride, + dilation, padding, POOL_TYPE_CODE[pool_type], ceil_mode, diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index ef36b9e73446..ef7d86322be7 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -49,8 +49,7 @@ from .batch_matmul import batch_matmul from .slice_axis_python import slice_axis_python from .sequence_mask_python import sequence_mask -from .pool1d_python import pool1d_ncw_python -from .pool3d_python import pool3d_ncdhw_python +from .poolnd_python import poolnd_python from .pool_grad_python import pool_grad_nchw from .one_hot import one_hot from .depth_to_space import depth_to_space_python diff --git a/python/tvm/topi/testing/pool1d_python.py b/python/tvm/topi/testing/pool1d_python.py deleted file mode 100644 index d83b7224434f..000000000000 --- a/python/tvm/topi/testing/pool1d_python.py +++ /dev/null @@ -1,69 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-argument, unused-variable -"""max_pool1d and avg_pool1d in python""" -import math -import numpy as np - - -def pool1d_ncw_python( - np_data, - kernel, - strides, - padding, - out_shape, - pool_type, - count_include_pad=True, - ceil_mode=False, - dtype="float32", -): - """Baseline for max_pool1d and avg_pool1d, default layout is NCW""" - in_n, in_c, in_w = in_shape = np_data.shape - k_w = kernel[0] - s_w = strides[0] - pl, pr = padding - - if ceil_mode: - assert out_shape[2] == int(math.ceil(float(in_shape[2] - k_w + pl + pr) / s_w) + 1) - else: - assert out_shape[2] == int(math.floor(float(in_shape[2] - k_w + pl + pr) / s_w) + 1) - - pad_np = np.zeros(shape=(in_n, in_c, in_w + pl + pr)).astype(dtype) - - no_zero = (range(in_n), range(in_c), range(pl, in_w + pl)) - pad_np[np.ix_(*no_zero)] = np_data - ret_np = np.zeros(shape=out_shape).astype(dtype) - - if pool_type == "avg": - for k in range(out_shape[2]): - if count_include_pad: - ret_np[:, :, k] = np.mean(pad_np[:, :, k * s_w : k * s_w + k_w], axis=(2,)) - else: - pad_count = np.sum(pad_np[:, :, k * s_w : k * s_w + k_w] > 0, axis=(2,)) - ret_np[:, :, k] = np.sum( - pad_np[:, :, k * s_w : k * s_w + k_w], axis=(2,) - ) / np.maximum(pad_count, 1) - - elif pool_type == "max": - for k in range(out_shape[2]): - ret_np[:, :, k] = np.max(pad_np[:, :, k * s_w : k * s_w + k_w], axis=(2,)) - - else: - raise ValueError("Pool type {} is not supported".format(pool_type)) - - ret_np = np.maximum(ret_np, 0.0) - return ret_np diff --git a/python/tvm/topi/testing/pool3d_python.py b/python/tvm/topi/testing/pool3d_python.py deleted file mode 100644 index 8c687f737166..000000000000 --- a/python/tvm/topi/testing/pool3d_python.py +++ /dev/null @@ -1,111 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-argument, unused-variable -"""max_pool3d and avg_pool3d in python""" -import math -import numpy as np -import tvm - - -def pool3d_ncdhw_python( - np_data, - kernel, - strides, - padding, - out_shape, - pool_type, - count_include_pad=True, - ceil_mode=False, - dtype="float32", -): - """baseline for max_pool3d and avg_pool3d, default layout is "NCDHW""" - # fmt: off - in_n, in_c, in_d, in_h, in_w = in_shape = np_data.shape - if isinstance(kernel, int): - k_d = k_h = k_w = kernel - else: - k_d, k_h, k_w = kernel - if isinstance(strides, int): - s_d = s_h = s_w = strides - else: - s_d, s_h, s_w = strides - if isinstance(padding, int): - pf = pt = pl = pk = pb = pr = padding - else: - pf, pt, pl, pk, pb, pr = padding - - if ceil_mode: - assert out_shape[2] == int(math.ceil(float(in_shape[2] - k_d + pf + pk) / s_d) + 1) - assert out_shape[3] == int(math.ceil(float(in_shape[3] - k_h + pt + pb) / s_h) + 1) - assert out_shape[4] == int(math.ceil(float(in_shape[4] - k_w + pl + pr) / s_w) + 1) - else: - assert out_shape[2] == int(math.floor(float(in_shape[2] - k_d + pf + pk) / s_d) + 1) - assert out_shape[3] == int(math.floor(float(in_shape[3] - k_h + pt + pb) / s_h) + 1) - assert out_shape[4] == int(math.floor(float(in_shape[4] - k_w + pl + pr) / s_w) + 1) - - fill_value = tvm.tir.const(0.0, dtype).value - if not(count_include_pad) and pool_type == 'max': - fill_value = tvm.te.min_value(dtype).value - - pad_np = np.full(shape=(in_n, in_c, - in_d + pf + pk, - in_h + pt + pb, - in_w + pl + pr), - fill_value=fill_value, - dtype=dtype) - - no_zero = (range(in_n), - range(in_c), - (range(pf, in_d + pf)), - (range(pt, in_h + pt)), - (range(pl, in_w + pl))) - pad_np[np.ix_(*no_zero)] = np_data - ret_np = np.zeros(shape=out_shape).astype(dtype) - - if pool_type == 'avg': - for k in range(out_shape[2]): - for i in range(out_shape[3]): - for j in range(out_shape[4]): - if count_include_pad: - ret_np[:, :, k, i, j] = \ - np.mean(pad_np[:, :, k * s_d: k * s_d + k_d, - i * s_h: i * s_h + k_h, - j * s_w: j * s_w + k_w], axis=(2, 3, 4)) - else: - pad_count = np.sum(pad_np[:, :, - k * s_d: k * s_d + k_d, - i * s_h: i * s_h + k_h, - j * s_w: j * s_w + k_w] > 0, axis=(2, 3, 4)) - ret_np[:, :, k, i, j] = np.sum(pad_np[:, :, - k * s_d: k * s_d + k_d, - i * s_h: i * s_h + k_h, - j * s_w: j * s_w + k_w], - axis=(2, 3, 4)) / np.maximum(pad_count, 1) - elif pool_type == 'max': - for k in range(out_shape[2]): - for i in range(out_shape[3]): - for j in range(out_shape[4]): - ret_np[:, :, k, i, j] = np.max( - pad_np[:, :, k * s_d: k * s_d + k_d, - i * s_h: i * s_h + k_h, - j * s_w: j * s_w + k_w], axis=(2, 3, 4)) - else: - raise ValueError("pool type {} is not supported".format(pool_type)) - - ret_np = np.maximum(ret_np, fill_value) - # fmt: on - return ret_np diff --git a/python/tvm/topi/testing/poolnd_python.py b/python/tvm/topi/testing/poolnd_python.py new file mode 100644 index 000000000000..43440d32f44e --- /dev/null +++ b/python/tvm/topi/testing/poolnd_python.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, unused-variable +"""Ground truth max and average pooling operators in python.""" +import itertools +import math +from typing import List, Tuple + +import numpy as np +import tvm + + +def get_slice( + spatial_dimensions: int, + pad_np: np.array, + dim_coord: Tuple[int], + kernel: Tuple[int], + strides: Tuple[int], + dilation: Tuple[int], +) -> List[slice]: + """ + Programmatically create a slice object of the right dimensions for pad_np. + + We assume pad_np's first two dimensions are not spatial and are not touched by the pad. + + pad_np[slice] should give the elements of the data that a pool operation will use for the + step given in dim_coord. + """ + slices = [slice(None)] * spatial_dimensions + + for nd in range(spatial_dimensions): + slices[nd] = slice( + dim_coord[nd] * strides[nd], + dim_coord[nd] * strides[nd] + (kernel[nd] - 1) * dilation[nd] + 1, + dilation[nd], + ) + + # Add back batch and channel dimensions + slices = [slice(None), slice(None)] + slices + + return slices + + +def pad_tensor( + np_arr: np.array, + pad_value: float, + padding_before: List[int], + padding_after: List[int], + dtype: str, +) -> np.array: + """Pad the spatial dimensions of the given array.""" + orig_shape = list(np_arr.shape) + padded_shape = list(np_arr.shape) + n = len(orig_shape) + for dim in range(2, n): + i = dim - 2 + padded_shape[dim] += padding_after[i] + padding_before[i] + + pad_np = (np.zeros(shape=padded_shape) + pad_value).astype(dtype) + ranges_it = [range(padded_shape[0]), range(padded_shape[1])] + for dim in range(2, n): + i = dim - 2 + ranges_it.append(range(padding_before[i], padding_before[i] + orig_shape[dim])) + pad_np[np.ix_(*ranges_it)] = np_arr + return pad_np + + +def poolnd_python( + np_data: np.array, + kernel: Tuple[int], + strides: Tuple[int], + dilation: Tuple[int], + padding_before: Tuple[int], + padding_after: Tuple[int], + pool_type: str, + count_include_pad: bool = True, + ceil_mode: bool = False, + dtype: str = "float32", +) -> np.array: + """Ground truth pooling operator impelmented in numpy.""" + out_shape = [np_data.shape[0], np_data.shape[1]] + for dim in range(2, len(np_data.shape)): + i = dim - 2 + val = ( + float( + np_data.shape[dim] + - (kernel[i] - 1) * dilation[i] + - 1 + + padding_before[i] + + padding_after[i] + ) + / strides[i] + ) + + if ceil_mode: + out_shape.append(int(math.ceil(val) + 1)) + else: + out_shape.append(int(math.floor(val) + 1)) + out_shape = tuple(out_shape) + + # Create a padded array, and a boolean mask showing which values are padded values + pad_value = 0 + if pool_type == "max" and not count_include_pad: + pad_value = tvm.te.min_value(dtype).value + pad_data = pad_tensor(np_data, pad_value, padding_before, padding_after, dtype) + pad_map = pad_tensor(np.ones_like(np_data), 0, padding_before, padding_after, "bool") + + # Create iterator which gives all indices for output array + dim_iterators = [] + for spatial_dimension in range(2, len(np_data.shape)): + dim_iterators.append(range(out_shape[spatial_dimension])) + coord_iterator = itertools.product(*dim_iterators) + + ret_np = np.zeros(shape=out_shape).astype(dtype) + for coordinate in coord_iterator: + # Get index into the values that any pool operation will use for given coordinate + np_index = get_slice( + spatial_dimensions=len(out_shape) - 2, + pad_np=pad_data, + dim_coord=coordinate, + kernel=kernel, + strides=strides, + dilation=dilation, + ) + + output_slice = [slice(None), slice(None)] + list(coordinate) + reduction_axis = tuple(range(2, len(np_data.shape))) + if pool_type == "avg": + count_non_padded = ( + pad_data[np_index].size if count_include_pad else np.sum(pad_map[np_index]) + ) + # We summed over the non spatial dimensions too so divide by them + count_non_padded /= out_shape[0] * out_shape[1] + if count_non_padded == 0: + ret_np[output_slice] = 0 + else: + ret_np[output_slice] = ( + np.sum(pad_data[np_index], axis=reduction_axis) / count_non_padded + ) + elif pool_type == "max": + count_non_padded = np.sum(pad_map[np_index]) + # All padded values, default to 0 + ret_np[output_slice] = np.max(pad_data[np_index], axis=reduction_axis) + else: + raise ValueError("Pool type {} is not supported".format(pool_type)) + + return ret_np diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index cd7f6808845b..08d29f25c32b 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -58,6 +58,20 @@ Array > PoolInferCorrectLayout(const Attrs& attrs, return Array >{{inferred_layout}, {inferred_layout}}; } +IndexExpr calculate_pool_dimension(IndexExpr in_dimension, IndexExpr pad_amount, + IndexExpr pool_size, IndexExpr dilation, IndexExpr stride_size, + bool ceil_mode) { + IndexExpr numerator = in_dimension + pad_amount - ((pool_size - 1) * dilation + 1); + IndexExpr denominator = stride_size; + + // Emulate the behavior of running ceil on numerator / denominator rather than floor + if (ceil_mode) { + numerator += denominator - 1; + } + + return numerator / denominator + 1; +} + template bool Pool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -101,24 +115,16 @@ bool Pool2DRel(const Array& types, int num_inputs, const Attrs& attrs, if (dshape[hidx].as()) { oshape[hidx] = dshape[hidx]; } else { - if (param->ceil_mode) { - oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + param->strides[0] - 1) / - param->strides[0]) + - 1; - } else { - oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1; - } + oshape[hidx] = + calculate_pool_dimension(dshape[hidx], pad_h, param->pool_size[0], param->dilation[0], + param->strides[0], param->ceil_mode); } if (dshape[widx].as()) { oshape[widx] = dshape[widx]; } else { - if (param->ceil_mode) { - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + param->strides[1] - 1) / - param->strides[1]) + - 1; - } else { - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1; - } + oshape[widx] = + calculate_pool_dimension(dshape[widx], pad_w, param->pool_size[1], param->dilation[1], + param->strides[1], param->ceil_mode); } // assign output type @@ -134,6 +140,7 @@ Array Pool2DCompute(const Attrs& attrs, const Array& inp ICHECK(param != nullptr); auto pool_size = param->pool_size; auto strides = param->strides; + auto dilation = param->dilation; auto padding = param->padding; auto ceil_mode = param->ceil_mode; Layout layout(param->layout); @@ -160,19 +167,20 @@ Array Pool2DCompute(const Attrs& attrs, const Array& inp } if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{topi::nn::pool(inputs[0], pool_size, strides, padding, mode, ceil_mode, - layout.name(), count_include_pad)}; + return Array{topi::nn::pool2d(inputs[0], pool_size, strides, dilation, padding, + mode, ceil_mode, layout.name(), count_include_pad)}; } else { - return Array{ - topi::nn::pool(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; + return Array{topi::nn::pool2d(inputs[0], pool_size, strides, dilation, padding, + mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d") .set_body_typed([](Expr data, Array pool_size, Array strides, - Array padding, String layout, bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool2d"); + Array dilation, Array padding, String layout, + bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, dilation, padding, layout, + ceil_mode, "nn.max_pool2d"); }); RELAY_REGISTER_OP("nn.max_pool2d") @@ -207,10 +215,10 @@ RELAY_REGISTER_OP("nn.max_pool2d") // AvgPool2D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d") .set_body_typed([](Expr data, Array pool_size, Array strides, - Array padding, String layout, bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool2d"); + Array dilation, Array padding, String layout, + bool ceil_mode, bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, + ceil_mode, count_include_pad, "nn.avg_pool2d"); }); RELAY_REGISTER_OP("nn.avg_pool2d") @@ -988,13 +996,9 @@ bool Pool1DRel(const Array& types, int num_inputs, const Attrs& attrs, if (dshape[widx].as()) { oshape[widx] = dshape[widx]; } else { - if (param->ceil_mode) { - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0] + param->strides[0] - 1) / - param->strides[0]) + - 1; - } else { - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0]) / param->strides[0]) + 1; - } + oshape[widx] = + calculate_pool_dimension(dshape[widx], pad_w, param->pool_size[0], param->dilation[0], + param->strides[0], param->ceil_mode); } // assign output type @@ -1010,6 +1014,7 @@ Array Pool1DCompute(const Attrs& attrs, const Array& inp ICHECK(param != nullptr); auto pool_size = param->pool_size; auto strides = param->strides; + auto dilation = param->dilation; auto padding = param->padding; auto ceil_mode = param->ceil_mode; Layout layout(param->layout); @@ -1030,19 +1035,20 @@ Array Pool1DCompute(const Attrs& attrs, const Array& inp if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{topi::nn::pool1d(inputs[0], pool_size, strides, padding, mode, - ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool1d(inputs[0], pool_size, strides, dilation, padding, + mode, ceil_mode, layout.name(), count_include_pad)}; } else { - return Array{ - topi::nn::pool1d(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; + return Array{topi::nn::pool1d(inputs[0], pool_size, strides, dilation, padding, + mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool1d") .set_body_typed([](Expr data, Array pool_size, Array strides, - Array padding, String layout, bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool1d"); + Array dilation, Array padding, String layout, + bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, dilation, padding, layout, + ceil_mode, "nn.max_pool1d"); }); RELAY_REGISTER_OP("nn.max_pool1d") @@ -1075,10 +1081,10 @@ RELAY_REGISTER_OP("nn.max_pool1d") // AvgPool1D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool1d") .set_body_typed([](Expr data, Array pool_size, Array strides, - Array padding, String layout, bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool1d"); + Array dilation, Array padding, String layout, + bool ceil_mode, bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, + ceil_mode, count_include_pad, "nn.avg_pool1d"); }); RELAY_REGISTER_OP("nn.avg_pool1d") @@ -1165,13 +1171,9 @@ bool Pool3DRel(const Array& types, int num_inputs, const Attrs& attrs, if (dshape[ii].as()) { oshape[ii] = dshape[ii]; } else { - if (param->ceil_mode) { - oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i] + param->strides[i] - 1) / - param->strides[i]) + - 1; - } else { - oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i]) / param->strides[i]) + 1; - } + oshape[ii] = + calculate_pool_dimension(dshape[ii], pad[i], param->pool_size[i], param->dilation[i], + param->strides[i], param->ceil_mode); } } @@ -1188,6 +1190,7 @@ Array Pool3DCompute(const Attrs& attrs, const Array& inp ICHECK(param != nullptr); auto pool_size = param->pool_size; auto strides = param->strides; + auto dilation = param->dilation; auto padding = param->padding; auto ceil_mode = param->ceil_mode; Layout layout(param->layout); @@ -1217,19 +1220,20 @@ Array Pool3DCompute(const Attrs& attrs, const Array& inp } if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{topi::nn::pool3d(inputs[0], pool_size, strides, padding, mode, - ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool3d(inputs[0], pool_size, strides, dilation, padding, + mode, ceil_mode, layout.name(), count_include_pad)}; } else { - return Array{ - topi::nn::pool3d(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; + return Array{topi::nn::pool3d(inputs[0], pool_size, strides, dilation, padding, + mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d") .set_body_typed([](Expr data, Array pool_size, Array strides, - Array padding, String layout, bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool3d"); + Array dilation, Array padding, String layout, + bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, dilation, padding, layout, + ceil_mode, "nn.max_pool3d"); }); RELAY_REGISTER_OP("nn.max_pool3d") @@ -1265,10 +1269,10 @@ RELAY_REGISTER_OP("nn.max_pool3d") // AvgPool3D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d") .set_body_typed([](Expr data, Array pool_size, Array strides, - Array padding, String layout, bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool3d"); + Array dilation, Array padding, String layout, + bool ceil_mode, bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, + ceil_mode, count_include_pad, "nn.avg_pool3d"); }); RELAY_REGISTER_OP("nn.avg_pool3d") diff --git a/src/relay/op/nn/pooling.h b/src/relay/op/nn/pooling.h index a803698b93eb..9b7eab25fe9a 100644 --- a/src/relay/op/nn/pooling.h +++ b/src/relay/op/nn/pooling.h @@ -34,10 +34,12 @@ namespace relay { template inline Expr MakeMaxPool(Expr data, Array pool_size, Array strides, - Array padding, String layout, bool ceil_mode, String op_name) { + Array dilation, Array padding, String layout, + bool ceil_mode, String op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); + attrs->dilation = std::move(dilation); attrs->padding = std::move(padding); attrs->layout = std::move(layout); attrs->ceil_mode = ceil_mode; @@ -47,11 +49,12 @@ inline Expr MakeMaxPool(Expr data, Array pool_size, Array template inline Expr MakeAvgPool(Expr data, Array pool_size, Array strides, - Array padding, String layout, bool ceil_mode, - bool count_include_pad, String op_name) { + Array dilation, Array padding, String layout, + bool ceil_mode, bool count_include_pad, String op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); + attrs->dilation = std::move(dilation); attrs->padding = std::move(padding); attrs->layout = std::move(layout); attrs->ceil_mode = ceil_mode; diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 1a81d10c2583..755f12848777 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -270,19 +270,19 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ auto scaled_hw_t2 = Multiply(casted_t2, MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w)); Array padding({0, 0}); - reduced_t2 = - AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, padding, param->data_layout, - false, // ceil_mode - false); // count_include_pad + reduced_t2 = AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, param->dilation, + padding, param->data_layout, + false, // ceil_mode + false); // count_include_pad } else { int stride1 = get_const_int(param->strides[0]); int stride2 = get_const_int(param->strides[1]); if (stride1 * stride2 != 1) { Array padding({0, 0}); - reduced_t2 = - AvgPool2D(reduced_t2, param->kernel_size, param->strides, padding, param->data_layout, - false, // ceil_mode - false); // count_include_pad + reduced_t2 = AvgPool2D(reduced_t2, param->kernel_size, param->strides, param->dilation, + padding, param->data_layout, + false, // ceil_mode + false); // count_include_pad } } @@ -435,18 +435,18 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, if (kernel_h * kernel_w != 1) { reduced_c_t2 = Multiply(reduced_c_t2, MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w)); - reduced_t2 = - AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, padding, param->data_layout, - false, // ceil_mode - false); // count_include_pad + reduced_t2 = AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, param->dilation, + padding, param->data_layout, + false, // ceil_mode + false); // count_include_pad } else { int stride1 = get_const_int(param->strides[0]); int stride2 = get_const_int(param->strides[1]); if (stride1 * stride2 != 1) { - reduced_t2 = - AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, padding, param->data_layout, - false, // ceil_mode - false); // count_include_pad + reduced_t2 = AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, param->dilation, + padding, param->data_layout, + false, // ceil_mode + false); // count_include_pad } } diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 975c07e563be..50a695bf1d84 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -670,9 +670,9 @@ static inline Expr Reshape(Expr data, Array newshape) { } static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides, - Array padding, std::string layout, bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + Array dilation, Array padding, + std::string layout, bool ceil_mode, bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, dilation, padding, layout, ceil_mode, count_include_pad, "nn.avg_pool2d"); } diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 092fe65e19dc..356f3d2ea18f 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -95,11 +95,6 @@ TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc").set_body([](TVMArgs args, TVMRet }); /* Ops from nn/pooling.h */ -TVM_REGISTER_GLOBAL("topi.nn.pool").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::pool(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), args[5], args[6], args[7]); -}); - TVM_REGISTER_GLOBAL("topi.nn.pool_grad").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4], static_cast(static_cast(args[5])), args[6], args[7], @@ -121,13 +116,18 @@ TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d").set_body([](TVMArgs args, TVMRetV }); TVM_REGISTER_GLOBAL("topi.nn.pool1d").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::pool1d(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), args[5], args[6], args[7]); + *rv = nn::pool1d(args[0], args[1], args[2], args[3], args[4], + static_cast(static_cast(args[5])), args[6], args[7], args[8]); +}); + +TVM_REGISTER_GLOBAL("topi.nn.pool2d").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::pool2d(args[0], args[1], args[2], args[3], args[4], + static_cast(static_cast(args[5])), args[6], args[7], args[8]); }); TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::pool3d(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), args[5], args[6], args[7]); + *rv = nn::pool3d(args[0], args[1], args[2], args[3], args[4], + static_cast(static_cast(args[5])), args[6], args[7], args[8]); }); /* Ops from nn/softmax.h */ diff --git a/tests/python/contrib/test_arm_compute_lib/test_pooling.py b/tests/python/contrib/test_arm_compute_lib/test_pooling.py index 7ab4b42f95c1..137484330db8 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_pooling.py +++ b/tests/python/contrib/test_arm_compute_lib/test_pooling.py @@ -17,30 +17,30 @@ """Arm Compute Library integration pooling tests.""" import numpy as np - import tvm -from tvm import relay -from tvm import testing +from tvm import relay, testing from test_arm_compute_lib.infrastructure import ( - skip_runtime_test, - skip_codegen_test, + Device, build_and_run, + skip_codegen_test, + skip_runtime_test, verify, verify_codegen, ) -from test_arm_compute_lib.infrastructure import Device -def _calculate_output_shape(shape, sizes, padding, strides): +def _calculate_output_shape(shape, sizes, padding, strides, dilation): """Calculate pooling output shape.""" - output_height = ((shape[1] - sizes[0] + padding[0] + padding[2]) / strides[0]) + 1 - output_width = ((shape[2] - sizes[1] + padding[1] + padding[3]) / strides[1]) + 1 + height_receptive_field = (sizes[0] - 1) * dilation[0] + 1 + width_receptive_field = (sizes[1] - 1) * dilation[1] + 1 + output_height = ((shape[1] - height_receptive_field + padding[0] + padding[2]) / strides[0]) + 1 + output_width = ((shape[2] - width_receptive_field + padding[1] + padding[3]) / strides[1]) + 1 return 1, int(output_height), int(output_width), shape[3] def _get_pooling_model( - shape, dtype, typef, sizes, strides, padding, ceil_mode, count_include_pad, var_names + shape, dtype, typef, sizes, strides, dilation, padding, ceil_mode, count_include_pad, var_names ): """Return a model and any parameters it may have.""" if len(padding) == 2: @@ -52,6 +52,7 @@ def _get_pooling_model( out, pool_size=sizes, strides=strides, + dilation=dilation, padding=padding, ceil_mode=ceil_mode, layout="NHWC", @@ -63,6 +64,7 @@ def _get_pooling_model( out, pool_size=sizes, strides=strides, + dilation=dilation, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad, @@ -107,11 +109,11 @@ def _get_global_pooling_model(shape, dtype, typef, var_names): def _get_expected_pooling_codegen( - shape, dtype, typef, sizes, strides, padding, ceil_mode, count_include_pad + shape, dtype, typef, sizes, strides, dilation, padding, ceil_mode, count_include_pad ): if len(padding) == 2: padding = (padding[0], padding[1], padding[0], padding[1]) - output_shape = _calculate_output_shape(shape, sizes, padding, strides) + output_shape = _calculate_output_shape(shape, sizes, padding, strides, dilation) node = { "op": "kernel", @@ -125,6 +127,7 @@ def _get_expected_pooling_codegen( "dtype": [[dtype]], "padding": [[str(p) for p in padding]], "strides": [[str(s) for s in strides]], + "dilation": [[str(d) for d in dilation]], "pool_size": [[str(s) for s in sizes]], "ceil_mode": [[str(1 if ceil_mode else 0)]], }, @@ -282,19 +285,21 @@ def test_codegen_pooling(): uint8_dtype = ("uint8", 0, 255) trials = [ - ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], - ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), True, True, (15, 15, 16)], - ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), False, False, (16, 16, 16)], - ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)], - ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), True, True, (15, 15, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, False, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 0), False, True, (16, 16, 16)], - ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 1), True, False, (15, 15, 16)], - ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), False, True, (16, 16, 16)], - ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (0, 1), False, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (0, 1), True, False, (15, 15, 16)], - ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (0, 0), False, False, (16, 16, 16)], - ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), False, True, (15, 15, 16)], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16)], + ["nn.max_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16)], + ["nn.max_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16)], + ["nn.max_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16)], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), True, True, (15, 15, 16)], + ["nn.max_pool2d", uint8_dtype, (2, 2), (2, 2), (3, 2), (1, 1), True, True, (15, 15, 16)], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, False, (16, 16, 16)], + ["nn.avg_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 0), False, True, (16, 16, 16)], + ["nn.avg_pool2d", fp32_dtype, (3, 3), (2, 2), (3, 2), (0, 1), True, False, (15, 15, 16)], + ["nn.avg_pool2d", uint8_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (16, 16, 16)], + ["nn.avg_pool2d", uint8_dtype, (3, 3), (2, 2), (1, 1), (0, 1), False, False, (16, 16, 16)], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (0, 1), True, False, (15, 15, 16)], + ["nn.l2_pool2d", fp32_dtype, (3, 3), (2, 2), (1, 1), (0, 0), False, False, (16, 16, 16)], + ["nn.l2_pool2d", fp32_dtype, (2, 2), (2, 2), (1, 1), (1, 1), False, True, (15, 15, 16)], ] for ( @@ -302,6 +307,7 @@ def test_codegen_pooling(): (dtype, low, high), size, stride, + dilation, pad, ceil_mode, count_include_pad, @@ -309,9 +315,10 @@ def test_codegen_pooling(): ) in trials: shape = (1, *input_shape) inputs = {"a"} - args = (shape, dtype, typef, size, stride, pad, False, False) + args = (shape, dtype, typef, size, stride, dilation, pad, False, False) func = _get_pooling_model(*args, iter(inputs)) exp_codegen = _get_expected_pooling_codegen(*args) + verify_codegen(func, exp_codegen, 1) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f81606c6ae50..f878fa939fe2 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -15,17 +15,18 @@ # specific language governing permissions and limitations # under the License. import numpy as np -import onnx -from onnx import helper, TensorProto, mapping, numpy_helper +import pytest +import scipy import torch import torchvision -import pytest -import tvm.topi.testing import tvm +import tvm.testing +import tvm.topi.testing from tvm import relay from tvm.contrib import graph_executor -import scipy -import tvm.testing + +import onnx +from onnx import TensorProto, helper, mapping, numpy_helper def get_input_data_shape_dict(graph_def, input_data): @@ -2696,7 +2697,7 @@ def repeat(N, D): @tvm.testing.uses_gpu def test_unsqueeze_constant(): - from torch.nn import Linear, Sequential, Module + from torch.nn import Linear, Module, Sequential class Flatten(Module): def forward(self, input): @@ -4210,7 +4211,8 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_eyelike_with_dtype/", "test_eyelike_without_dtype/", "test_matmulinteger/", - "test_maxpool_2d_dilations/", + "test_maxpool_2d_same_lower/", + "test_maxpool_2d_same_upper/", "test_maxpool_with_argmax_2d_precomputed_pads/", "test_maxpool_with_argmax_2d_precomputed_strides/", "test_maxunpool_export_with_output_shape/", diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index bff5bb60e24f..f9f3bba25937 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -16,21 +16,22 @@ # under the License. # pylint: disable=import-self, invalid-name, unused-argument """Unit tests for various models and operators""" -from time import time import os import sys -from scipy.stats import t as tdistr +from time import time + import numpy as np import torch import torchvision +import tvm +import tvm.testing +from packaging import version as package_version +from scipy.stats import t as tdistr from torch.nn import Module from torch.nn import functional as F -import tvm from tvm import relay from tvm.contrib import graph_executor from tvm.contrib.nvcc import have_fp16 -import tvm.testing -from packaging import version as package_version sys.setrecursionlimit(10000) @@ -736,6 +737,7 @@ def test_forward_maxpool2d(): input_data = torch.rand(input_shape).float() verify_model(torch.nn.MaxPool2d(kernel_size=[1, 1]).eval(), input_data) + verify_model(torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]).eval(), input_data) verify_model(torch.nn.MaxPool2d(kernel_size=[10, 10]).eval(), input_data) verify_model(torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2).eval(), input_data) @@ -774,6 +776,7 @@ def test_forward_maxpool1d(): input_data = torch.rand(input_shape).float() verify_model(torch.nn.MaxPool1d(kernel_size=1).eval(), input_data) + verify_model(torch.nn.MaxPool1d(kernel_size=2, dilation=[1]).eval(), input_data) verify_model(torch.nn.MaxPool1d(kernel_size=10).eval(), input_data) verify_model(torch.nn.MaxPool1d(kernel_size=4, padding=2, stride=2).eval(), input_data) @@ -792,6 +795,7 @@ def test_forward_maxpool3d(): input_data = torch.rand(input_shape).float() verify_model(torch.nn.MaxPool3d(kernel_size=[1, 1, 1]).eval(), input_data) + verify_model(torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[1, 2, 3]).eval(), input_data) verify_model(torch.nn.MaxPool3d(kernel_size=[10, 10, 10]).eval(), input_data) verify_model(torch.nn.MaxPool3d(kernel_size=[4, 4, 4], padding=2, stride=2).eval(), input_data) @@ -3523,7 +3527,7 @@ def test_forward_pretrained_bert_base_uncased(): """ try: - from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM + from pytorch_pretrained_bert import BertForMaskedLM, BertTokenizer except: print("Torch pretrained bert package must be installed to run this script.") return diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 95f54d0bd2aa..fe5e04844bb3 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -14,18 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os + import numpy as np import pytest - import tvm -from tvm import te -from tvm import relay +import tvm.topi.testing +from tvm import relay, te from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type -from utils.assert_diagnostic import DiagnosticTesting -import tvm.topi.testing -import os +from utils.assert_diagnostic import DiagnosticTesting def int32(val): @@ -669,13 +668,21 @@ def test_any_conv2d_transpose_nchw(): def verify_any_pool2d( - pool_type, data_shape, pool_size, strides, padding, layout, static_data_shape, ref_out_shape + pool_type, + data_shape, + pool_size, + strides, + dilation, + padding, + layout, + static_data_shape, + ref_out_shape, ): mod = tvm.IRModule() dtype = "float32" pool_func = relay.nn.max_pool2d if pool_type == "max" else relay.nn.avg_pool2d data = relay.var("data", shape=data_shape, dtype=dtype) - y = pool_func(data, pool_size, strides, padding, layout) + y = pool_func(data, pool_size, strides, dilation, padding, layout) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) check_result([data_np], mod, ref_out_shape, assert_shape=True) @@ -689,6 +696,7 @@ def test_any_pool2d(): (3, 3), (1, 1), (1, 1), + (1, 1), "NCHW", (2, 3, 220, 220), (2, 3, 220, 220), @@ -698,6 +706,7 @@ def test_any_pool2d(): (relay.Any(), relay.Any(), relay.Any(), 4), (1, 1), (2, 2), + (1, 1), (0, 0), "NHWC", (3, 220, 220, 4), @@ -709,6 +718,7 @@ def test_any_pool2d(): (3, 3), (2, 2), (1, 1), + (1, 1), "NCHW4c", (2, 3, 220, 220, 4), (2, 3, 110, 110, 4), diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 755618cccf06..d0ff86bffcde 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -926,49 +926,6 @@ def test_upsampling3d_infer_type(): assert yy.checked_type == relay.TensorType((n, c, 200, 200, 400), "float32") -def _test_pool2d(opfunc, reffunc, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)): - n, c, h, w = te.size_var("n"), 10, 224, 224 - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - y = opfunc(x, pool_size=(1, 1)) - assert "pool_size=" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, 10, 224, 224), "float32") - # test execution - dtype = "float32" - dshape = (1, 3, 28, 28) - x = relay.var("x", shape=dshape) - y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) - func = relay.Function([x], y) - data = np.random.uniform(size=dshape).astype(dtype) - ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)) - for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) - - -def _test_pool2d_int(opfunc, reffunc, dtype): - n, c, h, w = te.size_var("n"), 10, 224, 224 - x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) - y = opfunc(x, pool_size=(1, 1)) - assert "pool_size=" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, 10, 224, 224), dtype) - # test execution - dtype = "int32" - dshape = (1, 3, 28, 28) - for shape_dtype in ["int32", "int64"]: - x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) - y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) - func = relay.Function([x], y) - data = np.random.randint(low=-128, high=128, size=dshape) - ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype) - for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) - - def _test_global_pool2d(opfunc, reffunc): n, c, h, w = te.size_var("n"), te.size_var("c"), 224, 224 x = relay.var("x", relay.TensorType((n, h, w, c), "float32")) @@ -997,39 +954,69 @@ def _test_global_pool2d(opfunc, reffunc): @tvm.testing.uses_gpu def test_pool2d(): - _test_pool2d(relay.nn.max_pool2d, np.max) - _test_pool2d(relay.nn.max_pool2d, np.max, pool_size=2, strides=2, padding=0) - _test_pool2d(relay.nn.avg_pool2d, np.mean) - _test_pool2d(relay.nn.avg_pool2d, np.mean, pool_size=2, strides=2, padding=0) - _test_pool2d_int(relay.nn.avg_pool2d, np.mean, "int32") - _test_pool2d_int(relay.nn.avg_pool2d, np.mean, "uint16") - _test_global_pool2d(relay.nn.global_max_pool2d, np.max) - _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean) - - -def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0), dtype="float32"): - n, c, w = te.var("n"), 10, 224 - x = relay.var("x", relay.TensorType((n, c, w), "float32")) - y = opfunc(x, pool_size=(1,)) - assert "pool_size=" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, 10, 224), "float32") - # test execution - dshape = (1, 3, 32) - for shape_dtype in ["int32", "int64"]: - x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) - pool_type = "max" if "max" in str(opfunc) else "avg" - y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) + def _test_pool2d(opfunc, pool_type, pool_size=2, strides=2, dilation=1, padding=0): + n, c, h, w = te.size_var("n"), 10, 224, 224 + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + y = opfunc(x, pool_size=(1, 1)) + assert "pool_size=" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, 10, 224, 224), "float32") + # test execution + dtype = "float32" + dshape = (1, 3, 28, 28) + x = relay.var("x", shape=dshape) + y = opfunc(x, pool_size=pool_size, strides=strides, dilation=dilation, padding=padding) func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) - ref_res = tvm.topi.testing.pool1d_ncw_python( - data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False + ref_res = tvm.topi.testing.poolnd_python( + data, + [pool_size, pool_size], + [strides, strides], + [dilation, dilation], + [padding, padding], + [padding, padding], + pool_type, + count_include_pad=False, + ceil_mode=False, ) for target, dev in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", device=dev, target=target) op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + def _test_pool2d_int(opfunc, reffunc, dtype): + n, c, h, w = te.size_var("n"), 10, 224, 224 + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + y = opfunc(x, pool_size=(1, 1)) + assert "pool_size=" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, 10, 224, 224), dtype) + # test execution + dtype = "int32" + dshape = (1, 3, 28, 28) + for shape_dtype in ["int32", "int64"]: + x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) + y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + func = relay.Function([x], y) + data = np.random.randint(low=-128, high=128, size=dshape) + ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype) + for target, dev in tvm.testing.enabled_targets(): + intrp1 = relay.create_executor("graph", device=dev, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + _test_pool2d(relay.nn.max_pool2d, "max") + _test_pool2d(relay.nn.max_pool2d, "max", pool_size=2, strides=2, padding=0) + _test_pool2d(relay.nn.max_pool2d, "max", pool_size=2, strides=2, padding=0, dilation=2) + _test_pool2d(relay.nn.avg_pool2d, "avg") + _test_pool2d(relay.nn.avg_pool2d, "avg", pool_size=2, strides=2, padding=0) + _test_pool2d(relay.nn.avg_pool2d, "avg", pool_size=2, strides=2, padding=0, dilation=2) + + _test_pool2d_int(relay.nn.avg_pool2d, np.mean, "int32") + _test_pool2d_int(relay.nn.avg_pool2d, np.mean, "uint16") + _test_global_pool2d(relay.nn.global_max_pool2d, np.max) + _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean) + def _test_global_pool1d(opfunc, reffunc): n, c, w = te.size_var("n"), te.size_var("c"), 224 @@ -1059,12 +1046,47 @@ def _test_global_pool1d(opfunc, reffunc): @tvm.testing.uses_gpu def test_pool1d(): - _test_pool1d(relay.nn.max_pool1d) - _test_pool1d(relay.nn.max_pool1d, dtype="int32") - _test_pool1d(relay.nn.max_pool1d, pool_size=2, strides=2, padding=0) - _test_pool1d(relay.nn.avg_pool1d) - _test_pool1d(relay.nn.avg_pool1d, dtype="int32") - _test_pool1d(relay.nn.avg_pool1d, pool_size=2, strides=2, padding=0) + def _test_pool1d( + opfunc, pool_type, pool_size=2, strides=2, dilation=1, padding=0, dtype="float32" + ): + n, c, w = te.var("n"), 10, 224 + x = relay.var("x", relay.TensorType((n, c, w), "float32")) + y = opfunc(x, pool_size=(1,)) + assert "pool_size=" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, 10, 224), "float32") + # test execution + dshape = (1, 3, 32) + for shape_dtype in ["int32", "int64"]: + x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) + pool_type = "max" if "max" in str(opfunc) else "avg" + y = opfunc(x, pool_size=pool_size, strides=strides, dilation=dilation, padding=padding) + func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = tvm.topi.testing.poolnd_python( + data, + [pool_size], + [strides], + [dilation], + [padding], + [padding], + pool_type, + count_include_pad=False, + ceil_mode=False, + ) + for target, dev in tvm.testing.enabled_targets(): + intrp1 = relay.create_executor("graph", device=dev, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + _test_pool1d(relay.nn.max_pool1d, "max") + _test_pool1d(relay.nn.max_pool1d, "max", dtype="int32") + _test_pool1d(relay.nn.max_pool1d, "max", pool_size=2, strides=2, padding=0) + _test_pool1d(relay.nn.max_pool1d, "max", pool_size=2, strides=2, padding=0, dilation=2) + _test_pool1d(relay.nn.avg_pool1d, "avg") + _test_pool1d(relay.nn.avg_pool1d, "avg", dtype="int32") + _test_pool1d(relay.nn.avg_pool1d, "avg", pool_size=2, strides=2, padding=0) + _test_pool1d(relay.nn.avg_pool1d, "avg", pool_size=2, strides=2, padding=0, dilation=2) _test_global_pool1d(relay.nn.global_max_pool1d, np.max) _test_global_pool1d(relay.nn.global_avg_pool1d, np.mean) @@ -1073,10 +1095,11 @@ def test_pool1d(): def test_pool3d(): def _test_pool3d( opfunc, - pool_size=(2, 2, 2), - strides=(2, 2, 2), - padding=(0, 0, 0, 0, 0, 0), - out_shape=(1, 3, 16, 16, 16), + pool_type, + pool_size=2, + strides=2, + dilation=1, + padding=[0, 0, 0, 0, 0, 0], dtype="float32", ): n, c, d, h, w = te.size_var("n"), 10, 5, 224, 224 @@ -1091,34 +1114,45 @@ def _test_pool3d( for shape_dtype in ["int32", "int64"]: x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype) pool_type = "max" if "max" in str(opfunc) else "avg" - y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) - func = relay.Function([x], y) - # check output shape - f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape)) - assert out_shape == f_out_shape, "Output shape mismatch. expected {}, actual {}".format( - out_shape, f_out_shape + y = opfunc( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + dilation=dilation, ) + func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) - ref_res = tvm.topi.testing.pool3d_ncdhw_python( - data, pool_size, strides, padding, out_shape, pool_type, False + ref_res = tvm.topi.testing.poolnd_python( + data, + [pool_size, pool_size, pool_size], + [strides, strides, strides], + [dilation, dilation, dilation], + padding[:3], + padding[3:], + pool_type, + count_include_pad=False, + ceil_mode=False, ) for target, dev in tvm.testing.enabled_targets(): intrp1 = relay.create_executor("graph", device=dev, target=target) op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) - _test_pool3d(relay.nn.max_pool3d) - _test_pool3d(relay.nn.max_pool3d, dtype="int32") - _test_pool3d(relay.nn.max_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16)) - _test_pool3d(relay.nn.max_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16)) - _test_pool3d(relay.nn.max_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20)) - _test_pool3d(relay.nn.max_pool3d, pool_size=2, padding=0, strides=2) - _test_pool3d(relay.nn.avg_pool3d) - _test_pool3d(relay.nn.avg_pool3d, dtype="int32") - _test_pool3d(relay.nn.avg_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16)) - _test_pool3d(relay.nn.avg_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16)) - _test_pool3d(relay.nn.avg_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20)) - _test_pool3d(relay.nn.avg_pool3d, pool_size=2, padding=0, strides=2) + _test_pool3d(relay.nn.max_pool3d, "max") + _test_pool3d(relay.nn.max_pool3d, "max", dtype="int32") + _test_pool3d(relay.nn.max_pool3d, "max", padding=(2, 0, 0, 2, 0, 0)) + _test_pool3d(relay.nn.max_pool3d, "max", padding=(0, 3, 0, 0, 3, 0)) + _test_pool3d(relay.nn.max_pool3d, "max", padding=(0, 0, 4, 0, 0, 4)) + _test_pool3d(relay.nn.max_pool3d, "max", pool_size=2, strides=2) + _test_pool3d(relay.nn.max_pool3d, "max", pool_size=2, strides=2, dilation=2) + _test_pool3d(relay.nn.avg_pool3d, "avg") + _test_pool3d(relay.nn.avg_pool3d, "avg", dtype="int32") + _test_pool3d(relay.nn.avg_pool3d, "avg", padding=(2, 0, 0, 2, 0, 0)) + _test_pool3d(relay.nn.avg_pool3d, "avg", padding=(0, 3, 0, 0, 3, 0)) + _test_pool3d(relay.nn.avg_pool3d, "avg", padding=(0, 0, 4, 0, 0, 4)) + _test_pool3d(relay.nn.avg_pool3d, "avg", pool_size=2, strides=2) + _test_pool3d(relay.nn.avg_pool3d, "avg", pool_size=2, strides=2, dilation=2) @tvm.testing.uses_gpu diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index aa92e10c4d06..64ec7e3345a1 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -17,13 +17,13 @@ """ Support level5 operator test cases. """ import math + import numpy as np import tvm -from tvm import te -from tvm import relay -from tvm.relay.testing import run_infer_type -import tvm.topi.testing import tvm.testing +import tvm.topi.testing +from tvm import relay, te +from tvm.relay.testing import run_infer_type def test_resize_infer_type(): diff --git a/tests/python/topi/python/test_topi_pooling.py b/tests/python/topi/python/test_topi_pooling.py index 1451d18e42dd..6d4bd71642b6 100644 --- a/tests/python/topi/python/test_topi_pooling.py +++ b/tests/python/topi/python/test_topi_pooling.py @@ -17,14 +17,13 @@ # pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument """Test code for pooling""" import math + import numpy as np import tvm -from tvm import te -from tvm import topi import tvm.testing import tvm.topi.testing +from tvm import te, topi from tvm.topi.utils import get_const_tuple -import tvm.testing _pool_schedule = { "generic": topi.generic.schedule_pool, @@ -46,82 +45,6 @@ } -def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): - """verify function of pool""" - iw = ih - kw = kh - sw = sh - pt, pl, pb, pr = padding - layout = "NCHW" - A = te.placeholder((n, ic, ih, iw), name="A") - B = topi.nn.pool( - A, - kernel=[kh, kw], - stride=[sh, sw], - padding=padding, - pool_type=pool_type, - ceil_mode=ceil_mode, - layout="NCHW", - count_include_pad=count_include_pad, - ) - B = topi.nn.relu(B) - dtype = A.dtype - - bshape = get_const_tuple(B.shape) - ashape = get_const_tuple(A.shape) - if ceil_mode: - assert bshape[2] == int(math.ceil(float(ashape[2] - kh + pt + pb) / sh) + 1) - assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pl + pr) / sw) + 1) - else: - assert bshape[2] == int(math.floor(float(ashape[2] - kh + pt + pb) / sh) + 1) - assert bshape[3] == int(math.floor(float(ashape[3] - kw + pl + pr) / sw) + 1) - - a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype) - pad_np = np.zeros(shape=(n, ic, ih + pt + pb, iw + pl + pr)).astype(dtype) - no_zero = (range(n), range(ic), (range(pt, ih + pt)), (range(pl, iw + pl))) - pad_np[np.ix_(*no_zero)] = a_np - _, oc, oh, ow = get_const_tuple(B.shape) - b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) - - if pool_type == "avg": - for i in range(oh): - for j in range(ow): - if count_include_pad: - b_np[:, :, i, j] = np.mean( - pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw], axis=(2, 3) - ) - else: - pad_count = np.sum( - pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw] > 0, axis=(2, 3) - ) - b_np[:, :, i, j] = np.sum( - pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw], axis=(2, 3) - ) / np.maximum(pad_count, 1) - - elif pool_type == "max": - for i in range(oh): - for j in range(ow): - b_np[:, :, i, j] = np.max( - pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw], axis=(2, 3) - ) - b_np = np.maximum(b_np, 0.0) - - def check_target(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s_func = tvm.topi.testing.dispatch(target, _pool_schedule) - s = s_func(B, layout) - - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) - f = tvm.build(s, [A, B], target) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=2e-5, atol=1e-5) - - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) - - def verify_pool_grad( n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True, add_relu=False ): @@ -131,10 +54,11 @@ def verify_pool_grad( sw = sh pt, pl, pb, pr = padding A = te.placeholder((n, ic, ih, iw), name="A") - B = topi.nn.pool( + B = topi.nn.pool2d( A, kernel=[kh, kw], stride=[sh, sw], + dilation=[1, 1], padding=padding, pool_type=pool_type, ceil_mode=ceil_mode, @@ -198,24 +122,6 @@ def check_target(target, dev): check_target(target, dev) -@tvm.testing.uses_gpu -def test_pool(): - """test cases of pool""" - verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], "avg", False, True) - verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], "avg", False, True) - verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], "avg", False, False) - verify_pool(1, 256, 31, 4, 4, [3, 3, 3, 3], "avg", False, False) - verify_pool(1, 256, 31, 4, 4, [0, 0, 0, 0], "avg", False, False) - verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], "max", False) - verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], "max", False) - verify_pool(1, 256, 31, 3, 3, [2, 1, 2, 1], "max", True) - - verify_pool(1, 256, 31, 3, 3, [2, 1, 0, 3], "avg", False, True) - verify_pool(1, 256, 32, 2, 2, [0, 3, 2, 1], "avg", False, False) - verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], "max", False) - verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], "max", True) - - @tvm.testing.uses_gpu def test_pool_grad(): """test cases of pool_grad""" @@ -337,36 +243,82 @@ def test_adaptive_pool(): verify_adaptive_pool((1, 16, 32, 32, 32), (2, 4, 4), "max", layout="NDHWC") -def verify_pool3d( - n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True, layout="NCDHW" +def verify_poolnd( + n, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + layout, + count_include_pad=True, ): - """verify function of pool3d""" - id = iw = ih - kd = kw = kh - sd = sw = sh - input_shape = (n, ic, id, ih, iw) - kernel = [kd, kh, kw] - stride = [sd, sh, sw] + """verify function of pool1d""" A = te.placeholder(input_shape, name="A") - B = topi.nn.pool3d( - A, - kernel=kernel, - stride=stride, - padding=padding, - pool_type=pool_type, - ceil_mode=ceil_mode, - layout=layout, - count_include_pad=count_include_pad, - ) + + if n == 1: + B = topi.nn.pool1d( + A, + kernel=kernel, + stride=stride, + dilation=dilation, + padding=padding, + pool_type=pool_type, + ceil_mode=ceil_mode, + layout=layout, + count_include_pad=count_include_pad, + ) + elif n == 2: + B = topi.nn.pool2d( + A, + kernel=kernel, + stride=stride, + dilation=dilation, + padding=padding, + pool_type=pool_type, + ceil_mode=ceil_mode, + layout=layout, + count_include_pad=count_include_pad, + ) + elif n == 3: + B = topi.nn.pool3d( + A, + kernel=kernel, + stride=stride, + dilation=dilation, + padding=padding, + pool_type=pool_type, + ceil_mode=ceil_mode, + layout=layout, + count_include_pad=count_include_pad, + ) + else: + raise ValueError(f"PoolND only supports n=1, 2, 3 got n={n}") + B = topi.nn.relu(B) dtype = A.dtype output_shape = [int(i) for i in B.shape] input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype) - ref_np = tvm.topi.testing.pool3d_ncdhw_python( - input_np, kernel, stride, padding, output_shape, pool_type, count_include_pad, ceil_mode + + padding_before = padding[:n] + padding_after = padding[n:] + ref_np = tvm.topi.testing.poolnd_python( + input_np, + kernel, + stride, + dilation, + padding_before, + padding_after, + pool_type, + count_include_pad, + ceil_mode, ) + np.testing.assert_equal(tuple(output_shape), tuple(ref_np.shape)) + def check_target(target, dev): print("Running on target: %s" % target) with tvm.target.Target(target): @@ -383,88 +335,559 @@ def check_target(target, dev): check_target(target, dev) +def verify_pool3d( + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad=True, + layout="NCDHW", +): + verify_poolnd( + 3, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + layout="NCDHW", + count_include_pad=count_include_pad, + ) + + @tvm.testing.uses_gpu def test_pool3d(): """test cases of pool3d""" - verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], "avg", False, True) - verify_pool3d(1, 256, 31, 3, 3, [1, 1, 2, 2, 2, 1], "avg", False, True) - verify_pool3d(1, 256, 32, 2, 2, [1, 1, 2, 2, 2, 1], "avg", False, False) - verify_pool3d(1, 256, 31, 4, 4, [3, 3, 3, 3, 3, 3], "avg", False, False) - verify_pool3d(1, 256, 31, 4, 4, [0, 0, 0, 0, 0, 0], "avg", False, False) - verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], "max", False) - verify_pool3d(1, 256, 31, 3, 3, [2, 2, 1, 1, 1, 2], "max", False) - verify_pool3d(1, 256, 31, 3, 3, [2, 2, 1, 1, 1, 2], "max", True) + verify_pool3d( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "avg", + False, + True, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [1, 1, 2, 2, 2, 1], + "avg", + False, + True, + ) + verify_pool3d( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [1, 1, 2, 2, 2, 1], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [4, 4, 4], + [4, 4, 4], + [1, 1, 1], + [3, 3, 3, 3, 3, 3], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [4, 4, 4], + [4, 4, 4], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 0, 0, 0, 0, 0], + "max", + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [2, 2, 1, 1, 1, 2], + "max", + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [2, 2, 1, 1, 1, 2], + "max", + True, + ) + + verify_pool3d( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [2, 1, 0, 5, 4, 3], + "avg", + False, + True, + ) + verify_pool3d( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [1, 1, 1], + [0, 5, 4, 3, 2, 1], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [1, 0, 5, 4, 3, 2], + "max", + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [1, 1, 1], + [3, 2, 1, 0, 5, 4], + "max", + True, + ) - verify_pool3d(1, 256, 31, 3, 3, [2, 1, 0, 5, 4, 3], "avg", False, True) - verify_pool3d(1, 256, 32, 2, 2, [0, 5, 4, 3, 2, 1], "avg", False, False) - verify_pool3d(1, 256, 31, 3, 3, [1, 0, 5, 4, 3, 2], "max", False) - verify_pool3d(1, 256, 31, 3, 3, [3, 2, 1, 0, 5, 4], "max", True) + # Test non-1 dilation + verify_pool3d( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [2, 1, 0, 5, 4, 3], + "avg", + False, + True, + ) + verify_pool3d( + [1, 16, 32, 32, 32], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [0, 5, 4, 3, 2, 1], + "avg", + False, + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [2, 1, 3], + [1, 0, 5, 4, 3, 2], + "max", + False, + ) + verify_pool3d( + [1, 16, 31, 31, 31], + [3, 3, 3], + [3, 3, 3], + [2, 2, 3], + [3, 2, 1, 0, 5, 4], + "max", + True, + ) -def verify_pool1d( - n, ic, iw, kw, sw, padding, pool_type, ceil_mode, count_include_pad=True, layout="NCW" +def verify_pool2d( + input_shape, kernel, stride, dilation, padding, pool_type, ceil_mode, count_include_pad=True ): - """verify function of pool1d""" - input_shape = (n, ic, iw) - kernel = [kw] - stride = [sw] - A = te.placeholder(input_shape, name="A") - B = topi.nn.pool1d( - A, - kernel=kernel, - stride=stride, - padding=padding, - pool_type=pool_type, - ceil_mode=ceil_mode, - layout=layout, + verify_poolnd( + 2, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + layout="NCHW", count_include_pad=count_include_pad, ) - B = topi.nn.relu(B) - dtype = A.dtype - output_shape = [int(i) for i in B.shape] - input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype) - ref_np = tvm.topi.testing.pool1d_ncw_python( - input_np, kernel, stride, padding, output_shape, pool_type, count_include_pad, ceil_mode + +@tvm.testing.uses_gpu +def test_pool2d(): + """test cases of pool""" + verify_pool2d( + [1, 16, 32, 32], + [2, 2], + [2, 2], + [1, 1], + [0, 0, 0, 0], + "avg", + False, + True, + ) + verify_pool2d( + [1, 16, 31, 31], + [3, 3], + [3, 3], + [1, 1], + [1, 2, 1, 2], + "avg", + False, + True, + ) + verify_pool2d( + [1, 16, 32, 32], + [2, 2], + [2, 2], + [1, 1], + [1, 2, 1, 2], + "avg", + False, + False, + ) + verify_pool2d( + [1, 16, 31, 31], + [4, 4], + [4, 4], + [1, 1], + [3, 3, 3, 3], + "avg", + False, + False, + ) + verify_pool2d( + [1, 16, 31, 31], + [4, 4], + [4, 4], + [1, 1], + [0, 0, 0, 0], + "avg", + False, + False, + ) + verify_pool2d( + [1, 16, 32, 32], + [2, 3], + [2, 2], + [1, 1], + [0, 0, 0, 0], + "max", + False, + ) + verify_pool2d( + [1, 16, 31, 31], + [3, 3], + [3, 3], + [1, 1], + [2, 1, 2, 1], + "max", + False, + ) + verify_pool2d( + [1, 16, 31, 31], + [3, 3], + [3, 3], + [1, 1], + [2, 1, 2, 1], + "max", + True, ) - def check_target(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s_func = tvm.topi.testing.dispatch(target, _pool_schedule) - s = s_func(B, layout) + verify_pool2d( + [1, 16, 31, 31], + [3, 3], + [3, 3], + [1, 1], + [2, 1, 0, 3], + "avg", + False, + True, + ) + verify_pool2d( + [1, 16, 32, 32], + [2, 3], + [2, 2], + [1, 1], + [0, 3, 2, 1], + "avg", + False, + False, + ) + verify_pool2d( + [1, 16, 31, 31], + [3, 3], + [3, 3], + [1, 1], + [1, 0, 3, 2], + "max", + False, + ) + verify_pool2d( + [1, 16, 31, 31], + [3, 3], + [3, 3], + [1, 1], + [3, 2, 1, 0], + "max", + True, + ) - a = tvm.nd.array(input_np, dev) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) - f = tvm.build(s, [A, B], target) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), ref_np, rtol=1e-5) + # Test non-1 dilations + verify_pool2d( + [1, 16, 31, 31], + [3, 3], + [3, 3], + [2, 1], + [2, 1, 0, 3], + "avg", + False, + True, + ) + verify_pool2d( + [1, 16, 32, 32], + [2, 3], + [2, 2], + [2, 3], + [0, 3, 2, 1], + "avg", + False, + False, + ) + verify_pool2d( + [1, 16, 31, 31], + [3, 3], + [3, 3], + [3, 3], + [1, 0, 3, 2], + "max", + False, + ) + verify_pool2d( + [1, 16, 31, 31], + [3, 3], + [3, 3], + [2, 2], + [3, 2, 1, 0], + "max", + True, + ) - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) + +def verify_pool1d( + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + count_include_pad=True, + layout="NCW", +): + verify_poolnd( + 1, + input_shape, + kernel, + stride, + dilation, + padding, + pool_type, + ceil_mode, + layout="NCW", + count_include_pad=count_include_pad, + ) @tvm.testing.uses_gpu def test_pool1d(): """test cases of pool1d""" - verify_pool1d(1, 256, 32, 2, 2, [0, 0], "avg", False, True) - verify_pool1d(1, 256, 31, 3, 3, [1, 2], "avg", False, True) - verify_pool1d(1, 256, 32, 2, 2, [1, 2], "avg", False, False) - verify_pool1d(1, 256, 31, 4, 4, [3, 3], "avg", False, False) - verify_pool1d(1, 256, 31, 4, 4, [0, 0], "avg", False, False) - verify_pool1d(1, 256, 32, 2, 2, [0, 0], "max", False) - verify_pool1d(1, 256, 31, 3, 3, [2, 1], "max", False) - verify_pool1d(1, 256, 31, 3, 3, [2, 1], "max", True) + verify_pool1d( + [1, 16, 32], + [2], + [2], + [1], + [0, 0], + "avg", + False, + True, + ) + verify_pool1d( + [1, 16, 31], + [3], + [3], + [1], + [1, 2], + "avg", + False, + True, + ) + verify_pool1d( + [1, 16, 32], + [2], + [2], + [1], + [1, 2], + "avg", + False, + False, + ) + verify_pool1d( + [1, 16, 31], + [4], + [4], + [1], + [3, 3], + "avg", + False, + False, + ) + verify_pool1d( + [1, 16, 31], + [4], + [4], + [1], + [0, 0], + "avg", + False, + False, + ) + verify_pool1d( + [1, 16, 32], + [2], + [2], + [1], + [0, 0], + "max", + False, + ) + verify_pool1d( + [1, 16, 31], + [3], + [3], + [1], + [2, 1], + "max", + False, + ) + verify_pool1d( + [1, 16, 31], + [3], + [3], + [1], + [2, 1], + "max", + True, + ) - verify_pool1d(1, 256, 31, 3, 3, [2, 5], "avg", False, True) - verify_pool1d(1, 256, 32, 2, 2, [0, 3], "avg", False, False) - verify_pool1d(1, 256, 31, 3, 3, [1, 4], "max", False) - verify_pool1d(1, 256, 31, 3, 3, [3, 0], "max", True) + verify_pool1d( + [1, 16, 31], + [3], + [3], + [1], + [2, 5], + "avg", + False, + True, + ) + verify_pool1d( + [1, 16, 32], + [2], + [2], + [1], + [0, 3], + "avg", + False, + False, + ) + verify_pool1d( + [1, 16, 31], + [3], + [3], + [1], + [1, 4], + "max", + False, + ) + verify_pool1d( + [1, 16, 31], + [3], + [3], + [1], + [3, 0], + "max", + True, + ) + + # Test non-1 dilations + verify_pool1d( + [1, 16, 31], + [3], + [3], + [2], + [2, 5], + "avg", + False, + True, + ) + verify_pool1d( + [1, 16, 32], + [2], + [2], + [3], + [0, 3], + "avg", + False, + False, + ) + verify_pool1d( + [1, 16, 31], + [3], + [3], + [2], + [1, 4], + "max", + False, + ) + verify_pool1d( + [1, 16, 31], + [3], + [3], + [3], + [3, 0], + "max", + True, + ) if __name__ == "__main__": - test_pool() test_pool1d() + test_pool2d() test_pool3d() test_pool_grad() test_global_pool() diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py index 2f9423104a68..4890268c907b 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/tests/python/unittest/test_auto_scheduler_common.py @@ -17,8 +17,7 @@ """Common functions for auto_scheduler test cases""" import tvm -from tvm import te, auto_scheduler -from tvm import topi +from tvm import auto_scheduler, te, topi from tvm.topi.nn.winograd_util import winograd_transform_matrices from tvm.topi.utils import get_const_tuple @@ -105,7 +104,7 @@ def conv2d_nchw_bn_relu_auto_scheduler_test( @auto_scheduler.register_workload def max_pool2d_auto_scheduler_test(N, H, W, CI, padding): data = te.placeholder((N, CI, H, W), name="Data") - out = topi.nn.pool(data, [2, 2], [1, 1], [padding, padding, padding, padding], "max") + out = topi.nn.pool2d(data, [2, 2], [1, 1], [1, 1], [padding, padding, padding, padding], "max") return [data, out] diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index 75e807456981..494973e8b573 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -17,23 +17,24 @@ """Unit tests for the Bring Your Own Datatype framework. TODO(@gussmith23 @hypercubestart) link to documentation""" -import tvm -import tvm.topi.testing import numpy as np import pytest +import tvm +import tvm.topi.testing from tvm import relay from tvm.relay.testing.layers import batch_norm_infer from tvm.target.datatype import ( + create_lower_func, + create_min_lower_func, + lower_call_pure_extern, + lower_ite, register, register_min_func, register_op, - create_lower_func, - lower_ite, - lower_call_pure_extern, - create_min_lower_func, ) from tvm.tir.op import call_pure_extern + # note: we can't use relay.testing models because params are randomly initialized, # which lead the output to have the same values # get mobilenet model from Gluon CV @@ -50,8 +51,8 @@ def get_mobilenet(): # use real image instead of random data for end-to-end model training # or else output would all be around the same value def get_cat_image(dimensions): - from tvm.contrib.download import download_testdata from PIL import Image + from tvm.contrib.download import download_testdata url = "https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png" dst = "cat.png" diff --git a/tests/python/unittest/test_te_autodiff.py b/tests/python/unittest/test_te_autodiff.py index 59b20bd11e75..7471e8a1eee4 100644 --- a/tests/python/unittest/test_te_autodiff.py +++ b/tests/python/unittest/test_te_autodiff.py @@ -15,14 +15,12 @@ # specific language governing permissions and limitations # under the License. +import numpy as np +import pytest import tvm -from tvm import te +from tvm import te, topi from tvm.testing import assert_allclose -from tvm import topi from tvm.topi.utils import get_const_tuple -import pytest - -import numpy as np def check_grad( @@ -193,10 +191,10 @@ def test_topi(): R = topi.nn.conv2d(X, topi.broadcast_to(W2, (5, 2, 3, 3)), 1, 1, 1) check_grad(R, [X, W2]) - R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], "avg") + R = topi.nn.pool2d(X, [2, 2], [1, 1], [2, 2], [0, 0, 0, 0], "avg") check_grad(R, X) - R = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], "max") + R = topi.nn.pool2d(X, [2, 2], [1, 1], [2, 2], [0, 0, 0, 0], "max") check_grad(R, X) X = te.placeholder((1, 2, 5, 5), name="X") @@ -316,23 +314,23 @@ def test_stride_dilation(): Y = topi.nn.conv2d(X, W, 3, 0, 3) check_grad(Y, [X, W]) - Y = topi.nn.pool(X, [1, 1], [1, 1], [0, 0, 0, 0], "max") + Y = topi.nn.pool2d(X, [1, 1], [1, 1], [1, 1], [0, 0, 0, 0], "max") check_grad(Y, [X]) - Y = topi.nn.pool(X, [1, 1], [2, 2], [0, 0, 0, 0], "max") + Y = topi.nn.pool2d(X, [1, 1], [1, 1], [2, 2], [0, 0, 0, 0], "max") check_grad(Y, [X]) - Y = topi.nn.pool(X, [1, 1], [3, 3], [0, 0, 0, 0], "max") + Y = topi.nn.pool2d(X, [1, 1], [1, 1], [3, 3], [0, 0, 0, 0], "max") check_grad(Y, [X]) - Y = topi.nn.pool(X, [2, 2], [1, 1], [0, 0, 0, 0], "max") + Y = topi.nn.pool2d(X, [2, 2], [1, 1], [1, 1], [0, 0, 0, 0], "max") check_grad(Y, [X]) - Y = topi.nn.pool(X, [2, 2], [2, 2], [0, 0, 0, 0], "max") + Y = topi.nn.pool2d(X, [2, 2], [1, 1], [2, 2], [0, 0, 0, 0], "max") check_grad(Y, [X]) - Y = topi.nn.pool(X, [2, 2], [3, 3], [0, 0, 0, 0], "max") + Y = topi.nn.pool2d(X, [2, 2], [1, 1], [3, 3], [0, 0, 0, 0], "max") check_grad(Y, [X]) - Y = topi.nn.pool(X, [3, 3], [1, 1], [0, 0, 0, 0], "max") + Y = topi.nn.pool2d(X, [3, 3], [1, 1], [1, 1], [0, 0, 0, 0], "max") check_grad(Y, [X]) - Y = topi.nn.pool(X, [3, 3], [2, 2], [0, 0, 0, 0], "max") + Y = topi.nn.pool2d(X, [3, 3], [1, 1], [2, 2], [0, 0, 0, 0], "max") check_grad(Y, [X]) - Y = topi.nn.pool(X, [3, 3], [3, 3], [0, 0, 0, 0], "max") + Y = topi.nn.pool2d(X, [3, 3], [1, 1], [3, 3], [0, 0, 0, 0], "max") check_grad(Y, [X]) diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 81fa46e0e9cc..9f1400c41a15 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np import tvm import tvm.testing -import numpy as np from tvm import te -from tvm.topi.nn.pooling import pool +from tvm.topi.nn.pooling import pool2d def test_tensor(): @@ -337,7 +337,9 @@ def intrin_func(ins, outs): return te.decl_tensor_intrin(P.op, intrin_func, default_buffer_params={"offset_factor": 1}) A = te.placeholder((1, 64, 16, 16), name="A") - P = pool(data=A, kernel=(3, 3), stride=(1, 1), padding=(0, 0, 0, 0), pool_type="max") + P = pool2d( + data=A, kernel=(3, 3), stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), pool_type="max" + ) s = te.create_schedule(P.op) _, oh, _, _ = P.op.axis intrin = intrin_pool()