Skip to content

Commit

Permalink
[API][v0.3] Introduce Struct Data Type to HeteroCL (#157)
Browse files Browse the repository at this point in the history
* initial attempt for introducing struct

* add a simple test for struct; more are coming ...

* add complex test

* errors from nn.py; need to look into that ...

* add C codegen for struct

* the previous C codegen was not complete, the other conversion direction was not taken care of

* clean up the code
  • Loading branch information
seanlatias authored Feb 14, 2020
1 parent f1570b1 commit 8906f99
Show file tree
Hide file tree
Showing 15 changed files with 279 additions and 29 deletions.
1 change: 1 addition & 0 deletions hlib/python/hlib/op/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def conv2d(
d = []
for i in range(len(padding)):
p.append(tvm_to_primitive(padding[i]))
for i in range(len(strides)):
s.append(tvm_to_primitive(strides[i]))
d.append(tvm_to_primitive(dilation[i]))
strides = s
Expand Down
7 changes: 4 additions & 3 deletions python/heterocl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ def placeholder(shape, name=None, dtype=None):
"""
name = util.get_name("placeholder", name)
dtype = util.get_dtype(dtype)

tvm_dtype = types.dtype_to_str(dtype)

if shape == ():
return Scalar(tvm_api._Var(name, dtype))
return Scalar(tvm_api._Var(name, tvm_dtype))
tensor = Tensor(shape, dtype, name)
tensor.tensor = tvm_api._Placeholder(tensor.buf.shape, dtype, name)
tensor.tensor = tvm_api._Placeholder(tensor.buf.shape, tvm_dtype, name)

# placeholder is also a stage
stage = Stage(name)
Expand Down
31 changes: 27 additions & 4 deletions python/heterocl/compute_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from collections import OrderedDict
from .tvm import expr as _expr, stmt as _stmt, make as _make
from .tvm.api import _IterVar, min_value
from .util import get_index, get_name, get_type, get_dtype, make_for, CastRemover
from .util import get_index, get_name, get_type, get_tvm_dtype, make_for, CastRemover
from .tensor import Scalar, Tensor, TensorSlice
from .types import Struct, dtype_to_str
from .schedule import Stage
from .debug import APIError
from .dsl import if_, for_
Expand Down Expand Up @@ -117,7 +118,7 @@ def compute_body(name,
if not return_tensor:
stage.input_stages.add(tensor.last_update)
else:
tensor = Tensor(shape, stage._dtype, name, stage._buf)
tensor = Tensor(shape, stage._hcl_dtype, name, stage._buf)
buffer_var = tensor._buf.data
dtype = tensor.dtype
shape = tensor.shape
Expand All @@ -137,6 +138,28 @@ def compute_body(name,
stmt = stage.pop_stmt()
stmt = ReplaceReturn(buffer_var, dtype, index).mutate(stmt)
stmt = make_for(indices, stmt, 0)
elif isinstance(ret, (tuple, list)):
indices = lambda_ivs
index, _, _ = get_index(shape, indices, 0)
hcl_dtype = tensor.hcl_dtype
if not isinstance(hcl_dtype, Struct):
raise TensorError("Cannot assign a tuple/list to a non-struct-type tensor")
start = 0
end = 0
for sdtype, expr in zip(hcl_dtype.dtype_dict.values(), ret):
end = start + sdtype.bits
sdtype = dtype_to_str(sdtype)
load = _make.Load(dtype, buffer_var, index)
expr = _make.Cast(sdtype, expr)
if get_type(sdtype) != "uint":
ty = "uint" + str(get_type(sdtype)[1])
expr = _make.Call(ty, "bitcast", [expr], _expr.Call.PureIntrinsic, None, 0)
expr = _make.SetSlice(load, expr, end, start)
stage.emit(_make.Store(buffer_var,
_make.Cast(dtype, expr),
index))
start = end
stmt = make_for(indices, stage.pop_stmt(), 0)
elif isinstance(ret, (TensorSlice, Scalar, _expr.Expr, numbers.Number)):
indices = lambda_ivs
index, _, _ = get_index(shape, indices, 0)
Expand Down Expand Up @@ -539,7 +562,7 @@ def unpack_A(A):
# to do so, we will need the name
name_ = name if Stage.get_len() == 0 \
else Stage.get_current().name_with_prefix + "." + name
dtype = get_dtype(dtype, name_)
dtype = get_tvm_dtype(dtype, name_)
ret = get_type(dtype)
factor = tensor.type.bits // ret[1]
bitwidth = ret[1]
Expand Down Expand Up @@ -612,7 +635,7 @@ def pack(tensor, axis=0, factor=None, name=None, dtype=None):
# to do so, we will need the name
name_ = name if Stage.get_len() == 0 \
else Stage.get_current().name_with_prefix + "." + name
dtype = get_dtype(dtype, name_)
dtype = get_tvm_dtype(dtype, name_)
ret = get_type(dtype)
factor = ret[1] // tensor.type.bits
bitwidth = tensor.type.bits
Expand Down
14 changes: 7 additions & 7 deletions python/heterocl/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,19 +399,19 @@ def decorator(fmodule, shapes=shapes, dtypes=dtypes, ret_dtype=ret_dtype, name=n
if dtypes is None:
dtypes = []
for name_ in new_names:
dtypes.append(util.get_dtype(None, name_))
dtypes.append(util.get_tvm_dtype(None, name_))
elif isinstance(dtypes, list):
if len(dtypes) != nargs:
raise APIError("The number of data types does not match the of arguments")
for (name_, dtype_) in zip(new_names, dtypes):
dtypes.append(util.get_dtype(dtype_, name_))
dtypes.append(util.get_tvm_dtype(dtype_, name_))
dtypes = dtypes[int(len(dtypes)/2):]
else:
dtype = util.get_dtype(dtypes)
dtype = util.get_tvm_dtype(dtypes)
dtypes = []
for name_ in new_names:
dtypes.append(util.get_dtype(dtype, name_))
ret_dtype = util.get_dtype(ret_dtype, s.name_with_prefix)
dtypes.append(util.get_tvm_dtype(dtype, name_))
ret_dtype = util.get_tvm_dtype(ret_dtype, s.name_with_prefix)
# prepare inputs for IR generation
inputs = []
inputs_tvm = []
Expand Down Expand Up @@ -441,7 +441,7 @@ def decorator(fmodule, shapes=shapes, dtypes=dtypes, ret_dtype=ret_dtype, name=n
ret_void = _make.UIntImm("uint1", 0) if s.has_return else _make.UIntImm("uint1", 1)
body = s.pop_stmt()
s.stmt_stack.append([])
s.emit(_make.KernelDef(inputs_tvm, arg_shapes, arg_dtypes,
s.emit(_make.KernelDef(inputs_tvm, arg_shapes, arg_dtypes,
body, ret_void, ret_dtype, name, []))
for name_, i in zip(names, inputs):
s.var_dict[name_] = i
Expand Down Expand Up @@ -499,6 +499,6 @@ def compute_out(A, x):
if not Stage.get_len():
raise DSLError("Imperative DSL must be used with other compute APIs")
stage = Stage.get_current()
dtype = util.get_dtype(stage.ret_dtype)
dtype = util.get_tvm_dtype(stage.ret_dtype)
stage.emit(_make.Return(_make.Cast(dtype, val)))
stage.has_return = True
4 changes: 2 additions & 2 deletions python/heterocl/nparray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#pylint: disable=missing-docstring
import numpy as np
from .tvm.ndarray import array, cpu
from .util import get_dtype
from .util import get_tvm_dtype
from . import types

def cast_np(np_in, dtype):
Expand Down Expand Up @@ -79,7 +79,7 @@ def asarray(arr, dtype=None, ctx=cpu(0)):
np_A = numpy.zeros(10)
hcl_A = np_A.asarray()
"""
dtype = get_dtype(dtype)
dtype = get_tvm_dtype(dtype)
return array(arr, dtype, ctx)

def pack_np(np_in, dtype_in, dtype_out):
Expand Down
3 changes: 2 additions & 1 deletion python/heterocl/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def __init__(self, name=None, dtype=None, shape=()):
else Stage.get_current().name_with_prefix + "." + self.name
# Private attributes for building a stage
self._op = None
self._dtype = util.get_dtype(dtype, self.name_with_prefix)
self._hcl_dtype = util.get_dtype(dtype, self.name_with_prefix)
self._dtype = util.get_tvm_dtype(dtype, self.name_with_prefix)
self._buf = tvm_api.decl_buffer(shape, self._dtype, self.name)
self._shape = self._buf.shape

Expand Down
84 changes: 75 additions & 9 deletions python/heterocl/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,12 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
# not allowed: A[5:7]
"""
def __init__(self, tensor, indices):
def __init__(self, tensor, indices, dtype=None):
if not isinstance(indices, tuple):
indices = (indices,)
self.tensor = tensor
self.indices = indices
self._dtype = dtype if dtype is not None else self.tensor.dtype

def __getitem__(self, indices):
if not isinstance(indices, tuple):
Expand All @@ -134,10 +135,17 @@ def __setitem__(self, indices, expr):
builder = Stage.get_current()
if bit is None:
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self.tensor.dtype, expr),
_make.Cast(self._dtype, expr),
index))
elif isinstance(bit, slice):
load = _make.Load(self.tensor.dtype, self.tensor.buf.data, index)
# special handle for struct: we need to make sure the bitwidths
# are the same before and after bitcast
if (isinstance(self.tensor.type, types.Struct)
and util.get_type(self._dtype) != "uint"):
ty = "uint" + str(util.get_type(self._dtype)[1])
expr = _make.Call(ty, "bitcast",
[expr], _expr.Call.PureIntrinsic, None, 0)
expr = _make.SetSlice(load, expr, bit.start, bit.stop)
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self.tensor.dtype, expr),
Expand All @@ -146,9 +154,53 @@ def __setitem__(self, indices, expr):
load = _make.Load(self.tensor.dtype, self.tensor.buf.data, index)
expr = _make.SetBit(load, expr, bit)
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self.tensor.dtype, expr),
_make.Cast(self._dtype, expr),
index))

def __getattr__(self, key):
hcl_dtype = self.tensor.hcl_dtype
if not isinstance(hcl_dtype, types.Struct):
raise TensorError(
"Cannot access attribute if type is not struct")
start = 0
end = 0
dtype = None
for dkey, dval in hcl_dtype.dtype_dict.items():
if dkey == key:
end = start + dval.bits
dtype = types.dtype_to_str(dval)
break
else:
start += dval.bits
if dtype is None:
raise DTypeError("Field " + key
+ " is not in struct " + str(hcl_dtype))
indices = (slice(end, start),)
return TensorSlice(self.tensor, self.indices + indices, dtype)

def __setattr__(self, key, expr):
if key in ("tensor", "indices", "_dtype"):
super().__setattr__(key, expr)
else:
hcl_dtype = self.tensor.hcl_dtype
if not isinstance(hcl_dtype, types.Struct):
raise TensorError(
"Cannot access attribute if type is not struct")
start = 0
end = 0
for dkey, dval in hcl_dtype.dtype_dict.items():
if dkey == key:
end = start + dval.bits
self._dtype = types.dtype_to_str(dval)
break
else:
start += dval.bits
if start == end:
raise DTypeError("Field " + key
+ " is not in struct " + str(hcl_dtype))
indices = (slice(end, start),)
self.__setitem__(indices, expr)

@property
def dtype(self):
return self.tensor.dtype
Expand All @@ -158,12 +210,25 @@ def asnode(self):
raise TensorError("Accessing a slice of tensor is not allowed")
index, bit, _ = util.get_index(self.tensor.shape, self.indices, 0)
if bit is None:
return _make.Load(self.tensor.dtype, self.tensor.buf.data, index)
return _make.Load(self._dtype, self.tensor.buf.data, index)
elif isinstance(bit, slice):
return _make.GetSlice(_make.Load(self.tensor.dtype, self.tensor.buf.data, index),
load = _make.GetSlice(_make.Load(self.tensor.dtype,
self.tensor.buf.data, index),
bit.start,
bit.stop)
return _make.GetBit(_make.Load(self.tensor.dtype, self.tensor.buf.data, index), bit)
if self.tensor.dtype != self._dtype:
bw_from = types.get_bitwidth(self.tensor.dtype)
bw_to = types.get_bitwidth(self._dtype)
if bw_from != bw_to:
ty = util.get_type(self.tensor.dtype)[0] + str(bw_to)
load = _make.Cast(ty, load)
return _make.Call(self._dtype, "bitcast",
[load], _expr.Call.PureIntrinsic, None, 0)
else:
return load
return _make.GetBit(_make.Load(self._dtype,
self.tensor.buf.data,
index), bit)

class Tensor(NodeGeneric, _expr.ExprOp):
"""A HeteroCL tensor.
Expand Down Expand Up @@ -230,14 +295,15 @@ class Tensor(NodeGeneric, _expr.ExprOp):
def __init__(self, shape, dtype="int32", name="tensor", buf=None):
self._tensor = None
self._buf = buf
self.dtype = dtype
self.hcl_dtype = dtype
self.dtype = types.dtype_to_str(dtype)
self.shape = shape
self.name = name
self.var_dict = {}
self.first_update = None
self.last_update = None
if buf is None:
self._buf = decl_buffer(shape, dtype, name)
self._buf = decl_buffer(shape, self.dtype, name)

def __repr__(self):
return "Tensor('" + self.name + "', " + str(self.shape) + ", " + str(self.dtype) + ")"
Expand Down Expand Up @@ -291,7 +357,7 @@ def buf(self):

@property
def type(self):
return types.dtype_to_hcl(self.dtype)
return self.hcl_dtype

@property
def op(self):
Expand Down
28 changes: 27 additions & 1 deletion python/heterocl/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Define HeteroCL data types"""
#pylint: disable=too-few-public-methods, too-many-return-statements
import numbers
from collections import OrderedDict
from .debug import DTypeError

class Type(object):
Expand Down Expand Up @@ -48,6 +49,30 @@ class UFixed(Type):
def __repr__(self):
return "UFixed(" + str(self.bits) + ", " + str(self.fracs) + ")"

class Struct(Type):
"""A C-like struct
The struct members are defined with a Python dictionary
"""
def __init__(self, dtype_dict):
self.dtype_dict = OrderedDict(dtype_dict)
self.bits = 0
for dtype in dtype_dict.values():
self.bits += dtype.bits
Type.__init__(self, self.bits, 0)

def __repr__(self):
return "Struct(" + str(self.dtype_dict) + ")"

def __getattr__(self, key):
try:
return self.dtype_dict[key]
except KeyError:
raise DTypeError(key + " is not in struct")

def __getitem__(self, key):
return self.__getattr__(key)

def dtype_to_str(dtype):
"""Convert a data type to string format.
Expand All @@ -66,7 +91,8 @@ def dtype_to_str(dtype):
if isinstance(dtype, Type):
if isinstance(dtype, Int):
return "int" + str(dtype.bits)
elif isinstance(dtype, UInt):
# struct is treated as uint
elif isinstance(dtype, (UInt, Struct)):
return "uint" + str(dtype.bits)
elif isinstance(dtype, Fixed):
bits = dtype.bits
Expand Down
5 changes: 4 additions & 1 deletion python/heterocl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def get_dtype(dtype, name=None):
dtype_ = Scheme.current.dtype_dict.get(name)
dtype = dtype if dtype_ is None else dtype_
dtype = config.init_dtype if dtype is None else dtype
return types.dtype_to_str(dtype)
return dtype

def get_tvm_dtype(dtype, name=None):
return types.dtype_to_str(get_dtype(dtype, name))

def true():
return _make.UIntImm("uint1", 1)
Expand Down
Loading

0 comments on commit 8906f99

Please sign in to comment.