diff --git a/hlib/python/hlib/op/nn.py b/hlib/python/hlib/op/nn.py index da5393196..25c71f226 100644 --- a/hlib/python/hlib/op/nn.py +++ b/hlib/python/hlib/op/nn.py @@ -173,6 +173,7 @@ def conv2d( d = [] for i in range(len(padding)): p.append(tvm_to_primitive(padding[i])) + for i in range(len(strides)): s.append(tvm_to_primitive(strides[i])) d.append(tvm_to_primitive(dilation[i])) strides = s diff --git a/python/heterocl/api.py b/python/heterocl/api.py index f3e2151c8..b0c9351b5 100644 --- a/python/heterocl/api.py +++ b/python/heterocl/api.py @@ -90,11 +90,12 @@ def placeholder(shape, name=None, dtype=None): """ name = util.get_name("placeholder", name) dtype = util.get_dtype(dtype) - + tvm_dtype = types.dtype_to_str(dtype) + if shape == (): - return Scalar(tvm_api._Var(name, dtype)) + return Scalar(tvm_api._Var(name, tvm_dtype)) tensor = Tensor(shape, dtype, name) - tensor.tensor = tvm_api._Placeholder(tensor.buf.shape, dtype, name) + tensor.tensor = tvm_api._Placeholder(tensor.buf.shape, tvm_dtype, name) # placeholder is also a stage stage = Stage(name) diff --git a/python/heterocl/compute_api.py b/python/heterocl/compute_api.py index 304642d39..05539e26f 100644 --- a/python/heterocl/compute_api.py +++ b/python/heterocl/compute_api.py @@ -5,8 +5,9 @@ from collections import OrderedDict from .tvm import expr as _expr, stmt as _stmt, make as _make from .tvm.api import _IterVar, min_value -from .util import get_index, get_name, get_type, get_dtype, make_for, CastRemover +from .util import get_index, get_name, get_type, get_tvm_dtype, make_for, CastRemover from .tensor import Scalar, Tensor, TensorSlice +from .types import Struct, dtype_to_str from .schedule import Stage from .debug import APIError from .dsl import if_, for_ @@ -117,7 +118,7 @@ def compute_body(name, if not return_tensor: stage.input_stages.add(tensor.last_update) else: - tensor = Tensor(shape, stage._dtype, name, stage._buf) + tensor = Tensor(shape, stage._hcl_dtype, name, stage._buf) buffer_var = tensor._buf.data dtype = tensor.dtype shape = tensor.shape @@ -137,6 +138,28 @@ def compute_body(name, stmt = stage.pop_stmt() stmt = ReplaceReturn(buffer_var, dtype, index).mutate(stmt) stmt = make_for(indices, stmt, 0) + elif isinstance(ret, (tuple, list)): + indices = lambda_ivs + index, _, _ = get_index(shape, indices, 0) + hcl_dtype = tensor.hcl_dtype + if not isinstance(hcl_dtype, Struct): + raise TensorError("Cannot assign a tuple/list to a non-struct-type tensor") + start = 0 + end = 0 + for sdtype, expr in zip(hcl_dtype.dtype_dict.values(), ret): + end = start + sdtype.bits + sdtype = dtype_to_str(sdtype) + load = _make.Load(dtype, buffer_var, index) + expr = _make.Cast(sdtype, expr) + if get_type(sdtype) != "uint": + ty = "uint" + str(get_type(sdtype)[1]) + expr = _make.Call(ty, "bitcast", [expr], _expr.Call.PureIntrinsic, None, 0) + expr = _make.SetSlice(load, expr, end, start) + stage.emit(_make.Store(buffer_var, + _make.Cast(dtype, expr), + index)) + start = end + stmt = make_for(indices, stage.pop_stmt(), 0) elif isinstance(ret, (TensorSlice, Scalar, _expr.Expr, numbers.Number)): indices = lambda_ivs index, _, _ = get_index(shape, indices, 0) @@ -539,7 +562,7 @@ def unpack_A(A): # to do so, we will need the name name_ = name if Stage.get_len() == 0 \ else Stage.get_current().name_with_prefix + "." + name - dtype = get_dtype(dtype, name_) + dtype = get_tvm_dtype(dtype, name_) ret = get_type(dtype) factor = tensor.type.bits // ret[1] bitwidth = ret[1] @@ -612,7 +635,7 @@ def pack(tensor, axis=0, factor=None, name=None, dtype=None): # to do so, we will need the name name_ = name if Stage.get_len() == 0 \ else Stage.get_current().name_with_prefix + "." + name - dtype = get_dtype(dtype, name_) + dtype = get_tvm_dtype(dtype, name_) ret = get_type(dtype) factor = ret[1] // tensor.type.bits bitwidth = tensor.type.bits diff --git a/python/heterocl/dsl.py b/python/heterocl/dsl.py index b226cb0ab..2916dee97 100644 --- a/python/heterocl/dsl.py +++ b/python/heterocl/dsl.py @@ -399,19 +399,19 @@ def decorator(fmodule, shapes=shapes, dtypes=dtypes, ret_dtype=ret_dtype, name=n if dtypes is None: dtypes = [] for name_ in new_names: - dtypes.append(util.get_dtype(None, name_)) + dtypes.append(util.get_tvm_dtype(None, name_)) elif isinstance(dtypes, list): if len(dtypes) != nargs: raise APIError("The number of data types does not match the of arguments") for (name_, dtype_) in zip(new_names, dtypes): - dtypes.append(util.get_dtype(dtype_, name_)) + dtypes.append(util.get_tvm_dtype(dtype_, name_)) dtypes = dtypes[int(len(dtypes)/2):] else: - dtype = util.get_dtype(dtypes) + dtype = util.get_tvm_dtype(dtypes) dtypes = [] for name_ in new_names: - dtypes.append(util.get_dtype(dtype, name_)) - ret_dtype = util.get_dtype(ret_dtype, s.name_with_prefix) + dtypes.append(util.get_tvm_dtype(dtype, name_)) + ret_dtype = util.get_tvm_dtype(ret_dtype, s.name_with_prefix) # prepare inputs for IR generation inputs = [] inputs_tvm = [] @@ -441,7 +441,7 @@ def decorator(fmodule, shapes=shapes, dtypes=dtypes, ret_dtype=ret_dtype, name=n ret_void = _make.UIntImm("uint1", 0) if s.has_return else _make.UIntImm("uint1", 1) body = s.pop_stmt() s.stmt_stack.append([]) - s.emit(_make.KernelDef(inputs_tvm, arg_shapes, arg_dtypes, + s.emit(_make.KernelDef(inputs_tvm, arg_shapes, arg_dtypes, body, ret_void, ret_dtype, name, [])) for name_, i in zip(names, inputs): s.var_dict[name_] = i @@ -499,6 +499,6 @@ def compute_out(A, x): if not Stage.get_len(): raise DSLError("Imperative DSL must be used with other compute APIs") stage = Stage.get_current() - dtype = util.get_dtype(stage.ret_dtype) + dtype = util.get_tvm_dtype(stage.ret_dtype) stage.emit(_make.Return(_make.Cast(dtype, val))) stage.has_return = True diff --git a/python/heterocl/nparray.py b/python/heterocl/nparray.py index be7f57dbf..36a4fb696 100644 --- a/python/heterocl/nparray.py +++ b/python/heterocl/nparray.py @@ -2,7 +2,7 @@ #pylint: disable=missing-docstring import numpy as np from .tvm.ndarray import array, cpu -from .util import get_dtype +from .util import get_tvm_dtype from . import types def cast_np(np_in, dtype): @@ -79,7 +79,7 @@ def asarray(arr, dtype=None, ctx=cpu(0)): np_A = numpy.zeros(10) hcl_A = np_A.asarray() """ - dtype = get_dtype(dtype) + dtype = get_tvm_dtype(dtype) return array(arr, dtype, ctx) def pack_np(np_in, dtype_in, dtype_out): diff --git a/python/heterocl/schedule.py b/python/heterocl/schedule.py index 2cf6eeb6e..da98f0a27 100644 --- a/python/heterocl/schedule.py +++ b/python/heterocl/schedule.py @@ -324,7 +324,8 @@ def __init__(self, name=None, dtype=None, shape=()): else Stage.get_current().name_with_prefix + "." + self.name # Private attributes for building a stage self._op = None - self._dtype = util.get_dtype(dtype, self.name_with_prefix) + self._hcl_dtype = util.get_dtype(dtype, self.name_with_prefix) + self._dtype = util.get_tvm_dtype(dtype, self.name_with_prefix) self._buf = tvm_api.decl_buffer(shape, self._dtype, self.name) self._shape = self._buf.shape diff --git a/python/heterocl/tensor.py b/python/heterocl/tensor.py index f95d48811..10387a044 100644 --- a/python/heterocl/tensor.py +++ b/python/heterocl/tensor.py @@ -113,11 +113,12 @@ class TensorSlice(NodeGeneric, _expr.ExprOp): # not allowed: A[5:7] """ - def __init__(self, tensor, indices): + def __init__(self, tensor, indices, dtype=None): if not isinstance(indices, tuple): indices = (indices,) self.tensor = tensor self.indices = indices + self._dtype = dtype if dtype is not None else self.tensor.dtype def __getitem__(self, indices): if not isinstance(indices, tuple): @@ -134,10 +135,17 @@ def __setitem__(self, indices, expr): builder = Stage.get_current() if bit is None: builder.emit(_make.Store(self.tensor.buf.data, - _make.Cast(self.tensor.dtype, expr), + _make.Cast(self._dtype, expr), index)) elif isinstance(bit, slice): load = _make.Load(self.tensor.dtype, self.tensor.buf.data, index) + # special handle for struct: we need to make sure the bitwidths + # are the same before and after bitcast + if (isinstance(self.tensor.type, types.Struct) + and util.get_type(self._dtype) != "uint"): + ty = "uint" + str(util.get_type(self._dtype)[1]) + expr = _make.Call(ty, "bitcast", + [expr], _expr.Call.PureIntrinsic, None, 0) expr = _make.SetSlice(load, expr, bit.start, bit.stop) builder.emit(_make.Store(self.tensor.buf.data, _make.Cast(self.tensor.dtype, expr), @@ -146,9 +154,53 @@ def __setitem__(self, indices, expr): load = _make.Load(self.tensor.dtype, self.tensor.buf.data, index) expr = _make.SetBit(load, expr, bit) builder.emit(_make.Store(self.tensor.buf.data, - _make.Cast(self.tensor.dtype, expr), + _make.Cast(self._dtype, expr), index)) + def __getattr__(self, key): + hcl_dtype = self.tensor.hcl_dtype + if not isinstance(hcl_dtype, types.Struct): + raise TensorError( + "Cannot access attribute if type is not struct") + start = 0 + end = 0 + dtype = None + for dkey, dval in hcl_dtype.dtype_dict.items(): + if dkey == key: + end = start + dval.bits + dtype = types.dtype_to_str(dval) + break + else: + start += dval.bits + if dtype is None: + raise DTypeError("Field " + key + + " is not in struct " + str(hcl_dtype)) + indices = (slice(end, start),) + return TensorSlice(self.tensor, self.indices + indices, dtype) + + def __setattr__(self, key, expr): + if key in ("tensor", "indices", "_dtype"): + super().__setattr__(key, expr) + else: + hcl_dtype = self.tensor.hcl_dtype + if not isinstance(hcl_dtype, types.Struct): + raise TensorError( + "Cannot access attribute if type is not struct") + start = 0 + end = 0 + for dkey, dval in hcl_dtype.dtype_dict.items(): + if dkey == key: + end = start + dval.bits + self._dtype = types.dtype_to_str(dval) + break + else: + start += dval.bits + if start == end: + raise DTypeError("Field " + key + + " is not in struct " + str(hcl_dtype)) + indices = (slice(end, start),) + self.__setitem__(indices, expr) + @property def dtype(self): return self.tensor.dtype @@ -158,12 +210,25 @@ def asnode(self): raise TensorError("Accessing a slice of tensor is not allowed") index, bit, _ = util.get_index(self.tensor.shape, self.indices, 0) if bit is None: - return _make.Load(self.tensor.dtype, self.tensor.buf.data, index) + return _make.Load(self._dtype, self.tensor.buf.data, index) elif isinstance(bit, slice): - return _make.GetSlice(_make.Load(self.tensor.dtype, self.tensor.buf.data, index), + load = _make.GetSlice(_make.Load(self.tensor.dtype, + self.tensor.buf.data, index), bit.start, bit.stop) - return _make.GetBit(_make.Load(self.tensor.dtype, self.tensor.buf.data, index), bit) + if self.tensor.dtype != self._dtype: + bw_from = types.get_bitwidth(self.tensor.dtype) + bw_to = types.get_bitwidth(self._dtype) + if bw_from != bw_to: + ty = util.get_type(self.tensor.dtype)[0] + str(bw_to) + load = _make.Cast(ty, load) + return _make.Call(self._dtype, "bitcast", + [load], _expr.Call.PureIntrinsic, None, 0) + else: + return load + return _make.GetBit(_make.Load(self._dtype, + self.tensor.buf.data, + index), bit) class Tensor(NodeGeneric, _expr.ExprOp): """A HeteroCL tensor. @@ -230,14 +295,15 @@ class Tensor(NodeGeneric, _expr.ExprOp): def __init__(self, shape, dtype="int32", name="tensor", buf=None): self._tensor = None self._buf = buf - self.dtype = dtype + self.hcl_dtype = dtype + self.dtype = types.dtype_to_str(dtype) self.shape = shape self.name = name self.var_dict = {} self.first_update = None self.last_update = None if buf is None: - self._buf = decl_buffer(shape, dtype, name) + self._buf = decl_buffer(shape, self.dtype, name) def __repr__(self): return "Tensor('" + self.name + "', " + str(self.shape) + ", " + str(self.dtype) + ")" @@ -291,7 +357,7 @@ def buf(self): @property def type(self): - return types.dtype_to_hcl(self.dtype) + return self.hcl_dtype @property def op(self): diff --git a/python/heterocl/types.py b/python/heterocl/types.py index 5f22a87c6..44e3fd9a8 100644 --- a/python/heterocl/types.py +++ b/python/heterocl/types.py @@ -1,6 +1,7 @@ """Define HeteroCL data types""" #pylint: disable=too-few-public-methods, too-many-return-statements import numbers +from collections import OrderedDict from .debug import DTypeError class Type(object): @@ -48,6 +49,30 @@ class UFixed(Type): def __repr__(self): return "UFixed(" + str(self.bits) + ", " + str(self.fracs) + ")" +class Struct(Type): + """A C-like struct + + The struct members are defined with a Python dictionary + """ + def __init__(self, dtype_dict): + self.dtype_dict = OrderedDict(dtype_dict) + self.bits = 0 + for dtype in dtype_dict.values(): + self.bits += dtype.bits + Type.__init__(self, self.bits, 0) + + def __repr__(self): + return "Struct(" + str(self.dtype_dict) + ")" + + def __getattr__(self, key): + try: + return self.dtype_dict[key] + except KeyError: + raise DTypeError(key + " is not in struct") + + def __getitem__(self, key): + return self.__getattr__(key) + def dtype_to_str(dtype): """Convert a data type to string format. @@ -66,7 +91,8 @@ def dtype_to_str(dtype): if isinstance(dtype, Type): if isinstance(dtype, Int): return "int" + str(dtype.bits) - elif isinstance(dtype, UInt): + # struct is treated as uint + elif isinstance(dtype, (UInt, Struct)): return "uint" + str(dtype.bits) elif isinstance(dtype, Fixed): bits = dtype.bits diff --git a/python/heterocl/util.py b/python/heterocl/util.py index 704b774cb..5d9e9b5a8 100644 --- a/python/heterocl/util.py +++ b/python/heterocl/util.py @@ -75,7 +75,10 @@ def get_dtype(dtype, name=None): dtype_ = Scheme.current.dtype_dict.get(name) dtype = dtype if dtype_ is None else dtype_ dtype = config.init_dtype if dtype is None else dtype - return types.dtype_to_str(dtype) + return dtype + +def get_tvm_dtype(dtype, name=None): + return types.dtype_to_str(get_dtype(dtype, name)) def true(): return _make.UIntImm("uint1", 1) diff --git a/tests/test_dtype.py b/tests/test_dtype.py index 1d9d00279..7276e8a3d 100644 --- a/tests/test_dtype.py +++ b/tests/test_dtype.py @@ -205,3 +205,90 @@ def kernel(A): f(hcl_A, hcl_C) assert np.array_equal(np_A, hcl_C.asnumpy()) + +def test_dtype_struct(): + hcl.init() + A = hcl.placeholder((100,), dtype=hcl.Int(8)) + B = hcl.placeholder((100,), dtype=hcl.Fixed(13, 11)) + C = hcl.placeholder((100,), dtype=hcl.Float()) + + def kernel(A, B, C): + stype = hcl.Struct({"fa": hcl.Int(8), "fb": hcl.Fixed(13, 11), "fc": hcl.Float()}) + D = hcl.compute(A.shape, lambda x: (A[x], B[x], C[x]), dtype=stype) + E = hcl.compute(A.shape, lambda x: D[x].fa, dtype=hcl.Int(8)) + F = hcl.compute(A.shape, lambda x: D[x].fb, dtype=hcl.Fixed(13, 11)) + G = hcl.compute(A.shape, lambda x: D[x].fc, dtype=hcl.Float()) + return E, F, G + + s = hcl.create_schedule([A, B, C], kernel) + f = hcl.build(s) + np_A = np.random.randint(0, 500, size=100) - 250 + np_B = np.random.rand(100) - 0.5 + np_C = np.random.rand(100) - 0.5 + np_E = np.zeros(100) + np_F = np.zeros(100) + np_G = np.zeros(100) + hcl_A = hcl.asarray(np_A, dtype=hcl.Int(8)) + hcl_B = hcl.asarray(np_B, dtype=hcl.Fixed(13, 11)) + hcl_C = hcl.asarray(np_C, dtype=hcl.Float()) + hcl_E = hcl.asarray(np_E, dtype=hcl.Int(8)) + hcl_F = hcl.asarray(np_F, dtype=hcl.Fixed(13, 11)) + hcl_G = hcl.asarray(np_G, dtype=hcl.Float()) + f(hcl_A, hcl_B, hcl_C, hcl_E, hcl_F, hcl_G) + + assert np.allclose(hcl_A.asnumpy(), hcl_E.asnumpy()) + assert np.allclose(hcl_B.asnumpy(), hcl_F.asnumpy()) + assert np.allclose(hcl_C.asnumpy(), hcl_G.asnumpy()) + +def test_dtye_strcut_complex(): + hcl.init() + A = hcl.placeholder((100,)) + B = hcl.placeholder((100,)) + C = hcl.placeholder((100,)) + O = hcl.placeholder((100, 6)) + + def kernel(A, B, C, O): + dtype_xyz = hcl.Struct({"x": hcl.Int(), "y": hcl.Int(), "z": hcl.Int()}) + dtype_out = hcl.Struct({"v0": hcl.Int(), + "v1": hcl.Int(), + "v2": hcl.Int(), + "v3": hcl.Int(), + "v4": hcl.Int(), + "v5": hcl.Int()}) + + D = hcl.compute(A.shape, lambda x: (A[x], B[x], C[x]), dtype=dtype_xyz) + E = hcl.compute(A.shape, lambda x: (D[x].x * D[x].x, + D[x].y * D[x].y, + D[x].z * D[x].z, + D[x].x * D[x].y, + D[x].y * D[x].z, + D[x].x * D[x].z), dtype=dtype_out) + with hcl.Stage(): + with hcl.for_(0, 100) as i: + for j in range(0, 6): + O[i][j] = E[i].__getattr__("v" + str(j)) + + s = hcl.create_schedule([A, B, C, O], kernel) + f = hcl.build(s) + + np_A = np.random.randint(10, size=100) + np_B = np.random.randint(10, size=100) + np_C = np.random.randint(10, size=100) + np_O = np.zeros((100, 6)) + + np_G = np.zeros((100, 6)).astype("int") + for i in range(0, 100): + np_G[i][0] = np_A[i] * np_A[i] + np_G[i][1] = np_B[i] * np_B[i] + np_G[i][2] = np_C[i] * np_C[i] + np_G[i][3] = np_A[i] * np_B[i] + np_G[i][4] = np_B[i] * np_C[i] + np_G[i][5] = np_A[i] * np_C[i] + + hcl_A = hcl.asarray(np_A) + hcl_B = hcl.asarray(np_B) + hcl_C = hcl.asarray(np_C) + hcl_O = hcl.asarray(np_O) + f(hcl_A, hcl_B, hcl_C, hcl_O) + + assert np.array_equal(hcl_O.asnumpy(), np_G) diff --git a/tvm/HalideIR/src/ir/IR.cpp b/tvm/HalideIR/src/ir/IR.cpp index a604b6fd2..33ae584ce 100644 --- a/tvm/HalideIR/src/ir/IR.cpp +++ b/tvm/HalideIR/src/ir/IR.cpp @@ -963,6 +963,7 @@ Call::ConstString Call::shift_left = "shift_left"; Call::ConstString Call::shift_right = "shift_right"; Call::ConstString Call::abs = "abs"; Call::ConstString Call::absd = "absd"; +Call::ConstString Call::bitcast = "bitcast"; Call::ConstString Call::lerp = "lerp"; Call::ConstString Call::random = "random"; Call::ConstString Call::popcount = "popcount"; diff --git a/tvm/HalideIR/src/ir/IR.h b/tvm/HalideIR/src/ir/IR.h index e8a8835bf..232a8364d 100644 --- a/tvm/HalideIR/src/ir/IR.h +++ b/tvm/HalideIR/src/ir/IR.h @@ -707,6 +707,7 @@ struct Call : public ExprNode { shift_right, abs, absd, + bitcast, rewrite_buffer, random, lerp, diff --git a/tvm/src/codegen/codegen_c.cc b/tvm/src/codegen/codegen_c.cc index 006edf933..d28fac208 100644 --- a/tvm/src/codegen/codegen_c.cc +++ b/tvm/src/codegen/codegen_c.cc @@ -663,6 +663,38 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) PrintBinaryIntrinsitc(op, " << ", os, this); } else if (op->is_intrinsic(Call::shift_right)) { PrintBinaryIntrinsitc(op, " >> ", os, this); + } else if (op->is_intrinsic(Call::bitcast)) { + this->PrintIndent(); + std::string conv_name = GetUniqueName("_converter"); + int bits = op->args[0].type().bits(); + if (op->args[0].type().code() == Type::Float || + op->type.code() == Type::Float) { + CHECK(bits == 32 || bits == 64); + std::string ty_from = bits == 32 ? "float" : "double"; + std::string ty_to = bits == 32 ? "uint32_t" : "uint64_t"; + bool from_float = op->args[0].type().code() == Type::Float; + stream << "union { "; + if (from_float) stream << ty_from; + else stream << ty_to; + stream << " from; "; + if (from_float) stream << ty_to; + else stream << ty_from; + stream << " to;} " << conv_name << ";\n"; + this->PrintIndent(); + stream << conv_name << ".from = "; + this->PrintExpr(op->args[0], stream); + stream << ";\n"; + os << conv_name << ".to"; + } else { + this->PrintType(op->type, stream); + stream << " " << conv_name << ";\n"; + this->PrintIndent(); + stream << conv_name << "(" << bits-1 << ", 0) = "; + this->PrintExpr(op->args[0], stream); + stream << "(" << bits-1 << ", 0)"; + stream << ";\n"; + os << conv_name; + } } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { os << "("; PrintExpr(op->args[0], os); diff --git a/tvm/src/codegen/hlsc/codegen_vhls.cc b/tvm/src/codegen/hlsc/codegen_vhls.cc index f944bef83..ac4d9e899 100644 --- a/tvm/src/codegen/hlsc/codegen_vhls.cc +++ b/tvm/src/codegen/hlsc/codegen_vhls.cc @@ -98,6 +98,7 @@ void CodeGenVivadoHLS::AddFunction(LoweredFunc f, this->decl_stream << "#include \n"; this->decl_stream << "#include \n"; this->decl_stream << "#include \n\n"; + this->decl_stream << "#include \n\n"; CodeGenHLSC::AddFunction(f, map_arg_type); if (soda_header_.is_open()) soda_header_.close(); @@ -137,10 +138,11 @@ void CodeGenVivadoHLS::VisitStmt_(const Store* op) { Type t = op->value.type(); Expr new_index_left = ir::Simplify(ss->index_left - 1); std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); + std::string rhs = PrintExpr(ss->value); PrintIndent(); this->stream << ref << "(" << PrintExpr(new_index_left) << ", " << PrintExpr(ss->index_right) - << ") = " << PrintExpr(ss->value) << ";\n"; + << ") = " << rhs << ";\n"; } else if (const SetBit* sb = op->value.as()) { Type t = op->value.type(); std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); diff --git a/tvm/src/codegen/llvm/codegen_llvm.cc b/tvm/src/codegen/llvm/codegen_llvm.cc index a5c38154e..6c8d257e7 100644 --- a/tvm/src/codegen/llvm/codegen_llvm.cc +++ b/tvm/src/codegen/llvm/codegen_llvm.cc @@ -690,6 +690,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { } else { return builder_->CreateLShr(a, b_new); } + } else if (op->is_intrinsic(Call::bitcast)) { + llvm::Value* v = MakeValue(op->args[0]); + Type tv = op->args[0].type(); + Type to = op->type; + CHECK(tv.bits() == to.bits()); + return builder_->CreateBitCast(v, LLVMType(to)); } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return CreateStorageSync(op); } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {