diff --git a/HalideIR b/HalideIR index e68ae61cd541..1a11a6c2522b 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit e68ae61cd541ac29efc9fafe2ad061479bcaa9c9 +Subproject commit 1a11a6c2522b1d11a5ccdb9b4fe3976cbe7f9f27 diff --git a/Makefile b/Makefile index 97c0e1ed3d86..e7d8fd796383 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ ifndef config ifneq ("$(wildcard ./config.mk)","") - config = config.mk + config ?= config.mk else - config = make/config.mk + config ?= make/config.mk endif endif @@ -19,24 +19,16 @@ SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) -ifneq ($(USE_CUDA_PATH), NONE) - NVCC=$(USE_CUDA_PATH)/bin/nvcc -endif - export LDFLAGS = -pthread -lm -export CFLAGS = -std=c++11 -Wall -O2\ - -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -export FRAMEWORKS= - -ifneq ($(ADD_CFLAGS), NONE) - CFLAGS += $(ADD_CFLAGS) -endif +export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\ + -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0 -ifneq ($(ADD_LDFLAGS), NONE) - LDFLAGS += $(ADD_LDFLAGS) +ifdef CUDA_PATH + NVCC=$(CUDA_PATH)/bin/nvcc + CFLAGS += -I$(CUDA_PATH)/include + LDFLAGS += -L$(CUDA_PATH)/lib64 endif - ifeq ($(USE_CUDA), 1) CFLAGS += -DTVM_CUDA_RUNTIME=1 LDFLAGS += -lcuda -lcudart -lnvrtc @@ -44,6 +36,7 @@ else CFLAGS += -DTVM_CUDA_RUNTIME=0 endif +FRAMEWORKS= ifeq ($(USE_OPENCL), 1) CFLAGS += -DTVM_OPENCL_RUNTIME=1 @@ -57,6 +50,23 @@ else CFLAGS += -DTVM_OPENCL_RUNTIME=0 endif +# llvm configuration +LLVM_CONFIG=llvm-config + +ifeq ($(USE_LLVM), 1) + LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3) + LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags)) + LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs) + CFLAGS += $(LLVM_INCLUDE) -DTVM_LLVM_VERSION=$(LLVM_VERSION) +endif + +ifdef $(ADD_CFLAGS) + CFLAGS += $(ADD_CFLAGS) +endif + +ifdef $(ADD_LDFLAGS) + LDFLAGS += $(ADD_LDFLAGS) +endif include tests/cpp/unittest.mk diff --git a/dmlc-core b/dmlc-core index 3a51614d39b6..8dd365636528 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 3a51614d39b69fdb5de1efcf1016426626d267a6 +Subproject commit 8dd365636528175e785448cf8a9f4e494c8ee0e0 diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 7f8f5c0cc131..141e4b68fe6c 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -90,7 +90,7 @@ class BufferNode : public Node { Type dtype); static constexpr const char* _type_key = "Buffer"; - TVM_DECLARE_NODE_TYPE_INFO(BufferNode); + TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); }; inline const BufferNode* Buffer::operator->() const { diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h index c1796be8e0ca..6865586a501e 100644 --- a/include/tvm/codegen.h +++ b/include/tvm/codegen.h @@ -31,6 +31,13 @@ PackedFunc BuildStackVM( LoweredFunc func, const std::unordered_map& device_funcs); +/*! + * \brief Build a LLVM VM function, this is still beta + * \param func The LoweredFunc to be build + * \return A packed function representing the func. + */ +PackedFunc BuildLLVM(LoweredFunc func); + /*! * \brief Build a CUDA function with NVRTC * diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 2c01d7acadbf..b7a6a458876f 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -36,7 +36,6 @@ using Halide::Internal::make_zero; using Halide::Internal::as_const_int; using Halide::Internal::as_const_uint; - inline Type TVMType2Type(TVMType t) { return Type(static_cast(t.code), t.bits, t.lanes); } @@ -182,7 +181,7 @@ class IterVarNode : public Node { static IterVar make(Range dom, Var var, std::string thread_tag); static constexpr const char* _type_key = "IterVar"; - TVM_DECLARE_NODE_TYPE_INFO(IterVarNode); + TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node); }; // inline implementations diff --git a/include/tvm/ir.h b/include/tvm/ir.h index e6aa692af379..29b70e654832 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -200,6 +200,8 @@ using Halide::Internal::Realize; using Halide::Internal::Block; using Halide::Internal::IfThenElse; using Halide::Internal::Evaluate; +// ir functions +using Halide::Internal::is_const_power_of_two_integer; } // namespace ir } // namespace tvm diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index d9c685d77957..2639db01b4b2 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -92,7 +92,7 @@ class LoweredFuncNode : public FunctionBaseNode { } static constexpr const char* _type_key = "LoweredFunc"; - TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode); + TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node); }; // Implementations of inline functions diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 85b289f5d220..662e098d8299 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -39,7 +39,7 @@ class PlaceholderOpNode : public OperationNode { Type dtype); static constexpr const char* _type_key = "PlaceholderOp"; - TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode); + TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode); }; /*! @@ -74,7 +74,7 @@ class ComputeOpNode : public OperationNode { Expr body); static constexpr const char* _type_key = "ComputeOp"; - TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode); + TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode); }; /*! @@ -123,7 +123,7 @@ class ScanOpNode : public OperationNode { Array state_placeholder); static constexpr const char* _type_key = "ScanOp"; - TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode); + TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode); }; diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 94cfff26f1d0..e0970b2e4f19 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -33,7 +33,7 @@ struct NodeTypeChecker { // It can be turned off, but will make non strict checking. // TODO(tqchen) possibly find alternative to turn of RTTI using ContainerType = typename T::ContainerType; - return (dynamic_cast(sptr) != nullptr); + return sptr->derived_from(); } static inline void PrintName(std::ostringstream& os) { // NOLINT(*) using ContainerType = typename T::ContainerType; diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 92eda1f9021c..14b3974420a2 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -152,6 +152,13 @@ typedef void* TVMRetValueHandle; /*! \brief the array handle */ typedef TVMArray* TVMArrayHandle; +/*! + * \brief Used for implementing C API function. + * Set last error message before return. + * \param msg The error message to be set. + */ +TVM_DLL void TVMAPISetLastError(const char* msg); + /*! * \brief return str message of the last error * all function in this file will return 0 when success @@ -287,10 +294,10 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, * \param num_args Number of arguments. * \param ret The return value handle. * \param resource_handle The handle additional resouce handle from fron-end. - * + * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. * \sa TVMCFuncSetReturn */ -typedef void (*TVMPackedCFunc)( +typedef int (*TVMPackedCFunc)( TVMValue* args, int* type_codes, int num_args, diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index c6bbc65660c4..ce0cd3420d69 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -331,7 +331,7 @@ class StageNode : public Node { } static constexpr const char* _type_key = "Stage"; - TVM_DECLARE_NODE_TYPE_INFO(StageNode); + TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node); }; /*! \brief node container for schedule */ @@ -354,7 +354,7 @@ class ScheduleNode : public Node { } static constexpr const char* _type_key = "Schedule"; - TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode); + TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node); }; /*! \brief node container for IterVar attr */ @@ -368,11 +368,14 @@ class IterVarAttrNode : public Node { } static constexpr const char* _type_key = "IterVarAttr"; - TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode); + TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node); }; /*! \brief base node of iteration var */ class IterVarRelationNode : public Node { + public: + static constexpr const char* _type_key = "IterVarRelation"; + TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node); }; /*! @@ -402,7 +405,7 @@ class SplitNode : public IterVarRelationNode { IterVar inner, Expr factor); static constexpr const char* _type_key = "Split"; - TVM_DECLARE_NODE_TYPE_INFO(SplitNode); + TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode); }; /*! @@ -427,7 +430,7 @@ class FuseNode : public IterVarRelationNode { IterVar outer, IterVar inner, IterVar fused); static constexpr const char* _type_key = "Fuse"; - TVM_DECLARE_NODE_TYPE_INFO(FuseNode); + TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode); }; /*! @@ -450,7 +453,7 @@ class RebaseNode : public IterVarRelationNode { static IterVarRelation make(IterVar parent, IterVar rebased); static constexpr const char* _type_key = "Rebase"; - TVM_DECLARE_NODE_TYPE_INFO(RebaseNode); + TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode); }; diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index 11766cd005d5..082ce7f0fdd2 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -153,7 +153,7 @@ class TensorNode : public Node { int value_index); static constexpr const char* _type_key = "Tensor"; - TVM_DECLARE_NODE_TYPE_INFO(TensorNode); + TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node); }; /*! @@ -167,8 +167,6 @@ class OperationNode : public FunctionBaseNode { const std::string& func_name() const final { return name; } - /*! \return number of outputs of this op */ - virtual int num_outputs() const = 0; /*! \return the list of iteration variable at root */ virtual Array root_iter_vars() const = 0; /*! \return type of i-th output */ @@ -177,6 +175,8 @@ class OperationNode : public FunctionBaseNode { virtual Array output_shape(size_t i) const = 0; static constexpr const char* _type_key = "Operation"; + + TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node); }; // Implementations of inline functions diff --git a/make/config.mk b/make/config.mk index 26530827ea0e..f0680eb6b99c 100644 --- a/make/config.mk +++ b/make/config.mk @@ -40,10 +40,11 @@ USE_CUDA = 1 # whether use OpenCL during compile USE_OPENCL = 0 -# add the path to CUDA library to link and compile flag -# if you have already add them to environment variable, leave it as NONE -# USE_CUDA_PATH = /usr/local/cuda -USE_CUDA_PATH = NONE +# whether build with LLVM support +# This requires llvm-config to be in your PATH +# Requires LLVM version >= 4.0 +USE_LLVM = 0 -# whether use cuda runtime compiling for writing kernels in native language (i.e. Python) -USE_NVRTC = 0 +# add the path to CUDA library to link and compile flag +# if you have already add them to environment variable. +# CUDA_PATH = /usr/local/cuda diff --git a/python/tvm/_ctypes/_function.py b/python/tvm/_ctypes/_function.py index 7e0927cac6a3..d32583366b47 100644 --- a/python/tvm/_ctypes/_function.py +++ b/python/tvm/_ctypes/_function.py @@ -56,6 +56,7 @@ def cfun(args, type_codes, num_args, ret, _): check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0]))) _ = temp_args _ = rv + return 0 handle = FunctionHandle() f = TVMPackedCFunc(cfun) diff --git a/python/tvm/_ctypes/_types.py b/python/tvm/_ctypes/_types.py index 33d0e8fe0737..3a832d2b664a 100644 --- a/python/tvm/_ctypes/_types.py +++ b/python/tvm/_ctypes/_types.py @@ -96,7 +96,7 @@ class TVMByteArray(ctypes.Structure): TVMPackedCFunc = ctypes.CFUNCTYPE( - None, + ctypes.c_int, ctypes.POINTER(TVMValue), ctypes.POINTER(ctypes.c_int), ctypes.c_int, diff --git a/src/README.md b/src/README.md index 91cb47ece9ea..6224ee8f1595 100644 --- a/src/README.md +++ b/src/README.md @@ -6,3 +6,4 @@ - arithmetic Arithmetic expression and set simplification - pass The optimization pass on the IR structure - runtime Minimum runtime related codes. +- codegen The code generator diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 7161016f507f..f6018e6adeb1 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -37,6 +37,11 @@ TVM_REGISTER_API(_codegen_BuildStackVM) std::unordered_map()); }); +TVM_REGISTER_API(_codegen_BuildLLVM) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = BuildLLVM(args[0]); + }); + TVM_REGISTER_API(_codegen_BuildNVRTC) .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = BuildNVRTC(args[0], args[1]); diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 0c34fd71fae5..2fc25f55be2d 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -20,7 +20,7 @@ enum SignType { }; // internal node container of int set. -class IntSetNode; +struct IntSetNode; /*! * \brief Integer set class, represent a set of integers in one dimension. @@ -104,6 +104,8 @@ class IntSet : public NodeRef { * \brief Base class of all IntSet containers. */ struct IntSetNode : public Node { + static constexpr const char* _type_key = "IntSet"; + TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); }; using ExprIntSetMap = std::unordered_map strides; static constexpr const char* _type_key = "StrideSet"; - TVM_DECLARE_NODE_TYPE_INFO(StrideSet); + TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode); }; } // namespace arith diff --git a/src/codegen/codegen_stack_vm.cc b/src/codegen/codegen_stack_vm.cc index d1fb0751a8ab..ad979f3b328e 100644 --- a/src/codegen/codegen_stack_vm.cc +++ b/src/codegen/codegen_stack_vm.cc @@ -272,9 +272,6 @@ inline void PushBinary(StackVM::OpCode op_int64, } } - - - inline void PushCast(Type dst, Type src, CodeGenStackVM* p) { @@ -496,7 +493,5 @@ TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) .set_dispatch([](const Call *op, CodeGenStackVM* p) { p->Push_(op); }); - - } // namespace codegen } // namespace tvm diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc new file mode 100644 index 000000000000..31b6b1c5d0d5 --- /dev/null +++ b/src/codegen/llvm/codegen_llvm.cc @@ -0,0 +1,915 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_llvm.cc + */ +#ifdef TVM_LLVM_VERSION + +#include +#include "./codegen_llvm.h" +#include "../../arithmetic/compute_expr.h" + +namespace tvm { +namespace codegen { + +void CodeGenLLVM::Init(const std::string& module_name, + llvm::LLVMContext* ctx) { + InitializeLLVM(); + static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); + static_assert(alignof(TVMValue) == alignof(double), "invariant"); + // clear maps + var_map_.clear(); + str_map_.clear(); + func_handle_map_.clear(); + // initialize types. + if (ctx_ != ctx) { + t_void_ = llvm::Type::getVoidTy(*ctx); + t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo(); + t_int_ = llvm::Type::getIntNTy(*ctx, sizeof(int) * 8); + t_char_ = llvm::Type::getInt8Ty(*ctx); + t_int8_ = llvm::Type::getInt8Ty(*ctx); + t_int16_ = llvm::Type::getInt16Ty(*ctx); + t_int32_ = llvm::Type::getInt32Ty(*ctx); + t_float64_ = llvm::Type::getDoubleTy(*ctx); + t_tvm_index_ = llvm::Type::getIntNTy(*ctx, sizeof(tvm_index_t) * 8); + t_tvm_context_ = llvm::StructType::create({t_int_, t_int_}); + t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_}); + t_tvm_func_handle_ = t_void_p_; + t_tvm_array_ = llvm::StructType::create( + {t_void_p_, + t_tvm_index_->getPointerTo(), + t_tvm_index_->getPointerTo(), + t_tvm_index_, + t_tvm_type_, + t_tvm_context_}); + t_tvm_value_ = llvm::StructType::create({t_float64_}); + md_builder_.reset(new llvm::MDBuilder(*ctx)); + md_very_likely_branch_ = + md_builder_->createBranchWeights(1 << 30, 0); + md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa"); + } + ctx_ = ctx; + // initialize modules + module_.reset(new llvm::Module(module_name, *ctx)); + // initialize TVM runtime API + f_tvm_func_call_ = llvm::Function::Create( + llvm::FunctionType::get(t_int_, { + t_tvm_func_handle_, + t_tvm_value_->getPointerTo(), + t_int_->getPointerTo(), + t_int_, + t_tvm_value_->getPointerTo(), + t_int_->getPointerTo()}, false), + llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get()); + f_tvm_func_get_global_ = llvm::Function::Create( + llvm::FunctionType::get(t_int_, { + t_char_->getPointerTo(), + t_tvm_func_handle_->getPointerTo()}, false), + llvm::Function::ExternalLinkage, "TVMFuncGetGlobal", module_.get()); + f_tvm_api_set_last_error_ = llvm::Function::Create( + llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false), + llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get()); + // initialize builder + builder_.reset(new IRBuilder(*ctx)); +} + +void CodeGenLLVM::AddFunction(const LoweredFunc& f) { + var_map_.clear(); + CHECK(!module_->getFunction(f->name)) + << "Function " << f->name << "already exists in module"; + std::vector arg_type; + for (Var arg : f->args) { + Type t = arg.type(); + if (t.is_handle() && f->handle_data_type.count(arg)) { + arg_type.push_back( + LLVMType(f->handle_data_type[arg].type())->getPointerTo()); + } else { + arg_type.push_back(LLVMType(t)); + } + } + llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_type, false); + // setup the function. + function_ = llvm::cast(module_->getOrInsertFunction(f->name, ftype)); + function_->setCallingConv(llvm::CallingConv::C); + size_t idx = 0; + + for (auto it = function_->arg_begin(); + it != function_->arg_end(); ++it, ++idx) { + llvm::Argument* v = &(*it); + var_map_[f->args[idx].get()] = v; + } + + llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_); + builder_->SetInsertPoint(block); + this->Visit(f->body); + builder_->CreateRet(ConstInt32(0)); +} + +class FPassManager : public llvm::legacy::FunctionPassManager { + public: + explicit FPassManager(llvm::Module* m) + : llvm::legacy::FunctionPassManager(m) {} + // override add to allow messaging + void add(llvm::Pass* p) final { + llvm::legacy::FunctionPassManager::add(p); + } +}; +class MPassManager : public llvm::legacy::PassManager { + public: + // override add to allow messaging + void add(llvm::Pass* p) final { + llvm::legacy::PassManager::add(p); + } +}; + + +void CodeGenLLVM::Optimize() { + // place optimization pass + llvm::PassManagerBuilder builder; + builder.OptLevel = 3; + builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0); + builder.LoopVectorize = true; + builder.SLPVectorize = true; + // pass manager + FPassManager fpass(module_.get()); + MPassManager mpass; + builder.populateFunctionPassManager(fpass); + builder.populateModulePassManager(mpass); + + fpass.doInitialization(); + for (auto it = module_->begin(); it != module_->end(); ++it) { + fpass.run(*it); + } + fpass.doFinalization(); + mpass.run(*module_); +} + +std::unique_ptr CodeGenLLVM::Finish() { + this->Optimize(); + var_map_.clear(); + str_map_.clear(); + func_handle_map_.clear(); + return std::move(module_); +} + +llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const { + llvm::Type* ret = nullptr; + if (t.is_uint() || t.is_int()) { + ret = llvm::Type::getIntNTy(*ctx_, t.bits()); + } else if (t.is_float()) { + switch (t.bits()) { + case 16: ret = llvm::Type::getHalfTy(*ctx_); break; + case 32: ret = llvm::Type::getFloatTy(*ctx_); break; + case 64: ret = llvm::Type::getDoubleTy(*ctx_); break; + default: LOG(FATAL) << "cannot handle " << t; + } + } else { + CHECK(t.is_handle()); + ret = t_void_p_; + } + if (t.lanes() != 1) { + ret = llvm::VectorType::get(ret, t.lanes()); + } + return ret; +} + +void CodeGenLLVM::Visit_(const Variable* op) { + value_ = GetVarValue(op); +} + +void CodeGenLLVM::Visit_(const Cast* op) { + value_ = CreateCast(op->value.type(), op->type, MakeValue(op->value)); +} + +void CodeGenLLVM::Visit_(const IntImm* op) { + value_ = llvm::ConstantInt::getSigned(LLVMType(op->type), op->value); +} + +void CodeGenLLVM::Visit_(const UIntImm* op) { + value_ = llvm::ConstantInt::get(LLVMType(op->type), op->value); +} + +void CodeGenLLVM::Visit_(const FloatImm* op) { + value_ = llvm::ConstantFP::get(LLVMType(op->type), op->value); +} + +void CodeGenLLVM::Visit_(const StringImm* op) { + value_ = GetConstString(op->value); +} + +#define DEFINE_CODEGEN_BINARY_OP(OP) \ + llvm::Value* CodeGenLLVM::Create ## OP( \ + Type t, llvm::Value* a, llvm::Value *b) { \ + if (t.is_float()) { \ + return builder_->CreateF ## OP (a, b); \ + } else if (t.is_int() && t.bits() >= 32) { \ + return builder_->CreateNSW ## OP (a, b); \ + } else { \ + return builder_->Create ## OP (a, b); \ + } \ + } \ + +DEFINE_CODEGEN_BINARY_OP(Add); +DEFINE_CODEGEN_BINARY_OP(Sub); +DEFINE_CODEGEN_BINARY_OP(Mul); + +void CodeGenLLVM::Visit_(const Add* op) { + value_ = CreateAdd(op->type, MakeValue(op->a), MakeValue(op->b)); +} + +void CodeGenLLVM::Visit_(const Sub* op) { + value_ = CreateSub(op->type, MakeValue(op->a), MakeValue(op->b)); +} + +void CodeGenLLVM::Visit_(const Mul* op) { + value_ = CreateMul(op->type, MakeValue(op->a), MakeValue(op->b)); +} + +void CodeGenLLVM::Visit_(const Div* op) { + llvm::Value* a = MakeValue(op->a); + int shift; + if (op->type.is_float()) { + value_ = builder_->CreateFDiv(a, MakeValue(op->b)); + } else if ((op->type.is_int() || op->type.is_uint()) && + is_const_power_of_two_integer(op->b, &shift)) { + value_ = builder_->CreateAShr(a, shift); + } else { + llvm::Value* b = MakeValue(op->b); + if (op->type.is_int()) { + value_ = builder_->CreateSDiv(a, b); + } else { + CHECK(op->type.is_uint()); + value_ = builder_->CreateUDiv(a, b); + } + } +} + +void CodeGenLLVM::Visit_(const Mod* op) { + CHECK(!op->type.is_float()) + << "Cannot do mod for float"; + if (op->type.is_int()) { + value_ = builder_->CreateSRem(MakeValue(op->a), MakeValue(op->b)); + } else { + CHECK(op->type.is_uint()); + value_ = builder_->CreateURem(MakeValue(op->a), MakeValue(op->b)); + } +} + +void CodeGenLLVM::Visit_(const Min* op) { + llvm::Value* a = MakeValue(op->a); + llvm::Value* b = MakeValue(op->b); + llvm::Value* cond = CreateLT(op->a.type(), a, b); + value_ = builder_->CreateSelect(cond, a, b); +} + +void CodeGenLLVM::Visit_(const Max* op) { + llvm::Value* a = MakeValue(op->a); + llvm::Value* b = MakeValue(op->b); + llvm::Value* cond = CreateGT(op->a.type(), a, b); + value_ = builder_->CreateSelect(cond, a, b); +} + +#define DEFINE_CODEGEN_CMP_OP(OP) \ + llvm::Value* CodeGenLLVM::Create ## OP( \ + Type t, llvm::Value* a, llvm::Value* b) { \ + if (t.is_float()) { \ + return builder_->CreateFCmpO ## OP (a, b); \ + } else if (t.is_int()) { \ + return builder_->CreateICmpS ## OP (a, b); \ + } else { \ + return builder_->CreateICmpU ## OP (a, b); \ + } \ + } \ + +DEFINE_CODEGEN_CMP_OP(LT); +DEFINE_CODEGEN_CMP_OP(LE); +DEFINE_CODEGEN_CMP_OP(GT); +DEFINE_CODEGEN_CMP_OP(GE); + +void CodeGenLLVM::Visit_(const LT* op) { + value_ = CreateLT(op->a.type(), MakeValue(op->a), MakeValue(op->b)); +} +void CodeGenLLVM::Visit_(const LE* op) { + value_ = CreateLE(op->a.type(), MakeValue(op->a), MakeValue(op->b)); +} +void CodeGenLLVM::Visit_(const GT* op) { + value_ = CreateGT(op->a.type(), MakeValue(op->a), MakeValue(op->b)); +} +void CodeGenLLVM::Visit_(const GE* op) { + value_ = CreateGE(op->a.type(), MakeValue(op->a), MakeValue(op->b)); +} + +void CodeGenLLVM::Visit_(const EQ* op) { + if (op->a.type().is_float()) { + value_ = builder_->CreateFCmpOEQ(MakeValue(op->a), MakeValue(op->b)); + } else { + value_ = builder_->CreateICmpEQ(MakeValue(op->a), MakeValue(op->b)); + } +} + +void CodeGenLLVM::Visit_(const NE* op) { + if (op->a.type().is_float()) { + value_ = builder_->CreateFCmpONE(MakeValue(op->a), MakeValue(op->b)); + } else { + value_ = builder_->CreateICmpNE(MakeValue(op->a), MakeValue(op->b)); + } +} + +void CodeGenLLVM::Visit_(const And* op) { + value_ = builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b)); +} + +void CodeGenLLVM::Visit_(const Or* op) { + value_ = builder_->CreateOr(MakeValue(op->a), MakeValue(op->b)); +} + +void CodeGenLLVM::Visit_(const Not* op) { + value_ = builder_->CreateNot(MakeValue(op->a)); +} + +void CodeGenLLVM::Visit_(const Select* op) { + value_ = builder_->CreateSelect( + MakeValue(op->condition), + MakeValue(op->true_value), + MakeValue(op->false_value)); +} + +void CodeGenLLVM::Visit_(const Let* op) { + llvm::Value* v = MakeValue(op->value); + CHECK(!var_map_.count(op->var.get())); + var_map_[op->var.get()] = v; + value_ = MakeValue(op->body); +} + +void CodeGenLLVM::Visit_(const Broadcast* op) { + value_ = CreateBroadcast(MakeValue(op->value), op->lanes); +} + +void CodeGenLLVM::Visit_(const Ramp* op) { + Type t = op->type; + llvm::Value* base = MakeValue(op->base); + llvm::Value* stride = MakeValue(op->stride); + llvm::Value* value = llvm::UndefValue::get(LLVMType(t)); + for (int i = 0; i < t.lanes(); ++i) { + if (i != 0) { + base = CreateAdd(t, base, stride); + } + value = builder_->CreateInsertElement( + value, base, llvm::ConstantInt::get(t_int32_, i)); + } + value_ = value; +} + +void CodeGenLLVM::Visit_(const Load* op) { + Type t = op->type; + CHECK(!t.is_vector()); + + if (t.is_scalar()) { + llvm::DataLayout layout(module_.get()); + uint64_t valign = layout.getTypeAllocSize(LLVMType(t)); + llvm::LoadInst* inst = builder_->CreateAlignedLoad( + CreateBufferPtr( + t, + GetVarValue(op->buffer_var.get()), + MakeValue(op->index)), + valign); + AddAliasInfo(inst, op->buffer_var.get(), op->index); + value_ = inst; + } else { + LOG(FATAL) << "not yet supported"; + } +} + +void CodeGenLLVM::Visit_(const Store* op) { + llvm::Value* value = MakeValue(op->value); + Type t = op->value.type(); + CHECK(!t.is_vector()); + if (t.is_scalar()) { + llvm::DataLayout layout(module_.get()); + uint64_t valign = layout.getTypeAllocSize(value->getType()); + llvm::StoreInst* inst = builder_->CreateAlignedStore( + value, + CreateBufferPtr( + t, + GetVarValue(op->buffer_var.get()), + MakeValue(op->index)), + valign); + AddAliasInfo(inst, op->buffer_var.get(), op->index); + } else { + LOG(FATAL) << "not yet supported"; + } +} + +void CodeGenLLVM::Visit_(const Call* op) { + if (op->is_intrinsic(intrinsic::tvm_call_global) || + op->is_intrinsic(intrinsic::tvm_call_device)) { + value_ = CreateCallPacked(op); + } else if (op->call_type == Call::Intrinsic || + op->call_type == Call::PureIntrinsic) { + value_ = CreateIntrinstic(op); + } else { + CHECK(op->call_type == Call::Extern || + op->call_type == Call::PureExtern); + value_ = CreateCallExtern(op); + } +} + +llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { + if (op->is_intrinsic(Call::bitwise_and)) { + CHECK_EQ(op->args.size(), 2U); + return builder_->CreateAnd( + MakeValue(op->args[0]), MakeValue(op->args[1])); + } else if (op->is_intrinsic(Call::bitwise_xor)) { + CHECK_EQ(op->args.size(), 2U); + return builder_->CreateXor( + MakeValue(op->args[0]), MakeValue(op->args[1])); + } else if (op->is_intrinsic(Call::bitwise_or)) { + CHECK_EQ(op->args.size(), 2U); + return builder_->CreateOr( + MakeValue(op->args[0]), MakeValue(op->args[1])); + } else if (op->is_intrinsic(Call::bitwise_not)) { + CHECK_EQ(op->args.size(), 1U); + return builder_->CreateNot(MakeValue(op->args[0])); + } else if (op->is_intrinsic(Call::shift_left)) { + CHECK_EQ(op->args.size(), 2U); + return builder_->CreateShl( + MakeValue(op->args[0]), MakeValue(op->args[1])); + } else if (op->is_intrinsic(Call::shift_right)) { + CHECK_EQ(op->args.size(), 2U); + if (op->type.is_int()) { + return builder_->CreateAShr( + MakeValue(op->args[0]), MakeValue(op->args[1])); + } else { + return builder_->CreateLShr( + MakeValue(op->args[0]), MakeValue(op->args[1])); + } + } else if (op->is_intrinsic(Call::address_of)) { + const Load *l = op->args[0].as(); + CHECK(op->args.size() == 1 && l); + return CreateBufferPtr( + l->type, GetVarValue(l->buffer_var.get()), MakeValue(l->index)); + } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { + CHECK_EQ(op->args.size(), 1U); + llvm::Value* ptr = MakeValue(op->args[0]); + return builder_->CreateICmpEQ( + ptr, llvm::Constant::getNullValue(ptr->getType())); + } else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) { + CHECK_EQ(op->args.size(), 3U); + CHECK_EQ(op->type.lanes(), 1); + llvm::Value* args = builder_->CreatePointerCast( + MakeValue(op->args[0]), t_tvm_value_->getPointerTo()); + llvm::Value* ptr = builder_->CreateInBoundsGEP( + args, MakeValue(op->args[2])); + // always pass via 64 bit pointers + // For handle type, Handle(64) will simply become 32 bit void* + Type value_type = op->type.with_bits(64); + ptr = builder_->CreatePointerCast( + ptr, LLVMType(value_type)->getPointerTo()); + llvm::Value* value = builder_->CreateAlignedLoad(ptr, 8); + // cast to the desired type + if (value_type != op->type) { + value = CreateCast(value_type, op->type, value); + } + return value; + } else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) { + CHECK_EQ(op->args.size(), 2U); + llvm::Value* arr = builder_->CreatePointerCast( + MakeValue(op->args[0]), t_tvm_array_->getPointerTo()); + llvm::Constant* zero = ConstInt32(0); + llvm::Value* ret = nullptr; + switch (op->args[1].as()->value) { + case intrinsic::kData: { + ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(0)}); break; + } + case intrinsic::kShape: { + ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(1)}); break; + } + case intrinsic::kStrides: { + ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(2)}); break; + } + case intrinsic::kNDim: { + ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(3)}); break; + } + case intrinsic::kTypeCode: { + ret = builder_->CreateInBoundsGEP( + arr, {zero, ConstInt32(4), ConstInt32(0)}); break; + } + case intrinsic::kTypeBits: { + ret = builder_->CreateInBoundsGEP( + arr, {zero, ConstInt32(4), ConstInt32(1)}); break; + } + case intrinsic::kTypeLanes: { + ret = builder_->CreateInBoundsGEP( + arr, {zero, ConstInt32(4), ConstInt32(2)}); break; + } + default: LOG(FATAL) << "unknown field code"; + } + return builder_->CreateLoad(ret); + } else { + LOG(FATAL) << "Unknown intrinstic " << op->name; + } + return nullptr; +} + +llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) { + // create emit codes that checks and load the function. + using llvm::BasicBlock; + BasicBlock* fail_block = BasicBlock::Create( + *ctx_, "call_fail", function_); + BasicBlock* end_block = BasicBlock::Create( + *ctx_, "call_end", function_); + llvm::Value* succ = builder_->CreateICmpEQ( + retcode, llvm::ConstantInt::get(t_int_, 0)); + builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_); + builder_->SetInsertPoint(fail_block); + // return the code. + builder_->CreateRet(retcode); + // otherwise set it to be new end. + builder_->SetInsertPoint(end_block); + return end_block; +} +void CodeGenLLVM::Visit_(const For* op) { + using llvm::BasicBlock; + BasicBlock* for_head = BasicBlock::Create( + *ctx_, "for_head", function_); + BasicBlock* for_body = BasicBlock::Create( + *ctx_, "for_body", function_); + BasicBlock* for_end = BasicBlock::Create( + *ctx_, "for_end", function_); + BasicBlock* pre_block = builder_->GetInsertBlock(); + CHECK(is_zero(op->min)); + Type t = op->min.type(); + llvm::Value* init = ConstInt32(0); + llvm::Value* extent = MakeValue(op->extent); + builder_->CreateBr(for_head); + + builder_->SetInsertPoint(for_head); + llvm::PHINode* index = builder_->CreatePHI(LLVMType(t), 2); + index->addIncoming(init, pre_block); + llvm::Value* cond = CreateLT(t, index, extent); + builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_); + // body of for + builder_->SetInsertPoint(for_body); + var_map_[op->loop_var.get()] = index; + this->Visit(op->body); + llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1)); + index->addIncoming(next_index, builder_->GetInsertBlock()); + builder_->CreateBr(for_head); + // end of for + builder_->SetInsertPoint(for_end); +} + +void CodeGenLLVM::Visit_(const IfThenElse* op) { + using llvm::BasicBlock; + BasicBlock* then_block = BasicBlock::Create( + *ctx_, "if_then", function_); + BasicBlock* else_block = BasicBlock::Create( + *ctx_, "if_else", function_); + BasicBlock* end_block = BasicBlock::Create( + *ctx_, "if_end", function_); + if (!op->else_case.defined()) { + else_block = end_block; + } + // condition. + llvm::Value* cond = MakeValue(op->condition); + bool likely = true; + if (likely) { + builder_->CreateCondBr(cond, then_block, else_block, md_very_likely_branch_); + } else { + builder_->CreateCondBr(cond, then_block, else_block); + } + // then case. + builder_->SetInsertPoint(then_block); + this->Visit(op->then_case); + builder_->CreateBr(end_block); + // else case. + if (op->else_case.defined()) { + builder_->SetInsertPoint(else_block); + this->Visit(op->else_case); + builder_->CreateBr(end_block); + } + builder_->SetInsertPoint(end_block); +} + +void CodeGenLLVM::Visit_(const Allocate* op) { + CHECK(!is_zero(op->condition)); + llvm::Value* buf = nullptr; + if (op->new_expr.defined()) { + CHECK_EQ(op->free_function, "nop"); + buf = MakeValue(op->new_expr); + } else { + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + buf = builder_->CreateAlloca( + LLVMType(op->type), ConstInt32(constant_size)); + } + buf = builder_->CreatePointerCast(buf, LLVMType(op->type)->getPointerTo()); + CHECK(!var_map_.count(op->buffer_var.get())); + var_map_[op->buffer_var.get()] = buf; +} + +void CodeGenLLVM::Visit_(const AttrStmt* op) { + this->Visit(op->body); +} + +void CodeGenLLVM::Visit_(const AssertStmt* op) { + using llvm::BasicBlock; + llvm::Value* cond = MakeValue(op->condition); + std::ostringstream os; + os << "Assert fail: " << op->condition; + if (op->message.as()) { + os << ", " << op->message.as()->value; + } + llvm::Value* msg = GetConstString(os.str()); + BasicBlock* fail_block = BasicBlock::Create( + *ctx_, "assert_fail", function_); + BasicBlock* end_block = BasicBlock::Create( + *ctx_, "assert_end", function_); + builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); + // fail condition. + builder_->SetInsertPoint(fail_block); + builder_->CreateCall(f_tvm_api_set_last_error_, {msg}); + builder_->CreateRet(llvm::ConstantInt::getSigned(t_int32_, -1)); + // otherwise set it to be new end. + builder_->SetInsertPoint(end_block); +} + +void CodeGenLLVM::Visit_(const LetStmt* op) { + llvm::Value* v = MakeValue(op->value); + CHECK(!var_map_.count(op->var.get())); + var_map_[op->var.get()] = v; + this->Visit(op->body); +} + +void CodeGenLLVM::AddAliasInfo( + llvm::Instruction* inst, const Variable* buffer, Expr index) { + int base = 0, width = 0; + // create meta-data for alias analysis + // Use a group of binary tree ranges. + const Ramp* ramp = index.as(); + if (ramp) { + int base, stride; + if (arith::GetConstInt(ramp->base, &base) && + arith::GetConstInt(ramp->stride, &stride)) { + int xwith = ramp->lanes * stride; + width = 1; + while (width < xwith) { + width *= 2; + } + while (base % width) { + base -= base % width; + width *= 2; + } + } + } else { + if (arith::GetConstInt(index, &base)) width = 1; + } + + llvm::MDNode* meta = md_tbaa_root_; + std::ostringstream buffer_addr; + buffer_addr << buffer; + meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta); + // create a tree-shape access structure. + if (width != 0) { + for (int w = 1024; w >= width; w /= 2) { + int b = (base / w) * w; + std::stringstream os; + os << buffer << ".w" << w << ".b" << b; + meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta); + } + } + inst->setMetadata( + "tbaa", + md_builder_->createTBAAStructTagNode(meta, meta, 0)); +} + +llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { + llvm::Constant* init = llvm::UndefValue::get( + llvm::VectorType::get(value->getType(), lanes)); + llvm::Constant* zero = ConstInt32(0); + value = builder_->CreateInsertElement(init, value, zero); + llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); + return builder_->CreateShuffleVector(value, init, mask); +} + +llvm::Value* CodeGenLLVM::CreateBufferPtr( + Type t, llvm::Value* buffer, llvm::Value* index) { + llvm::Type* elem_type = buffer->getType(); + unsigned address_space = elem_type->getPointerAddressSpace(); + llvm::Type* load_type = LLVMType(t)->getPointerTo(address_space); + + if (load_type != elem_type) { + buffer = builder_->CreatePointerCast(buffer, load_type); + } + llvm::Constant* cindex = llvm::dyn_cast(index); + if (cindex && cindex->isZeroValue()) { + return buffer; + } + return builder_->CreateInBoundsGEP(buffer, index); +} + +llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { + llvm::Type * target = LLVMType(to); + if (value->getType() == target) return value; + if (from.is_handle() && from.is_handle()) { + return builder_->CreateBitCast(value, target); + } else if (!from.is_float() && !to.is_float()) { + return builder_->CreateIntCast(value, target, from.is_int()); + } else if (from.is_float() && to.is_int()) { + return builder_->CreateFPToSI(value, target); + } else if (from.is_float() && to.is_uint()) { + if (to.bits() < 8) { + value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8))); + return builder_->CreateIntCast(value, target, false); + } else { + return builder_->CreateFPToUI(value, target); + } + } else if (from.is_int() && to.is_float()) { + return builder_->CreateSIToFP(value, target); + } else if (from.is_uint() && to.is_float()) { + return builder_->CreateUIToFP(value, target); + } else { + CHECK(from.is_float() && to.is_float()); + return builder_->CreateFPCast(value, target); + } +} + +llvm::Value* CodeGenLLVM::GetPackedFuncHandle( + const std::string& fname, bool global) { + using llvm::BasicBlock; + // We will store the packed function handle in global space. + // Initialize it during the first call. + llvm::DataLayout layout(module_.get()); + uint64_t halign = layout.getTypeAllocSize(t_tvm_func_handle_); + auto it = func_handle_map_.find(fname); + llvm::GlobalVariable* hptr; + if (it == func_handle_map_.end()) { + // create global location for the handle + // create the function handle + hptr = new llvm::GlobalVariable( + *module_, t_tvm_func_handle_, false, + llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func"); + hptr->setAlignment(halign); + hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_)); + func_handle_map_[fname] = hptr; + } else { + hptr = it->second; + } + // create emit codes that checks and load the function. + BasicBlock* pre_block = builder_->GetInsertBlock(); + BasicBlock* init_block = BasicBlock::Create( + *ctx_, "handle_init", function_); + BasicBlock* end_block = BasicBlock::Create( + *ctx_, "handle_init_end", function_); + llvm::Value* handle = builder_->CreateAlignedLoad(hptr, halign); + llvm::Value* handle_not_null = builder_->CreateICmpNE( + handle, llvm::Constant::getNullValue(t_tvm_func_handle_)); + builder_->CreateCondBr( + handle_not_null, end_block, init_block, md_very_likely_branch_); + // loaded handle, if created by call. + llvm::Value* loaded_handle = nullptr; + // Then block. + // We do not do lock here, so unlike static variable initialization + // This clause might be executed multiple times, but it is safe to do so. + builder_->SetInsertPoint(init_block); + if (global) { + llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_); + llvm::Value* retcode = builder_->CreateCall( + f_tvm_func_get_global_, {GetConstString(fname), out}); + init_block = CheckPackedCallSuccess(retcode); + loaded_handle = builder_->CreateAlignedLoad(out, halign); + } else { + LOG(FATAL) << "not yet supported"; + } + builder_->CreateBr(end_block); + // end block + builder_->SetInsertPoint(end_block); + llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2); + phi->addIncoming(handle, pre_block); + phi->addIncoming(loaded_handle, init_block); + return phi; +} + +llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) { + CHECK_GE(op->args.size(), 1U); + std::string func_name = op->args[0].as()->value; + CHECK(!op->is_intrinsic(intrinsic::tvm_call_device)) + << "not implemented for now"; + llvm::Value* handle = GetPackedFuncHandle( + func_name, op->is_intrinsic(intrinsic::tvm_call_global)); + + // call the function + unsigned nargs = static_cast(op->args.size() - 1); + llvm::Value* targs = builder_->CreateAlloca( + t_tvm_value_, ConstInt32(nargs)); + llvm::Value* tcodes = builder_->CreateAlloca( + t_int_, ConstInt32(nargs)); + for (unsigned i = 0; i < nargs; ++i) { + Expr expr = op->args[i + 1]; + Type t = expr.type(); + CHECK_EQ(t.lanes(), 1); + // Always pass via 64 bit value. + // For handle type, Handle(64) maps to 32 bit void* in 32bit platform. + Type api_type = t.with_bits(64); + llvm::Value* value = CreateCast(t, api_type, MakeValue(expr)); + llvm::Value* store_ptr = builder_->CreatePointerCast( + builder_->CreateInBoundsGEP(targs, ConstInt32(i)), + LLVMType(api_type)->getPointerTo()); + builder_->CreateAlignedStore(value, store_ptr, 8); + builder_->CreateAlignedStore( + ConstInt32(t.code()), + builder_->CreateInBoundsGEP(tcodes, ConstInt32(i)), 4); + } + llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_); + llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_); + CheckPackedCallSuccess( + builder_->CreateCall( + f_tvm_func_call_, + {handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode})); + Type r_type = op->type; + Type r_api_type = op->type.with_bits(64); + llvm::Value* rvalue = + builder_->CreateAlignedLoad( + builder_->CreatePointerCast( + ret_value, LLVMType(r_api_type)->getPointerTo()), 8); + rvalue = CreateCast(r_api_type, r_type, rvalue); + return rvalue; +} + +llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) { + std::vector arg_values(op->args.size()); + for (size_t i = 0; i < op->args.size(); ++i) { + arg_values[i] = MakeValue(op->args[i]); + } + if (op->type.is_scalar()) { + llvm::Function* f = module_->getFunction(op->name); + if (f) { + return builder_->CreateCall(f, arg_values); + } else { + LOG(FATAL) << "cannot find function " << op->name; + } + } else { + llvm::Function* f = module_->getFunction(op->name); + if (f) { + return CreateScalarizedCall(op, f, arg_values); + } else { + LOG(FATAL) << "cannot find function " << op->name; + } + } + return nullptr; +} + +llvm::Value* CodeGenLLVM::CreateScalarizedCall( + const Call* op, llvm::Function* f, const std::vector& args) { + llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type)); + for (int i = 0; i < op->type.lanes(); ++i) { + std::vector sargs(args.size()); + for (size_t j = 0; j < args.size(); ++j) { + if (args[j]->getType()->isVectorTy()) { + sargs[j] = builder_->CreateExtractElement(args[j], ConstInt32(i)); + } else { + sargs[j] = args[j]; + } + } + llvm::CallInst* call = builder_->CreateCall(f, sargs); + if (op->is_pure()) { + call->setDoesNotAccessMemory(); + } + call->setDoesNotThrow(); + if (!call->getType()->isVoidTy()) { + value = builder_->CreateInsertElement(value, call, ConstInt32(i)); + } + } + return value; +} + +llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const { + auto it = var_map_.find(v); + CHECK(it != var_map_.end()) + << "Cannot find " << v->name_hint << " in the var map"; + return it->second; +} + +llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { + auto it = str_map_.find(str); + if (it == str_map_.end()) { + llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); + llvm::GlobalVariable *global = new llvm::GlobalVariable( + *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str"); + global->setAlignment(1); + global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str)); + // useful constant value + llvm::Constant* zero = ConstInt32(0); + llvm::Constant* indices[] = {zero, zero}; + llvm::Constant* sptr = llvm::ConstantExpr::getGetElementPtr( + type, global, indices); + str_map_[str] = sptr; + return sptr; + } else { + return it->second; + } +} + +} // namespace codegen +} // namespace tvm +#endif // TVM_LLVM_VERSION diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h new file mode 100644 index 000000000000..a6d9e1bfa57a --- /dev/null +++ b/src/codegen/llvm/codegen_llvm.h @@ -0,0 +1,186 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_llvm.h + * \brief Common base class for generating into LLVM IR + */ +#ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_ +#define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_ +#ifdef TVM_LLVM_VERSION + +#include +#include +#include +#include +#include +#include +#include "./llvm_common.h" + +namespace tvm { +namespace codegen { + +using namespace ir; + +/*! + * \brief A base class to generate a LLVM. + */ +class CodeGenLLVM : public IRVisitor { + public: + /*! + * \brief Initialize the code generator with given context + * \param module_name The name of the module. + * \param ctx The context. + */ + void Init(const std::string& module_name, llvm::LLVMContext* ctx); + /*! + * \brief Compile and add function f to the current module. + * \param f The function to be added. + */ + void AddFunction(const LoweredFunc& f); + /*! + * \brief Finish current pass of codegen, get the module. + * \return the created module. + */ + std::unique_ptr Finish(); + /*! + * \brief Create Value for expression e + * \param e The expression to be created value for. + * \return created value. + */ + llvm::Value* MakeValue(const Expr& e) { + value_ = nullptr; + this->Visit(e); + CHECK(value_ != nullptr); + return value_; + } + // Short hande code to get a constant int 32 + llvm::Constant* ConstInt32(unsigned value) const { + return llvm::ConstantInt::get(t_int32_, value); + } + // override codegen + void Visit_(const Variable* op) final; + void Visit_(const Cast* op) final; + void Visit_(const IntImm* op) final; + void Visit_(const UIntImm* op) final; + void Visit_(const FloatImm* op) final; + void Visit_(const StringImm* op) final; + void Visit_(const Add* op) final; + void Visit_(const Sub* op) final; + void Visit_(const Mul* op) final; + void Visit_(const Div* op) final; + void Visit_(const Mod* op) final; + void Visit_(const Min* op) final; + void Visit_(const Max* op) final; + void Visit_(const LT* op) final; + void Visit_(const LE* op) final; + void Visit_(const GT* op) final; + void Visit_(const GE* op) final; + void Visit_(const EQ* op) final; + void Visit_(const NE* op) final; + void Visit_(const And* op) final; + void Visit_(const Or* op) final; + void Visit_(const Not* op) final; + void Visit_(const Select* op) final; + void Visit_(const Let* op) final; + void Visit_(const Load* op) final; + void Visit_(const Call* op) final; + void Visit_(const Ramp* op) final; + void Visit_(const Broadcast* op) final; + // stmt + void Visit_(const Store* op) final; + void Visit_(const For* op) final; + void Visit_(const IfThenElse* op) final; + void Visit_(const Allocate* op) final; + void Visit_(const AttrStmt* op) override; + void Visit_(const AssertStmt* op) final; + void Visit_(const LetStmt* op) final; + // create intrinstic given call + virtual llvm::Value* CreateIntrinstic(const Call* op); + // create extern function call + virtual llvm::Value* CreateCallExtern(const Call* op); + // create call into tvm packed function. + virtual llvm::Value* CreateCallPacked(const Call* op); + + protected: + /*! + * \param t The original type. + * \return LLVM type of t + */ + llvm::Type* LLVMType(const Type& t) const; + // do a scalarize call with f + llvm::Value* CreateScalarizedCall( + const Call* op, llvm::Function* f, const std::vector& args); + // apply optimization on the module. + virtual void Optimize(); + // The IRBuilder. + using IRBuilder = llvm::IRBuilder; + // The current function + llvm::Function* function_; + // Internal builder + std::unique_ptr builder_; + // The module to be returned; + std::unique_ptr module_; + // Internal metabuilder + std::unique_ptr md_builder_; + // llvm context + llvm::LLVMContext* ctx_{nullptr}; + // helpful data types + llvm::Type* t_void_{nullptr}; + llvm::Type* t_void_p_{nullptr}; + llvm::Type* t_int_{nullptr}; + llvm::Type* t_char_{nullptr}; + llvm::Type* t_int8_{nullptr}; + llvm::Type* t_int16_{nullptr}; + llvm::Type* t_int32_{nullptr}; + llvm::Type* t_int64_{nullptr}; + llvm::Type* t_float64_{nullptr}; + // branch + llvm::MDNode* md_very_likely_branch_{nullptr}; + llvm::MDNode* md_tbaa_root_{nullptr}; + // TVM related data types + llvm::Type* t_tvm_index_{nullptr}; + llvm::Type* t_tvm_func_handle_{nullptr}; + llvm::StructType* t_tvm_context_{nullptr}; + llvm::StructType* t_tvm_type_{nullptr}; + llvm::StructType* t_tvm_array_{nullptr}; + llvm::StructType* t_tvm_value_{nullptr}; + // tvm api functions + llvm::Function* f_tvm_func_call_{nullptr}; + llvm::Function* f_tvm_func_get_global_{nullptr}; + llvm::Function* f_tvm_api_set_last_error_{nullptr}; + // The acting body + llvm::BasicBlock* block_{nullptr}; + // Last value returned codegen call. + llvm::Value* value_{nullptr}; + + private: + // comparison op + llvm::Value* GetVarValue(const Variable* v) const; + llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b); + llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); + llvm::Value* GetConstString(const std::string& str); + llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index); + llvm::Value* CreateCast(Type from, Type to, llvm::Value* value); + llvm::Value* GetPackedFuncHandle(const std::string& str, bool global); + // Check if the call to packed function is successful + // if not directly finalize function and pass on return code. + // return the end block after the check + llvm::BasicBlock* CheckPackedCallSuccess(llvm::Value* retcode); + // add alias information. + void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index); + // The definition of local variable. + std::unordered_map var_map_; + // global strings + std::unordered_map str_map_; + // global to packed function handle + std::unordered_map func_handle_map_; +}; +} // namespace codegen +} // namespace tvm +#endif // LLVM_VERSION +#endif // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_ diff --git a/src/codegen/llvm/llvm_common.cc b/src/codegen/llvm/llvm_common.cc new file mode 100644 index 000000000000..7dbcf13906a1 --- /dev/null +++ b/src/codegen/llvm/llvm_common.cc @@ -0,0 +1,39 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file llvm_common.cc + */ +#ifdef TVM_LLVM_VERSION + +#include +#include +#include "./llvm_common.h" + +namespace tvm { +namespace codegen { + +struct LLVMEnv { + std::mutex mu; + bool native_initialized{false}; + + static LLVMEnv* Global() { + static LLVMEnv inst; + return &inst; + } +}; + +void InitializeLLVM() { + LLVMEnv* e = LLVMEnv::Global(); + if (!e->native_initialized) { + std::lock_guard(e->mu); + if (!e->native_initialized) { + e->native_initialized = true; + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + } + } +} + +} // namespace codegen +} // namespace tvm +#endif // TVM_LLVM_VERSION diff --git a/src/codegen/llvm/llvm_common.h b/src/codegen/llvm/llvm_common.h new file mode 100644 index 000000000000..e10e8430e14c --- /dev/null +++ b/src/codegen/llvm/llvm_common.h @@ -0,0 +1,52 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file llvm_common.h + * \brief Common utilities for llvm initialization. + */ +#ifndef TVM_CODEGEN_LLVM_LLVM_COMMON_H_ +#define TVM_CODEGEN_LLVM_LLVM_COMMON_H_ +#ifdef TVM_LLVM_VERSION + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +extern "C" { +// Function signature for LLVM generated packed function. +typedef int (*LLVMPackedCFunc)(void* args, + int* type_codes, + int num_args); +} // extern "C" + +namespace tvm { +namespace codegen { + +/*! + * \brief Initialize LLVM on this process, + * can be called multiple times. + */ +void InitializeLLVM(); + +} // namespace codegen +} // namespace tvm +#endif // TVM_LLVM_VERSION +#endif // TVM_CODEGEN_LLVM_LLVM_COMMON_H_ diff --git a/src/codegen/llvm/llvm_exec_engine.cc b/src/codegen/llvm/llvm_exec_engine.cc new file mode 100644 index 000000000000..b31505253900 --- /dev/null +++ b/src/codegen/llvm/llvm_exec_engine.cc @@ -0,0 +1,77 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file llvm_exec_engine.cc + */ +#include +#include +#include "./llvm_common.h" +#include "./codegen_llvm.h" + +namespace tvm { +namespace codegen { + +using runtime::TVMArgs; +using runtime::TVMRetValue; +using runtime::PackedFunc; + +#ifdef TVM_LLVM_VERSION +// Environment to keep jit resources alive. +struct LLVMJITEnv { + std::shared_ptr ctx; + llvm::ExecutionEngine* ee{nullptr}; + // constructor + LLVMJITEnv(std::shared_ptr ctx, + llvm::ExecutionEngine* ee) + : ctx(ctx), ee(ee) { + } + // destructor + ~LLVMJITEnv() { + if (ee != nullptr) { + ee->runStaticConstructorsDestructors(true); + delete ee; + } + } +}; + + +PackedFunc JITCompile(std::unique_ptr module, + std::shared_ptr ctx, + const std::string& func_name) { + llvm::EngineBuilder builder(std::move(module)); + builder.setEngineKind(llvm::EngineKind::JIT); + builder.setOptLevel(llvm::CodeGenOpt::Aggressive); + std::shared_ptr env = std::make_shared( + ctx, builder.create()); + CHECK(env->ee != nullptr); + auto* faddr = reinterpret_cast( + env->ee->getFunctionAddress(func_name)); + env->ee->runStaticConstructorsDestructors(false); + return PackedFunc([env, faddr](TVMArgs args, TVMRetValue* rv) { + int ret = (*faddr)( + (void*)args.values, // NOLINT(*) + (int*)args.type_codes, // NOLINT(*) + args.num_args); + CHECK(ret == 0) << TVMGetLastError(); + }); +} + +PackedFunc BuildLLVM(LoweredFunc func) { + InitializeLLVM(); + // use one context per function. + std::shared_ptr ctx = + std::make_shared(); + CodeGenLLVM cg; + cg.Init(func->name, ctx.get()); + cg.AddFunction(func); + std::unique_ptr m = cg.Finish(); + return JITCompile(std::move(m), ctx, func->name); +} + +#else +PackedFunc BuildLLVM(LoweredFunc func) { + LOG(FATAL) << "LLVM is not enabled"; + return PackedFunc(); +} +#endif // TVM_LLVM_VERSION +} // namespace codegen +} // namespace tvm diff --git a/src/runtime/runtime_base.h b/src/runtime/runtime_base.h index 31fb8ede2e57..b47024adca91 100644 --- a/src/runtime/runtime_base.h +++ b/src/runtime/runtime_base.h @@ -21,8 +21,6 @@ */ #define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) -void TVMAPISetLastError(const char* msg); - /*! * \brief handle exception throwed out * \param e the exception diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 4b7c7f886d45..11b6d354dfab 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -274,8 +274,9 @@ Stmt MakeLoop(const Stage& s, bound_state[iv] = false; } PassUpBoundCheck(s, dom_map, &bound_state); - auto nest = MakeLoopNest(s, dom_map, 0, false, - bound_state, {}, &value_map); + auto nest = MakeLoopNest( + s, dom_map, 0, false, + bound_state, {{}}, &value_map); provide = Substitute(provide, value_map); if (init.defined()) { diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 5d48ed8b453d..b22cbfad4a59 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -2,7 +2,6 @@ import numpy as np def test_add_pipeline(): - """Not yet working, mock design""" n = tvm.Var('n') A = tvm.placeholder((n,), name='A') B = tvm.placeholder((n,), name='B') diff --git a/tests/python/unittest/test_runtime_stack_vm.py b/tests/python/unittest/test_codegen_stack_llvm.py similarity index 51% rename from tests/python/unittest/test_runtime_stack_vm.py rename to tests/python/unittest/test_codegen_stack_llvm.py index 2de2da544036..addf2e699f29 100644 --- a/tests/python/unittest/test_runtime_stack_vm.py +++ b/tests/python/unittest/test_codegen_stack_llvm.py @@ -6,6 +6,16 @@ def tvm_call_global(*args): return tvm.make.Call("int32", "tvm_call_global", args, 4, None, 0) +def run_jit(fapi, check): + for target in ["stackvm"]: + if target == "llvm": + f = tvm.codegen.BuildLLVM(fapi) + else: + f = tvm.codegen.BuildStackVM(fapi) + check(f) + + + def test_stack_vm_basic(): a = tvm.nd.array(np.zeros(10, dtype='float32')) @tvm.register_func @@ -17,8 +27,7 @@ def tvm_call_back_get_shape(shape0): Ab = tvm.Buffer((n, ), tvm.float32) stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0])) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 1) - f = tvm.codegen.BuildStackVM(fapi) - f(a) + run_jit(fapi, lambda f: f(a)) @tvm.register_func @@ -42,8 +51,10 @@ def test_stack_vm_loop(): fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 1) f = tvm.codegen.BuildStackVM(fapi) a = tvm.nd.array(np.zeros(10, dtype=dtype)) - f(a) - np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0])) + def check(f): + f(a) + np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0])) + run_jit(fapi, check) def test_stack_vm_cond(): @@ -61,15 +72,46 @@ def test_stack_vm_cond(): tvm.make.Store(Ab.data, tvm.make.Load(dtype, Ab.data, i) + 2, i + 1))) fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 1) - f = tvm.codegen.BuildStackVM(fapi) - a = tvm.nd.array(np.zeros(10, dtype=dtype)) - f(a) - y = np.arange(a.shape[0]) * 2 - y[5:] -= 1 - np.testing.assert_equal(a.asnumpy(), y) + def check(f): + a = tvm.nd.array(np.zeros(10, dtype=dtype)) + f(a) + y = np.arange(a.shape[0]) * 2 + y[5:] -= 1 + np.testing.assert_equal(a.asnumpy(), y) + run_jit(fapi, check) +def test_llvm_add_pipeline(): + n = tvm.Var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.Schedule(C.op) + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + Ab = tvm.Buffer(A.shape, A.dtype, name='A') + Bb = tvm.Buffer(B.shape, B.dtype, name='B') + Cb = tvm.Buffer(C.shape, C.dtype, name='C') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) + stmt = tvm.ir_pass.Simplify(stmt) + fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3) + + def check_llvm(): + # build and invoke the kernel. + f = tvm.codegen.BuildLLVM(fapi) + ctx = tvm.cpu(0) + # launch the kernel. + n = 1027 + a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx) + f(a, b, c) + np.testing.assert_allclose( + c.asnumpy(), a.asnumpy() + b.asnumpy()) + #check_llvm() + if __name__ == "__main__": test_stack_vm_cond() - test_stack_vm_loop() test_stack_vm_basic() + test_stack_vm_loop() + test_llvm_add_pipeline()