diff --git a/topi/python/topi/testing/pool_grad_python.py b/topi/python/topi/testing/pool_grad_python.py index adb2d05e2adf..d916b2edb181 100644 --- a/topi/python/topi/testing/pool_grad_python.py +++ b/topi/python/topi/testing/pool_grad_python.py @@ -14,11 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name, unused-argument, unused-variable """Gradient of pooling in python""" import numpy as np -def pool_grad_nchw(a_np, out_grad_np, pool_size, strides, padding, pool_type, ceil_mode, +def pool_grad_nchw(a_np, out_grad_np, + pool_size, + strides, + padding, + pool_type, + ceil_mode, count_include_pad=True): """pool_grad for NCHW layout in python""" dtype = a_np.dtype @@ -47,8 +53,8 @@ def pool_grad_nchw(a_np, out_grad_np, pool_size, strides, padding, pool_type, ce # take the first element, as they are the same across batch and channel pad_count = pad_count.ravel()[0] pad_pool_grad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] += \ - out_grad_np[:, :, i, j].reshape(n,ic,1,1) / np.maximum(pad_count, 1) - elif pool_type =='max': + out_grad_np[:, :, i, j].reshape(n, ic, 1, 1) / np.maximum(pad_count, 1) + elif pool_type == 'max': for i in range(oh): for j in range(ow): a_patch = pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw]