Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API][v0.3] Introduce Struct Data Type to HeteroCL #157

Merged
merged 7 commits into from
Feb 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hlib/python/hlib/op/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def conv2d(
d = []
for i in range(len(padding)):
p.append(tvm_to_primitive(padding[i]))
for i in range(len(strides)):
s.append(tvm_to_primitive(strides[i]))
d.append(tvm_to_primitive(dilation[i]))
strides = s
Expand Down
7 changes: 4 additions & 3 deletions python/heterocl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ def placeholder(shape, name=None, dtype=None):
"""
name = util.get_name("placeholder", name)
dtype = util.get_dtype(dtype)

tvm_dtype = types.dtype_to_str(dtype)

if shape == ():
return Scalar(tvm_api._Var(name, dtype))
return Scalar(tvm_api._Var(name, tvm_dtype))
tensor = Tensor(shape, dtype, name)
tensor.tensor = tvm_api._Placeholder(tensor.buf.shape, dtype, name)
tensor.tensor = tvm_api._Placeholder(tensor.buf.shape, tvm_dtype, name)

# placeholder is also a stage
stage = Stage(name)
Expand Down
31 changes: 27 additions & 4 deletions python/heterocl/compute_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from collections import OrderedDict
from .tvm import expr as _expr, stmt as _stmt, make as _make
from .tvm.api import _IterVar, min_value
from .util import get_index, get_name, get_type, get_dtype, make_for, CastRemover
from .util import get_index, get_name, get_type, get_tvm_dtype, make_for, CastRemover
from .tensor import Scalar, Tensor, TensorSlice
from .types import Struct, dtype_to_str
from .schedule import Stage
from .debug import APIError
from .dsl import if_, for_
Expand Down Expand Up @@ -117,7 +118,7 @@ def compute_body(name,
if not return_tensor:
stage.input_stages.add(tensor.last_update)
else:
tensor = Tensor(shape, stage._dtype, name, stage._buf)
tensor = Tensor(shape, stage._hcl_dtype, name, stage._buf)
buffer_var = tensor._buf.data
dtype = tensor.dtype
shape = tensor.shape
Expand All @@ -137,6 +138,28 @@ def compute_body(name,
stmt = stage.pop_stmt()
stmt = ReplaceReturn(buffer_var, dtype, index).mutate(stmt)
stmt = make_for(indices, stmt, 0)
elif isinstance(ret, (tuple, list)):
indices = lambda_ivs
index, _, _ = get_index(shape, indices, 0)
hcl_dtype = tensor.hcl_dtype
if not isinstance(hcl_dtype, Struct):
raise TensorError("Cannot assign a tuple/list to a non-struct-type tensor")
start = 0
end = 0
for sdtype, expr in zip(hcl_dtype.dtype_dict.values(), ret):
end = start + sdtype.bits
sdtype = dtype_to_str(sdtype)
load = _make.Load(dtype, buffer_var, index)
expr = _make.Cast(sdtype, expr)
if get_type(sdtype) != "uint":
ty = "uint" + str(get_type(sdtype)[1])
expr = _make.Call(ty, "bitcast", [expr], _expr.Call.PureIntrinsic, None, 0)
expr = _make.SetSlice(load, expr, end, start)
stage.emit(_make.Store(buffer_var,
_make.Cast(dtype, expr),
index))
start = end
stmt = make_for(indices, stage.pop_stmt(), 0)
elif isinstance(ret, (TensorSlice, Scalar, _expr.Expr, numbers.Number)):
indices = lambda_ivs
index, _, _ = get_index(shape, indices, 0)
Expand Down Expand Up @@ -539,7 +562,7 @@ def unpack_A(A):
# to do so, we will need the name
name_ = name if Stage.get_len() == 0 \
else Stage.get_current().name_with_prefix + "." + name
dtype = get_dtype(dtype, name_)
dtype = get_tvm_dtype(dtype, name_)
ret = get_type(dtype)
factor = tensor.type.bits // ret[1]
bitwidth = ret[1]
Expand Down Expand Up @@ -612,7 +635,7 @@ def pack(tensor, axis=0, factor=None, name=None, dtype=None):
# to do so, we will need the name
name_ = name if Stage.get_len() == 0 \
else Stage.get_current().name_with_prefix + "." + name
dtype = get_dtype(dtype, name_)
dtype = get_tvm_dtype(dtype, name_)
ret = get_type(dtype)
factor = ret[1] // tensor.type.bits
bitwidth = tensor.type.bits
Expand Down
14 changes: 7 additions & 7 deletions python/heterocl/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,19 +399,19 @@ def decorator(fmodule, shapes=shapes, dtypes=dtypes, ret_dtype=ret_dtype, name=n
if dtypes is None:
dtypes = []
for name_ in new_names:
dtypes.append(util.get_dtype(None, name_))
dtypes.append(util.get_tvm_dtype(None, name_))
elif isinstance(dtypes, list):
if len(dtypes) != nargs:
raise APIError("The number of data types does not match the of arguments")
for (name_, dtype_) in zip(new_names, dtypes):
dtypes.append(util.get_dtype(dtype_, name_))
dtypes.append(util.get_tvm_dtype(dtype_, name_))
dtypes = dtypes[int(len(dtypes)/2):]
else:
dtype = util.get_dtype(dtypes)
dtype = util.get_tvm_dtype(dtypes)
dtypes = []
for name_ in new_names:
dtypes.append(util.get_dtype(dtype, name_))
ret_dtype = util.get_dtype(ret_dtype, s.name_with_prefix)
dtypes.append(util.get_tvm_dtype(dtype, name_))
ret_dtype = util.get_tvm_dtype(ret_dtype, s.name_with_prefix)
# prepare inputs for IR generation
inputs = []
inputs_tvm = []
Expand Down Expand Up @@ -441,7 +441,7 @@ def decorator(fmodule, shapes=shapes, dtypes=dtypes, ret_dtype=ret_dtype, name=n
ret_void = _make.UIntImm("uint1", 0) if s.has_return else _make.UIntImm("uint1", 1)
body = s.pop_stmt()
s.stmt_stack.append([])
s.emit(_make.KernelDef(inputs_tvm, arg_shapes, arg_dtypes,
s.emit(_make.KernelDef(inputs_tvm, arg_shapes, arg_dtypes,
body, ret_void, ret_dtype, name, []))
for name_, i in zip(names, inputs):
s.var_dict[name_] = i
Expand Down Expand Up @@ -499,6 +499,6 @@ def compute_out(A, x):
if not Stage.get_len():
raise DSLError("Imperative DSL must be used with other compute APIs")
stage = Stage.get_current()
dtype = util.get_dtype(stage.ret_dtype)
dtype = util.get_tvm_dtype(stage.ret_dtype)
stage.emit(_make.Return(_make.Cast(dtype, val)))
stage.has_return = True
4 changes: 2 additions & 2 deletions python/heterocl/nparray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#pylint: disable=missing-docstring
import numpy as np
from .tvm.ndarray import array, cpu
from .util import get_dtype
from .util import get_tvm_dtype
from . import types

def cast_np(np_in, dtype):
Expand Down Expand Up @@ -79,7 +79,7 @@ def asarray(arr, dtype=None, ctx=cpu(0)):
np_A = numpy.zeros(10)
hcl_A = np_A.asarray()
"""
dtype = get_dtype(dtype)
dtype = get_tvm_dtype(dtype)
return array(arr, dtype, ctx)

def pack_np(np_in, dtype_in, dtype_out):
Expand Down
3 changes: 2 additions & 1 deletion python/heterocl/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def __init__(self, name=None, dtype=None, shape=()):
else Stage.get_current().name_with_prefix + "." + self.name
# Private attributes for building a stage
self._op = None
self._dtype = util.get_dtype(dtype, self.name_with_prefix)
self._hcl_dtype = util.get_dtype(dtype, self.name_with_prefix)
self._dtype = util.get_tvm_dtype(dtype, self.name_with_prefix)
self._buf = tvm_api.decl_buffer(shape, self._dtype, self.name)
self._shape = self._buf.shape

Expand Down
84 changes: 75 additions & 9 deletions python/heterocl/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,12 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):

# not allowed: A[5:7]
"""
def __init__(self, tensor, indices):
def __init__(self, tensor, indices, dtype=None):
if not isinstance(indices, tuple):
indices = (indices,)
self.tensor = tensor
self.indices = indices
self._dtype = dtype if dtype is not None else self.tensor.dtype

def __getitem__(self, indices):
if not isinstance(indices, tuple):
Expand All @@ -134,10 +135,17 @@ def __setitem__(self, indices, expr):
builder = Stage.get_current()
if bit is None:
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self.tensor.dtype, expr),
_make.Cast(self._dtype, expr),
index))
elif isinstance(bit, slice):
load = _make.Load(self.tensor.dtype, self.tensor.buf.data, index)
# special handle for struct: we need to make sure the bitwidths
# are the same before and after bitcast
if (isinstance(self.tensor.type, types.Struct)
and util.get_type(self._dtype) != "uint"):
ty = "uint" + str(util.get_type(self._dtype)[1])
expr = _make.Call(ty, "bitcast",
[expr], _expr.Call.PureIntrinsic, None, 0)
expr = _make.SetSlice(load, expr, bit.start, bit.stop)
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self.tensor.dtype, expr),
Expand All @@ -146,9 +154,53 @@ def __setitem__(self, indices, expr):
load = _make.Load(self.tensor.dtype, self.tensor.buf.data, index)
expr = _make.SetBit(load, expr, bit)
builder.emit(_make.Store(self.tensor.buf.data,
_make.Cast(self.tensor.dtype, expr),
_make.Cast(self._dtype, expr),
index))

def __getattr__(self, key):
hcl_dtype = self.tensor.hcl_dtype
if not isinstance(hcl_dtype, types.Struct):
raise TensorError(
"Cannot access attribute if type is not struct")
start = 0
end = 0
dtype = None
for dkey, dval in hcl_dtype.dtype_dict.items():
if dkey == key:
end = start + dval.bits
dtype = types.dtype_to_str(dval)
break
else:
start += dval.bits
if dtype is None:
raise DTypeError("Field " + key
+ " is not in struct " + str(hcl_dtype))
indices = (slice(end, start),)
return TensorSlice(self.tensor, self.indices + indices, dtype)

def __setattr__(self, key, expr):
if key in ("tensor", "indices", "_dtype"):
super().__setattr__(key, expr)
else:
hcl_dtype = self.tensor.hcl_dtype
if not isinstance(hcl_dtype, types.Struct):
raise TensorError(
"Cannot access attribute if type is not struct")
start = 0
end = 0
for dkey, dval in hcl_dtype.dtype_dict.items():
if dkey == key:
end = start + dval.bits
self._dtype = types.dtype_to_str(dval)
break
else:
start += dval.bits
if start == end:
raise DTypeError("Field " + key
+ " is not in struct " + str(hcl_dtype))
indices = (slice(end, start),)
self.__setitem__(indices, expr)

@property
def dtype(self):
return self.tensor.dtype
Expand All @@ -158,12 +210,25 @@ def asnode(self):
raise TensorError("Accessing a slice of tensor is not allowed")
index, bit, _ = util.get_index(self.tensor.shape, self.indices, 0)
if bit is None:
return _make.Load(self.tensor.dtype, self.tensor.buf.data, index)
return _make.Load(self._dtype, self.tensor.buf.data, index)
elif isinstance(bit, slice):
return _make.GetSlice(_make.Load(self.tensor.dtype, self.tensor.buf.data, index),
load = _make.GetSlice(_make.Load(self.tensor.dtype,
self.tensor.buf.data, index),
bit.start,
bit.stop)
return _make.GetBit(_make.Load(self.tensor.dtype, self.tensor.buf.data, index), bit)
if self.tensor.dtype != self._dtype:
bw_from = types.get_bitwidth(self.tensor.dtype)
bw_to = types.get_bitwidth(self._dtype)
if bw_from != bw_to:
ty = util.get_type(self.tensor.dtype)[0] + str(bw_to)
load = _make.Cast(ty, load)
return _make.Call(self._dtype, "bitcast",
[load], _expr.Call.PureIntrinsic, None, 0)
else:
return load
return _make.GetBit(_make.Load(self._dtype,
self.tensor.buf.data,
index), bit)

class Tensor(NodeGeneric, _expr.ExprOp):
"""A HeteroCL tensor.
Expand Down Expand Up @@ -230,14 +295,15 @@ class Tensor(NodeGeneric, _expr.ExprOp):
def __init__(self, shape, dtype="int32", name="tensor", buf=None):
self._tensor = None
self._buf = buf
self.dtype = dtype
self.hcl_dtype = dtype
self.dtype = types.dtype_to_str(dtype)
self.shape = shape
self.name = name
self.var_dict = {}
self.first_update = None
self.last_update = None
if buf is None:
self._buf = decl_buffer(shape, dtype, name)
self._buf = decl_buffer(shape, self.dtype, name)

def __repr__(self):
return "Tensor('" + self.name + "', " + str(self.shape) + ", " + str(self.dtype) + ")"
Expand Down Expand Up @@ -291,7 +357,7 @@ def buf(self):

@property
def type(self):
return types.dtype_to_hcl(self.dtype)
return self.hcl_dtype

@property
def op(self):
Expand Down
28 changes: 27 additions & 1 deletion python/heterocl/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Define HeteroCL data types"""
#pylint: disable=too-few-public-methods, too-many-return-statements
import numbers
from collections import OrderedDict
from .debug import DTypeError

class Type(object):
Expand Down Expand Up @@ -48,6 +49,30 @@ class UFixed(Type):
def __repr__(self):
return "UFixed(" + str(self.bits) + ", " + str(self.fracs) + ")"

class Struct(Type):
"""A C-like struct

The struct members are defined with a Python dictionary
"""
def __init__(self, dtype_dict):
self.dtype_dict = OrderedDict(dtype_dict)
self.bits = 0
for dtype in dtype_dict.values():
self.bits += dtype.bits
Type.__init__(self, self.bits, 0)

def __repr__(self):
return "Struct(" + str(self.dtype_dict) + ")"

def __getattr__(self, key):
try:
return self.dtype_dict[key]
except KeyError:
raise DTypeError(key + " is not in struct")

def __getitem__(self, key):
return self.__getattr__(key)

def dtype_to_str(dtype):
"""Convert a data type to string format.

Expand All @@ -66,7 +91,8 @@ def dtype_to_str(dtype):
if isinstance(dtype, Type):
if isinstance(dtype, Int):
return "int" + str(dtype.bits)
elif isinstance(dtype, UInt):
# struct is treated as uint
elif isinstance(dtype, (UInt, Struct)):
return "uint" + str(dtype.bits)
elif isinstance(dtype, Fixed):
bits = dtype.bits
Expand Down
5 changes: 4 additions & 1 deletion python/heterocl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def get_dtype(dtype, name=None):
dtype_ = Scheme.current.dtype_dict.get(name)
dtype = dtype if dtype_ is None else dtype_
dtype = config.init_dtype if dtype is None else dtype
return types.dtype_to_str(dtype)
return dtype

def get_tvm_dtype(dtype, name=None):
return types.dtype_to_str(get_dtype(dtype, name))

def true():
return _make.UIntImm("uint1", 1)
Expand Down
Loading