Skip to content

Commit

Permalink
Prelu bug fix (apache#1358)
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN87 authored and tqchen committed Jun 30, 2018
1 parent b6ead6d commit df2d286
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion nnvm/src/top/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ where :math:`*` is an channelwise multiplication for each sample in the
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed);
return Array<Tensor>{ topi::prelu<float>(inputs[0], inputs[1], param.axis)};
return Array<Tensor>{ topi::prelu(inputs[0], inputs[1], param.axis)};
})
.set_support_level(4);

Expand Down
1 change: 0 additions & 1 deletion topi/include/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
*
* \return A Tensor whose op member is the relu operation
*/
template <typename T>
inline tvm::Tensor prelu(const tvm::Tensor &x,
const tvm::Tensor &slope,
const int axis = 1,
Expand Down
2 changes: 1 addition & 1 deletion topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")

TVM_REGISTER_GLOBAL("topi.nn.prelu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = prelu<float>(args[0], args[1]);
*rv = prelu(args[0], args[1], args[2]);
});

TVM_REGISTER_GLOBAL("topi.nn.pad")
Expand Down
9 changes: 5 additions & 4 deletions topi/tests/python/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ def verify_leaky_relu(m, alpha):
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)


def verify_prelu(x, w):
def verify_prelu(x, w, axis, weight_reshape):
X = tvm.placeholder((x), name='X')
W = tvm.placeholder((w), name='W')
x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)

def _prelu_numpy(x, W):
return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x
return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x

B = topi.nn.prelu(X, W)
B = topi.nn.prelu(X, W, axis)
s = tvm.create_schedule([B.op])

ctx = tvm.cpu(0)
Expand All @@ -79,7 +79,8 @@ def test_leaky_relu():
verify_leaky_relu(100, 0.1)

def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))
verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))

if __name__ == "__main__":
test_schedule_big_array()
Expand Down
9 changes: 5 additions & 4 deletions topi/tests/python_cpp/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def verify_leaky_relu(m, alpha):
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

def verify_prelu(x, w):
def verify_prelu(x, w, axis, weight_reshape):
X = tvm.placeholder((x), name='X')
W = tvm.placeholder((w), name='W')
x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
def _prelu_numpy(x, W):
return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x
return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x

out_np = _prelu_numpy(x_np, w_np)
B = topi.cpp.nn.prelu(X, W)
B = topi.cpp.nn.prelu(X, W, axis)
device = "llvm"
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.generic.schedule_injective(target, [B])
Expand All @@ -81,7 +81,8 @@ def test_leaky_relu():
verify_leaky_relu(100, 0.5)

def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))
verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))

if __name__ == "__main__":
test_relu()
Expand Down

0 comments on commit df2d286

Please sign in to comment.