Skip to content

Commit

Permalink
[BugFix] Fix bug in cast to bool (apache#3207)
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon authored and Wei Chen committed Jun 26, 2019
1 parent 2039d2a commit c174537
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,14 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
if (value->getType() == target) return value;
if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (to.is_uint() && to.bits() == 1) {
if (from.is_float()) {
llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.);
return builder_->CreateFCmpONE(value, zero);
} else {
llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0);
return builder_->CreateICmpNE(value, zero);
}
} else if (!from.is_float() && !to.is_float()) {
return builder_->CreateIntCast(value, target, from.is_int());
} else if (from.is_float() && to.is_int()) {
Expand Down
44 changes: 42 additions & 2 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import topi
import topi.testing
from topi import util
from common import get_all_backend


def test_util():
Expand Down Expand Up @@ -59,8 +60,7 @@ def check_device(device):
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx', 'sdaccel',
'aocl_sw_emu']:
for device in get_all_backend():
check_device(device)


Expand All @@ -77,6 +77,46 @@ def check_device(device):
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True)


def test_cast():
def verify(from_dtype, to_dtype, low=-100, high=100):
shape = (5, 4)
A = tvm.placeholder(shape, dtype=from_dtype, name="A")
B = topi.cast(A, to_dtype)

if from_dtype == "bool":
a_np = np.random.choice([True, False], size=shape)
else:
a_np = np.random.uniform(low, high, size=shape).astype(from_dtype)
if to_dtype == "bool":
a_np = a_np - a_np[2, 3]
b_np = a_np.astype(to_dtype)

for device in get_all_backend():
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
continue
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.empty(shape=shape, dtype=to_dtype, ctx=ctx)
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np)

verify("int32", "float32")
verify("int32", "float64")
verify("int32", "bool")
verify("float32", "int32")
verify("float32", "float64")
verify("float32", "bool")
verify("bool", "float32")
verify("bool", "int32")


if __name__ == "__main__":
test_util()
test_ewise()
test_cast()

0 comments on commit c174537

Please sign in to comment.