From 4aaf241e49ae430b0534328cd72175f38e5c0ea5 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 8 Jul 2019 20:20:22 -0700 Subject: [PATCH] [Vulkan] Added conversion from bool to float. (#3513) * Added bool to float conversion support to spirv ir builder. * Added unittest for vulkan bool conversion. * Typo fix. --- src/codegen/spirv/ir_builder.cc | 3 +++ src/runtime/vulkan/vulkan_device_api.cc | 2 +- tests/python/unittest/test_codegen_bool.py | 7 ++++--- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 72b68807a7d9..d6ba9e40c123 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -462,6 +462,9 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0)); } else if (to.is_uint()) { return Select(value, UIntImm(dst_type, 1), UIntImm(dst_type, 0)); + } else if (to.is_float()) { + return MakeValue(spv::OpConvertUToF, dst_type, + Select(value, UIntImm(t_uint32_, 1), UIntImm(t_uint32_, 0))); } else { LOG(FATAL) << "cannot cast from " << from << " to " << to; return Value(); diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 7ebe278f3032..da04acdcbc31 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -683,7 +683,7 @@ void VulkanWorkspace::Init() { try { instance_ = CreateInstance(); context_ = GetContext(instance_); - LOG(INFO) << "Initialzie Vulkan with " << context_.size() << " devices.."; + LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices.."; for (size_t i = 0; i < context_.size(); ++i) { LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName diff --git a/tests/python/unittest/test_codegen_bool.py b/tests/python/unittest/test_codegen_bool.py index 0a9f8fb8331b..934812b36a6a 100644 --- a/tests/python/unittest/test_codegen_bool.py +++ b/tests/python/unittest/test_codegen_bool.py @@ -24,7 +24,8 @@ def test_cmp_load_store(): A = tvm.placeholder((n,), name='A') B = tvm.placeholder((n,), name='B') C = tvm.compute(A.shape, lambda *i: A(*i) > B(*i), name='C') - D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i), A(*i) > 1), name="D") + D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i), + A(*i) > 1).astype('float32'), name="D") def check_llvm(): @@ -43,7 +44,7 @@ def check_llvm(): d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx) f(a, b, d) np.testing.assert_equal( - d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1)) + d.asnumpy(), np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype('float32')) def check_device(device): ctx = tvm.context(device, 0) @@ -61,7 +62,7 @@ def check_device(device): d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx) f(a, b, d) np.testing.assert_equal( - d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1)) + d.asnumpy(), np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype('float32')) check_llvm()