From 3aa2eaed5ca93f0e06079b75e34fcee03c62417a Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 24 Jul 2019 18:06:36 -0700 Subject: [PATCH] [TOPI] Average Pool2D Bug. (#3607) * [TOPI] Average Pool2D Bug. Issue - https://github.com/dmlc/tvm/issues/3581 * Add uint16 test. --- tests/python/relay/test_op_level2.py | 21 +++++++++++++++++++++ topi/include/topi/nn/pooling.h | 19 +++++++++++++------ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index d2aca1890f85..4e8fe2cbedab 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -264,6 +264,25 @@ def _test_pool2d(opfunc, reffunc): op_res1 = intrp1.evaluate(func)(data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) +def _test_pool2d_int(opfunc, reffunc, dtype): + n, c, h, w = tvm.var("n"), 10, 224, 224 + x = relay.var("x", relay.TensorType((n, c, h, w), dtype)) + y = opfunc(x, pool_size=(1, 1)) + assert "pool_size=" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, 10, 224, 224), dtype) + # test execution + dtype = "int32" + dshape = (1, 3, 28, 28) + x = relay.var("x", shape=dshape, dtype=dtype) + y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + func = relay.Function([x], y) + data = np.random.random_integers(low=-128, high=128, size=dshape) + ref_res = reffunc(data.reshape(1,3,14,2,14,2), axis=(3,5)).astype(dtype) + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) def _test_global_pool2d(opfunc, reffunc): n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224 @@ -294,6 +313,8 @@ def _test_global_pool2d(opfunc, reffunc): def test_pool2d(): _test_pool2d(relay.nn.max_pool2d, np.max) _test_pool2d(relay.nn.avg_pool2d, np.mean) + _test_pool2d_int(relay.nn.avg_pool2d, np.mean, 'int32') + _test_pool2d_int(relay.nn.avg_pool2d, np.mean, 'uint16') _test_global_pool2d(relay.nn.global_max_pool2d, np.max) _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean) diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index f7aa1ddfd773..2eff244179ba 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -130,19 +130,26 @@ inline Tensor pool_impl(const Tensor& x, return tvm::max(temp(indices), { dheight, dwidth }); }, "tensor", "pool_max"); } else if (pool_type == kAvgPool) { + // Pad the inputs auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; - auto tavg = [&](const Array& output, Expr divide_factor) { + + // TVM compute for summing the pooling window. + auto pool_sum = tvm::compute(out_shape, + [&](const Array& output) { Array indices; for (const Var& var : output) indices.push_back(var); indices.Set(height_axis, output[height_axis] * stride_height + dheight); indices.Set(width_axis, output[width_axis] * stride_width + dwidth); - return tvm::sum(temp(indices) / divide_factor, { dheight, dwidth }); - }; + return tvm::sum(temp(indices), { dheight, dwidth }); + }, "tensor", "pool_sum"); + // TVM compute for dividing the reduced window sum by kernel size. return tvm::compute(out_shape, [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); if (count_include_pad) { - return tavg(output, kernel_height * kernel_width); + return pool_sum(indices) / (kernel_height * kernel_width); } else { Expr h_start = output[height_axis] * stride_height - pad_top; Expr w_start = output[width_axis] * stride_width - pad_left; @@ -152,9 +159,9 @@ inline Tensor pool_impl(const Tensor& x, w_start = ir::Max::make(w_start, make_const(Int(32), 0)); Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start), make_const(Int(32), 1)); - return tavg(output, divide_factor); + return pool_sum(indices) / divide_factor; } - }, "tensor", "pool_avg"); + }, "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x;