Skip to content

Commit

Permalink
leaky_relu bug fix (#1218)
Browse files Browse the repository at this point in the history
  • Loading branch information
PariksheetPinjari909 authored and tqchen committed Jun 1, 2018
1 parent 0a2280e commit 4e6df19
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 9 deletions.
2 changes: 1 addition & 1 deletion nnvm/src/top/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ NNVM_REGISTER_OP(leaky_relu)
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
return Array<Tensor>{ topi::leaky_relu<float>(inputs[0], 0.0, param.alpha) };
return Array<Tensor>{ topi::leaky_relu(inputs[0], param.alpha) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
Expand Down
7 changes: 2 additions & 5 deletions topi/include/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,14 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
* \brief Creates an operation that performs a leaky rectified linear unit
*
* \param t The input tensor
* \param threshold The relu threshold (default 0)
* \param alpha The slope for the small gradient when t < threshold
* \param alpha The slope for the small gradient when t < 0
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the relu operation
*/
template <typename T>
inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
T threshold = static_cast<T>(0),
T alpha = static_cast<T>(0.1),
double alpha = 0.1,
std::string name = "tensor",
std::string tag = kElementWise) {
return tvm::compute(
Expand Down
2 changes: 1 addition & 1 deletion topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ TVM_REGISTER_GLOBAL("topi.nn.relu")

TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = leaky_relu<float>(args[0]);
*rv = leaky_relu(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.nn.prelu")
Expand Down
4 changes: 2 additions & 2 deletions topi/tests/python_cpp/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def verify_leaky_relu(m, alpha):
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.generic.schedule_injective(target, [B])

a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
a_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0) + a_np * (a_np < 0) * alpha
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_relu():
verify_relu(10, 128, dtype)

def test_leaky_relu():
verify_leaky_relu(100, 0.1)
verify_leaky_relu(100, 0.5)

def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))
Expand Down

0 comments on commit 4e6df19

Please sign in to comment.