Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

leaky_relu bug fix #1218

Merged
merged 1 commit into from
Jun 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nnvm/src/top/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ NNVM_REGISTER_OP(leaky_relu)
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
return Array<Tensor>{ topi::leaky_relu<float>(inputs[0], 0.0, param.alpha) };
return Array<Tensor>{ topi::leaky_relu(inputs[0], param.alpha) };
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
Expand Down
7 changes: 2 additions & 5 deletions topi/include/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,14 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
* \brief Creates an operation that performs a leaky rectified linear unit
*
* \param t The input tensor
* \param threshold The relu threshold (default 0)
* \param alpha The slope for the small gradient when t < threshold
* \param alpha The slope for the small gradient when t < 0
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the relu operation
*/
template <typename T>
inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
T threshold = static_cast<T>(0),
T alpha = static_cast<T>(0.1),
double alpha = 0.1,
std::string name = "tensor",
std::string tag = kElementWise) {
return tvm::compute(
Expand Down
2 changes: 1 addition & 1 deletion topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ TVM_REGISTER_GLOBAL("topi.nn.relu")

TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = leaky_relu<float>(args[0]);
*rv = leaky_relu(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.nn.prelu")
Expand Down
4 changes: 2 additions & 2 deletions topi/tests/python_cpp/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def verify_leaky_relu(m, alpha):
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.generic.schedule_injective(target, [B])

a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
a_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0) + a_np * (a_np < 0) * alpha
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_relu():
verify_relu(10, 128, dtype)

def test_leaky_relu():
verify_leaky_relu(100, 0.1)
verify_leaky_relu(100, 0.5)

def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))
Expand Down