From c1745377b01902ff54abe251e2a9b43953a90e06 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 20 May 2019 10:07:01 -0700 Subject: [PATCH] [BugFix] Fix bug in cast to bool (#3207) --- src/codegen/llvm/codegen_llvm.cc | 8 ++++++ topi/tests/python/test_topi_math.py | 44 +++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 7946f906125f..bedcdc79ff1f 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -537,6 +537,14 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { if (value->getType() == target) return value; if (to.is_handle()) { return builder_->CreateBitCast(value, target); + } else if (to.is_uint() && to.bits() == 1) { + if (from.is_float()) { + llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.); + return builder_->CreateFCmpONE(value, zero); + } else { + llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0); + return builder_->CreateICmpNE(value, zero); + } } else if (!from.is_float() && !to.is_float()) { return builder_->CreateIntCast(value, target, from.is_int()); } else if (from.is_float() && to.is_int()) { diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index c180bc77e829..d6df450628d2 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -19,6 +19,7 @@ import topi import topi.testing from topi import util +from common import get_all_backend def test_util(): @@ -59,8 +60,7 @@ def check_device(device): foo(a, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) - for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx', 'sdaccel', - 'aocl_sw_emu']: + for device in get_all_backend(): check_device(device) @@ -77,6 +77,46 @@ def check_device(device): test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True) + +def test_cast(): + def verify(from_dtype, to_dtype, low=-100, high=100): + shape = (5, 4) + A = tvm.placeholder(shape, dtype=from_dtype, name="A") + B = topi.cast(A, to_dtype) + + if from_dtype == "bool": + a_np = np.random.choice([True, False], size=shape) + else: + a_np = np.random.uniform(low, high, size=shape).astype(from_dtype) + if to_dtype == "bool": + a_np = a_np - a_np[2, 3] + b_np = a_np.astype(to_dtype) + + for device in get_all_backend(): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + foo = tvm.build(s, [A, B], device) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.empty(shape=shape, dtype=to_dtype, ctx=ctx) + foo(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np) + + verify("int32", "float32") + verify("int32", "float64") + verify("int32", "bool") + verify("float32", "int32") + verify("float32", "float64") + verify("float32", "bool") + verify("bool", "float32") + verify("bool", "int32") + + if __name__ == "__main__": test_util() test_ewise() + test_cast()