Skip to content

Commit

Permalink
Fix bugs with C++ TOPI flatten and relu (#869)
Browse files Browse the repository at this point in the history
* Fix bugs with C++ TOPI flatten and relu

* Added regression tests. Fixed typo in CMakeLists.txt. Fixed topi cpp import removed.
  • Loading branch information
alex-weaver authored and tqchen committed Feb 4, 2018
1 parent 01c088c commit 633f7f2
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ if(USE_GRAPH_RUNTIME)
endif(USE_GRAPH_RUNTIME)

if(USE_LLVM)
find_spackage(LLVM CONFIG REQUIRED)
find_package(LLVM CONFIG REQUIRED)
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR})
Expand Down
5 changes: 4 additions & 1 deletion topi/include/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
std::string tag = kElementWise) {
return tvm::compute(
t->shape,
[&](const tvm::Array<tvm::Var>& i) { return tvm::max(t(i), threshold); },
[&](const tvm::Array<tvm::Var>& i) {
auto threshold_const = tvm::make_const(t->dtype, threshold);
return tvm::max(t(i), threshold_const);
},
name,
tag);
}
Expand Down
2 changes: 1 addition & 1 deletion topi/include/topi/nn/flatten.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ inline Tensor flatten(const Tensor& x,
index.push_back(i);
std::reverse(index.begin(), index.end());
return x(index);
});
}, name, tag);
}

} // namespace nn
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from . import testing
from . import util
from . import rocm
from . import cpp
5 changes: 5 additions & 0 deletions topi/tests/python_cpp/test_topi_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def test_apply(func, name):
test_apply(topi.cpp.log, "log")
test_apply(topi.cpp.sqrt, "sqrt")

def test_flatten_tag():
A = tvm.placeholder((3, 4), name='A')
B = topi.cpp.nn.flatten(A)
assert B.op.tag == topi.tag.INJECTIVE

if __name__ == "__main__":
test_util()
test_ewise()
test_flatten_tag()
8 changes: 5 additions & 3 deletions topi/tests/python_cpp/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import topi
from topi.util import get_const_tuple

def verify_relu(m, n):
A = tvm.placeholder((m, n), name='A')
def verify_relu(m, n, dtype):
A = tvm.placeholder((m, n), name='A', dtype=dtype)
B = topi.cpp.nn.relu(A)
assert B.dtype == dtype

a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0)
Expand Down Expand Up @@ -51,7 +52,8 @@ def verify_leaky_relu(m, alpha):


def test_relu():
verify_relu(10, 128)
for dtype in ['float32', 'float64', 'int32', 'int16', 'int8', 'int64']:
verify_relu(10, 128, dtype)

def test_leaky_relu():
verify_leaky_relu(100, 0.1)
Expand Down

0 comments on commit 633f7f2

Please sign in to comment.