Skip to content

Commit

Permalink
Enable bool type as storage type (#1853)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Oct 7, 2018
1 parent ea07f74 commit f1d815c
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 12 deletions.
2 changes: 2 additions & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ inline TVMType Type2TVMType(Type t) {
// Get number of bytes considering vector type.
inline int GetVectorBytes(Type dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == Bool()) return 1;
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
Expand Down
12 changes: 11 additions & 1 deletion include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,9 @@ inline const char* TypeCode2Str(int type_code) {

#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os;
}
os << TypeCode2Str(t.code);
if (t.code == kHandle) return os;
os << static_cast<int>(t.bits);
Expand All @@ -890,7 +893,9 @@ inline std::string TVMType2String(TVMType t) {
os << t;
return os.str();
#else
std::string repr = "";
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
return "bool";
}
repr += TypeCode2Str(t.code);
if (t.code == kHandle) return repr;
repr += std::to_string(static_cast<int>(t.bits));
Expand Down Expand Up @@ -920,6 +925,11 @@ inline TVMType String2TVMType(std::string s) {
t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else if (s == "bool") {
t.code = kDLUInt;
t.bits = 1;
t.lanes = 1;
return t;
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def __init__(self, type_str):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)

if type_str == "bool":
self.bits = 1
self.type_code = 1
self.lanes = 1
return

arr = type_str.split("x")
head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1
Expand All @@ -73,6 +80,8 @@ def __init__(self, type_str):


def __repr__(self):
if self.bits == 1 and self.lanes == 1:
return "bool"
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return;
}
} else if (t == Bool()) {
os << "bool"; return;
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
if (t.lanes() != 1) {
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ void CodeGenMetal::PrintType(Type t, std::ostream& os) { // NOLINT(*)
<< "do not yet support vector types";
os << "void*"; return;
}
if (t == Bool()) {
os << "bool"; return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
Expand Down
3 changes: 3 additions & 0 deletions src/codegen/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*)
<< "do not yet support vector types";
os << "void*"; return;
}
if (t == Bool()) {
os << "bool"; return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
Expand Down
21 changes: 19 additions & 2 deletions src/codegen/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,25 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
const tvm::Type& from = value.stype.type;
const tvm::Type& to = dst_type.type;
CHECK_EQ(from.lanes(), to.lanes());

if (from.is_int() && to.is_int()) {
if (from == Bool()) {
if (to.is_int()) {
return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0));
} else if (to.is_uint()) {
return Select(value, UIntImm(dst_type, 1), UIntImm(dst_type, 0));
} else {
LOG(FATAL) << "cannot cast from " << from << " to " << to;
return Value();
}
} else if (to == Bool()) {
if (from.is_int()) {
return NE(value, IntImm(value.stype, 0));
} else if (to.is_uint()) {
return NE(value, UIntImm(value.stype, 0));
} else {
LOG(FATAL) << "cannot cast from " << from << " to " << to;
return Value();
}
} else if (from.is_int() && to.is_int()) {
return MakeValue(spv::OpSConvert, dst_type, value);
} else if (from.is_uint() && to.is_uint()) {
return MakeValue(spv::OpUConvert, dst_type, value);
Expand Down
27 changes: 22 additions & 5 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,25 +260,42 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, Type dtype) {
}

Expr Buffer::vload(Array<Expr> begin, Type dtype) const {
// specially handle bool, stored as Int(8)
const BufferNode* n = operator->();
CHECK(dtype.element_of() == n->dtype.element_of() &&
dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
return ir::Load::make(
dtype, n->data, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
if (dtype == Bool()) {
return ir::Cast::make(
Bool(),
ir::Load::make(
Int(8), n->data, BufferOffset(n, begin, Int(8)),
const_true()));
} else {
return ir::Load::make(
dtype, n->data, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
}
}

Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
// specially handle bool, stored as Int(8)
const BufferNode* n = operator->();
Type dtype = value.type();
CHECK(dtype.element_of() == n->dtype.element_of() &&
dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
if (value.type() == Bool()) {
return ir::Store::make(n->data,
ir::Cast::make(Int(8), value),
BufferOffset(n, begin, Int(8)),
const_true());
} else {
return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
}
}

Buffer Buffer::MakeStrideView() const {
Expand Down
10 changes: 8 additions & 2 deletions src/pass/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,16 @@ class StorageFlattener : public IRMutator {
buf_map_[key].released = true;
Stmt ret;

Type storage_type = e.buffer->dtype;
// specially handle bool, lower its storage
// type to be Int(8)(byte)
if (storage_type == Bool()) {
storage_type = Int(8);
}
if (strides.size() != 0) {
int first_dim = 0;
ret = Allocate::make(
e.buffer->data, e.buffer->dtype,
e.buffer->data, storage_type,
{arith::ComputeExpr<Mul>(e.buffer->strides[first_dim], e.buffer->shape[first_dim])},
make_const(Bool(e.buffer->dtype.lanes()), true), body);
} else {
Expand All @@ -203,7 +209,7 @@ class StorageFlattener : public IRMutator {
shape.push_back(make_const(Int(32), 1));
}
ret = Allocate::make(
e.buffer->data, e.buffer->dtype, shape,
e.buffer->data, storage_type, shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
}
ret = AttrStmt::make(
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/builtin_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
* \file builtin_fp16.cc
* \brief Functions for conversion between fp32 and fp16
*/

#include <builtin_fp16.h>
#include <tvm/runtime/c_runtime_api.h>

extern "C" {

// disable under msvc
#ifndef _MSC_VER

TVM_WEAK uint16_t __gnu_f2h_ieee(float a) {
return __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(a);
}
Expand All @@ -17,4 +19,5 @@ TVM_WEAK float __gnu_h2f_ieee(uint16_t a) {
return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(a);
}

#endif
}
2 changes: 2 additions & 0 deletions src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ inline void VerifyDataType(DLDataType dtype) {
if (dtype.code == kDLFloat) {
CHECK_EQ(dtype.bits % 8, 0);
} else {
// allow uint1 as a special flag for bool.
if (dtype.bits == 1 && dtype.code == kDLUInt) return;
CHECK_EQ(dtype.bits % 8, 0);
}
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
Expand Down
58 changes: 58 additions & 0 deletions tests/python/unittest/test_codegen_bool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""codegen related to bool types"""

import tvm
import numpy as np

def test_cmp_load_store():
n = 32
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) > B(*i), name='C')
D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i), A(*i) > 1), name="D")


def check_llvm():
if not tvm.module.enabled("llvm"):
return
s = tvm.create_schedule(D.op)
xo, xi = s[C].split(C.op.axis[0], factor=4)
xo1, xo2 = s[C].split(xo, factor=13)
s[C].parallel(xo2)
# BUILD and invoke the kernel.
f = tvm.build(s, [A, B, D], "llvm")
ctx = tvm.cpu(0)
a_np = np.random.uniform(size=n).astype(A.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
f(a, b, d)
np.testing.assert_equal(
d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1))

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
return
s = tvm.create_schedule(D.op)
for stage in [C, D]:
xo, xi = s[stage].split(stage.op.axis[0], factor=4)
s[stage].bind(xo, tvm.thread_axis("blockIdx.x"))
s[stage].bind(xi, tvm.thread_axis("threadIdx.x"))
f = tvm.build(s, [A, B, D], device)
a_np = np.random.uniform(size=n).astype(A.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
f(a, b, d)
np.testing.assert_equal(
d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1))


check_llvm()
for device in ["vulkan", "opencl", "cuda", "rocm", "metal"]:
check_device(device)



if __name__ == "__main__":
test_cmp_load_store()
2 changes: 1 addition & 1 deletion tests/python/unittest/test_lang_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_dtype():
x = tvm.var('x')
assert x.dtype == 'int32'
y = tvm.var('y')
assert (x > y).dtype == 'uint1'
assert (x > y).dtype == 'bool'


def test_any():
Expand Down

0 comments on commit f1d815c

Please sign in to comment.