From 2fc766225fcf0a74a0fec52c188246acc4f7ae1a Mon Sep 17 00:00:00 2001 From: XiaolongMeng Date: Mon, 25 Mar 2019 10:27:19 +0800 Subject: [PATCH] fix prelu, now can use on 2d input and add one test (#2875) --- topi/include/topi/nn.h | 1 - topi/python/topi/nn/elemwise.py | 2 +- topi/tests/python/test_topi_relu.py | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 00c3f999853d..653c0a5f70ce 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -97,7 +97,6 @@ inline tvm::Tensor prelu(const tvm::Tensor &x, const int axis = 1, std::string name = "tensor", std::string tag = kBroadcast) { - CHECK_EQ(4, x->shape.size()); CHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. "; CHECK(topi::detail::GetConstInt(slope->shape[0]) == diff --git a/topi/python/topi/nn/elemwise.py b/topi/python/topi/nn/elemwise.py index 14a747e67610..6a2697795f4d 100644 --- a/topi/python/topi/nn/elemwise.py +++ b/topi/python/topi/nn/elemwise.py @@ -69,7 +69,7 @@ def prelu(x, slope, axis=1): [http://arxiv.org/pdf/1502.01852v1.pdf] """ - assert len(x.shape) == 4 and len(slope.shape) == 1 + assert len(slope.shape) == 1 assert axis < len(x.shape) assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis]) diff --git a/topi/tests/python/test_topi_relu.py b/topi/tests/python/test_topi_relu.py index a7ff64f0f759..5aa9c1ee57a0 100644 --- a/topi/tests/python/test_topi_relu.py +++ b/topi/tests/python/test_topi_relu.py @@ -83,6 +83,7 @@ def test_leaky_relu(): def test_prelu(): verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1)) verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1)) + verify_prelu((1, 3), (3,), 1, (3, )) if __name__ == "__main__": test_schedule_big_array()