Skip to content

Commit

Permalink
[Vulkan] Added conversion from bool to float. (apache#3513)
Browse files Browse the repository at this point in the history
* Added bool to float conversion support to spirv ir builder.

* Added unittest for vulkan bool conversion.

* Typo fix.
  • Loading branch information
jwfromm authored and wweic committed Jul 11, 2019
1 parent 71872f6 commit 4aaf241
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/codegen/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0));
} else if (to.is_uint()) {
return Select(value, UIntImm(dst_type, 1), UIntImm(dst_type, 0));
} else if (to.is_float()) {
return MakeValue(spv::OpConvertUToF, dst_type,
Select(value, UIntImm(t_uint32_, 1), UIntImm(t_uint32_, 0)));
} else {
LOG(FATAL) << "cannot cast from " << from << " to " << to;
return Value();
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ void VulkanWorkspace::Init() {
try {
instance_ = CreateInstance();
context_ = GetContext(instance_);
LOG(INFO) << "Initialzie Vulkan with " << context_.size() << " devices..";
LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
for (size_t i = 0; i < context_.size(); ++i) {
LOG(INFO) << "vulkan(" << i
<< ")=\'" << context_[i].phy_device_prop.deviceName
Expand Down
7 changes: 4 additions & 3 deletions tests/python/unittest/test_codegen_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_cmp_load_store():
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) > B(*i), name='C')
D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i), A(*i) > 1), name="D")
D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i),
A(*i) > 1).astype('float32'), name="D")


def check_llvm():
Expand All @@ -43,7 +44,7 @@ def check_llvm():
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
f(a, b, d)
np.testing.assert_equal(
d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1))
d.asnumpy(), np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype('float32'))

def check_device(device):
ctx = tvm.context(device, 0)
Expand All @@ -61,7 +62,7 @@ def check_device(device):
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
f(a, b, d)
np.testing.assert_equal(
d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1))
d.asnumpy(), np.logical_and(a.asnumpy() > b.asnumpy(), a.asnumpy() > 1).astype('float32'))


check_llvm()
Expand Down

0 comments on commit 4aaf241

Please sign in to comment.