From 2899d0aa5d748d3ac8c791ddcf8fe739fe64cbdf Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 6 Oct 2018 22:17:32 -0700 Subject: [PATCH] Enable bool type as storage type (#1853) --- include/tvm/expr.h | 2 + include/tvm/runtime/packed_func.h | 12 ++++- python/tvm/_ffi/runtime_ctypes.py | 9 ++++ src/codegen/codegen_cuda.cc | 2 + src/codegen/codegen_metal.cc | 3 ++ src/codegen/codegen_opencl.cc | 3 ++ src/codegen/spirv/ir_builder.cc | 21 +++++++- src/lang/buffer.cc | 27 ++++++++-- src/pass/storage_flatten.cc | 10 +++- src/runtime/builtin_fp16.cc | 5 +- src/runtime/ndarray.cc | 2 + tests/python/unittest/test_codegen_bool.py | 58 ++++++++++++++++++++++ tests/python/unittest/test_lang_basic.py | 2 +- 13 files changed, 144 insertions(+), 12 deletions(-) create mode 100644 tests/python/unittest/test_codegen_bool.py diff --git a/include/tvm/expr.h b/include/tvm/expr.h index e41f5f28d35b..7fdca7f6af8e 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -56,6 +56,8 @@ inline TVMType Type2TVMType(Type t) { // Get number of bytes considering vector type. inline int GetVectorBytes(Type dtype) { int data_bits = dtype.bits() * dtype.lanes(); + // allow bool to exist + if (dtype == Bool()) return 1; CHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; return data_bits / 8; diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index d204f8624a64..a8fa096e51c4 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -873,6 +873,9 @@ inline const char* TypeCode2Str(int type_code) { #ifndef _LIBCPP_SGX_NO_IOSTREAMS inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) + if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { + os << "bool"; return os; + } os << TypeCode2Str(t.code); if (t.code == kHandle) return os; os << static_cast(t.bits); @@ -890,7 +893,9 @@ inline std::string TVMType2String(TVMType t) { os << t; return os.str(); #else - std::string repr = ""; + if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { + return "bool"; + } repr += TypeCode2Str(t.code); if (t.code == kHandle) return repr; repr += std::to_string(static_cast(t.bits)); @@ -920,6 +925,11 @@ inline TVMType String2TVMType(std::string s) { t.code = kHandle; t.bits = 64; // handle uses 64 bit by default. scan = s.c_str() + 6; + } else if (s == "bool") { + t.code = kDLUInt; + t.bits = 1; + t.lanes = 1; + return t; } else { scan = s.c_str(); LOG(FATAL) << "unknown type " << s; diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 2aced1aef7d2..b17487559e50 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -48,6 +48,13 @@ def __init__(self, type_str): super(TVMType, self).__init__() if isinstance(type_str, np.dtype): type_str = str(type_str) + + if type_str == "bool": + self.bits = 1 + self.type_code = 1 + self.lanes = 1 + return + arr = type_str.split("x") head = arr[0] self.lanes = int(arr[1]) if len(arr) > 1 else 1 @@ -73,6 +80,8 @@ def __init__(self, type_str): def __repr__(self): + if self.bits == 1 and self.lanes == 1: + return "bool" x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) if self.lanes != 1: x += "x%d" % self.lanes diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 0960106ae471..2ed8d8e3ff78 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -77,6 +77,8 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } + } else if (t == Bool()) { + os << "bool"; return; } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { if (t.lanes() != 1) { diff --git a/src/codegen/codegen_metal.cc b/src/codegen/codegen_metal.cc index 3bbe98289439..031313190370 100644 --- a/src/codegen/codegen_metal.cc +++ b/src/codegen/codegen_metal.cc @@ -141,6 +141,9 @@ void CodeGenMetal::PrintType(Type t, std::ostream& os) { // NOLINT(*) << "do not yet support vector types"; os << "void*"; return; } + if (t == Bool()) { + os << "bool"; return; + } bool fail = false; if (t.is_float()) { switch (t.bits()) { diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 3d3de5e3bcf4..a0b3c2000a80 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -80,6 +80,9 @@ void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*) << "do not yet support vector types"; os << "void*"; return; } + if (t == Bool()) { + os << "bool"; return; + } bool fail = false; if (t.is_float()) { switch (t.bits()) { diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 41cb48c5854b..fdf4b9852430 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -438,8 +438,25 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { const tvm::Type& from = value.stype.type; const tvm::Type& to = dst_type.type; CHECK_EQ(from.lanes(), to.lanes()); - - if (from.is_int() && to.is_int()) { + if (from == Bool()) { + if (to.is_int()) { + return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0)); + } else if (to.is_uint()) { + return Select(value, UIntImm(dst_type, 1), UIntImm(dst_type, 0)); + } else { + LOG(FATAL) << "cannot cast from " << from << " to " << to; + return Value(); + } + } else if (to == Bool()) { + if (from.is_int()) { + return NE(value, IntImm(value.stype, 0)); + } else if (to.is_uint()) { + return NE(value, UIntImm(value.stype, 0)); + } else { + LOG(FATAL) << "cannot cast from " << from << " to " << to; + return Value(); + } + } else if (from.is_int() && to.is_int()) { return MakeValue(spv::OpSConvert, dst_type, value); } else if (from.is_uint() && to.is_uint()) { return MakeValue(spv::OpUConvert, dst_type, value); diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 69967c55a7ff..183a52f785bd 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -260,25 +260,42 @@ inline Expr BufferOffset(const BufferNode* n, Array index, Type dtype) { } Expr Buffer::vload(Array begin, Type dtype) const { + // specially handle bool, stored as Int(8) const BufferNode* n = operator->(); CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; - return ir::Load::make( - dtype, n->data, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + if (dtype == Bool()) { + return ir::Cast::make( + Bool(), + ir::Load::make( + Int(8), n->data, BufferOffset(n, begin, Int(8)), + const_true())); + } else { + return ir::Load::make( + dtype, n->data, BufferOffset(n, begin, dtype), + const_true(dtype.lanes())); + } } Stmt Buffer::vstore(Array begin, Expr value) const { + // specially handle bool, stored as Int(8) const BufferNode* n = operator->(); Type dtype = value.type(); CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; - return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + if (value.type() == Bool()) { + return ir::Store::make(n->data, + ir::Cast::make(Int(8), value), + BufferOffset(n, begin, Int(8)), + const_true()); + } else { + return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype), + const_true(dtype.lanes())); + } } Buffer Buffer::MakeStrideView() const { diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 28a6ace9bfa6..993f6294e15b 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -191,10 +191,16 @@ class StorageFlattener : public IRMutator { buf_map_[key].released = true; Stmt ret; + Type storage_type = e.buffer->dtype; + // specially handle bool, lower its storage + // type to be Int(8)(byte) + if (storage_type == Bool()) { + storage_type = Int(8); + } if (strides.size() != 0) { int first_dim = 0; ret = Allocate::make( - e.buffer->data, e.buffer->dtype, + e.buffer->data, storage_type, {arith::ComputeExpr(e.buffer->strides[first_dim], e.buffer->shape[first_dim])}, make_const(Bool(e.buffer->dtype.lanes()), true), body); } else { @@ -203,7 +209,7 @@ class StorageFlattener : public IRMutator { shape.push_back(make_const(Int(32), 1)); } ret = Allocate::make( - e.buffer->data, e.buffer->dtype, shape, + e.buffer->data, storage_type, shape, make_const(Bool(e.buffer->dtype.lanes()), true), body); } ret = AttrStmt::make( diff --git a/src/runtime/builtin_fp16.cc b/src/runtime/builtin_fp16.cc index 79c3cc474269..c920c9571f38 100644 --- a/src/runtime/builtin_fp16.cc +++ b/src/runtime/builtin_fp16.cc @@ -3,12 +3,14 @@ * \file builtin_fp16.cc * \brief Functions for conversion between fp32 and fp16 */ - #include #include extern "C" { +// disable under msvc +#ifndef _MSC_VER + TVM_WEAK uint16_t __gnu_f2h_ieee(float a) { return __truncXfYf2__(a); } @@ -17,4 +19,5 @@ TVM_WEAK float __gnu_h2f_ieee(uint16_t a) { return __extendXfYf2__(a); } +#endif } diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 574111e39b64..0ffa4c174544 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -20,6 +20,8 @@ inline void VerifyDataType(DLDataType dtype) { if (dtype.code == kDLFloat) { CHECK_EQ(dtype.bits % 8, 0); } else { + // allow uint1 as a special flag for bool. + if (dtype.bits == 1 && dtype.code == kDLUInt) return; CHECK_EQ(dtype.bits % 8, 0); } CHECK_EQ(dtype.bits & (dtype.bits - 1), 0); diff --git a/tests/python/unittest/test_codegen_bool.py b/tests/python/unittest/test_codegen_bool.py new file mode 100644 index 000000000000..e2592c416345 --- /dev/null +++ b/tests/python/unittest/test_codegen_bool.py @@ -0,0 +1,58 @@ +"""codegen related to bool types""" + +import tvm +import numpy as np + +def test_cmp_load_store(): + n = 32 + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) > B(*i), name='C') + D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i), A(*i) > 1), name="D") + + + def check_llvm(): + if not tvm.module.enabled("llvm"): + return + s = tvm.create_schedule(D.op) + xo, xi = s[C].split(C.op.axis[0], factor=4) + xo1, xo2 = s[C].split(xo, factor=13) + s[C].parallel(xo2) + # BUILD and invoke the kernel. + f = tvm.build(s, [A, B, D], "llvm") + ctx = tvm.cpu(0) + a_np = np.random.uniform(size=n).astype(A.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) + d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx) + f(a, b, d) + np.testing.assert_equal( + d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1)) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + return + s = tvm.create_schedule(D.op) + for stage in [C, D]: + xo, xi = s[stage].split(stage.op.axis[0], factor=4) + s[stage].bind(xo, tvm.thread_axis("blockIdx.x")) + s[stage].bind(xi, tvm.thread_axis("threadIdx.x")) + f = tvm.build(s, [A, B, D], device) + a_np = np.random.uniform(size=n).astype(A.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) + d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx) + f(a, b, d) + np.testing.assert_equal( + d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1)) + + + check_llvm() + for device in ["vulkan", "opencl", "cuda", "rocm", "metal"]: + check_device(device) + + + +if __name__ == "__main__": + test_cmp_load_store() diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index bf25ca3dfc85..079123d96ca0 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -79,7 +79,7 @@ def test_dtype(): x = tvm.var('x') assert x.dtype == 'int32' y = tvm.var('y') - assert (x > y).dtype == 'uint1' + assert (x > y).dtype == 'bool' def test_any():