From 88d9852dffda3d6af61bc14254ebc31b89b0c773 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 16 Mar 2020 10:04:46 -0700 Subject: [PATCH] [TIR][TARGET] Refactor Target codegen to use IRModule and PrimFunc. As part of the unified IR refactor. This PR refactors the target codegen to use IRModule containing tir::PrimFuncs. In order to break the refactor into several steps without breaking the codebase, we built an conversion pass to convert Array into IRModule. The follow-up refactors will gradually move the passes covered by IRModule up until we cover all the passes. Then we can remove the additional redundant concepts such as LoweredFunc. --- include/tvm/ir/expr.h | 12 ++ include/tvm/ir/function.h | 21 ++++ include/tvm/ir/type.h | 56 ++++++++- include/tvm/tir/expr.h | 47 +++---- include/tvm/tir/function.h | 10 ++ include/tvm/tir/ir_pass.h | 14 +++ include/tvm/tir/op.h | 13 ++ python/tvm/ir/__init__.py | 2 +- python/tvm/ir/json_compact.py | 11 ++ python/tvm/ir/type.py | 15 +++ python/tvm/tir/expr.py | 6 +- src/arith/domain_touched.cc | 2 +- src/arith/ir_mutator_with_analyzer.cc | 4 +- src/driver/driver_api.cc | 2 +- src/ir/type.cc | 21 ++++ src/runtime/module.cc | 2 +- src/target/build_common.h | 28 +++++ src/target/codegen.cc | 75 +++++++++-- src/target/llvm/codegen_amdgpu.cc | 26 ++-- src/target/llvm/codegen_cpu.cc | 16 ++- src/target/llvm/codegen_cpu.h | 2 +- src/target/llvm/codegen_llvm.cc | 137 ++++++++++++--------- src/target/llvm/codegen_llvm.h | 38 ++++-- src/target/llvm/codegen_nvptx.cc | 24 ++-- src/target/llvm/codegen_x86_64.cc | 5 +- src/target/llvm/llvm_module.cc | 47 +++++-- src/target/opt/build_cuda_on.cc | 16 ++- src/target/source/codegen_aocl.cc | 20 ++- src/target/source/codegen_c.cc | 66 +++++++--- src/target/source/codegen_c.h | 38 ++++-- src/target/source/codegen_c_host.cc | 70 +++-------- src/target/source/codegen_c_host.h | 4 +- src/target/source/codegen_cuda.cc | 10 +- src/target/source/codegen_cuda.h | 2 +- src/target/source/codegen_metal.cc | 71 +++++++---- src/target/source/codegen_metal.h | 7 +- src/target/source/codegen_opencl.cc | 54 +++----- src/target/source/codegen_opencl.h | 7 +- src/target/source/codegen_opengl.cc | 43 ++++--- src/target/source/codegen_opengl.h | 4 +- src/target/source/codegen_vhls.cc | 41 ++++-- src/target/source/codegen_vhls.h | 11 +- src/target/spirv/build_vulkan.cc | 26 +++- src/target/spirv/codegen_spirv.cc | 23 ++-- src/target/spirv/codegen_spirv.h | 2 +- src/target/stackvm/codegen_stackvm.cc | 38 ++++-- src/target/stackvm/codegen_stackvm.h | 2 +- src/te/operation/cross_thread_reduction.cc | 4 +- src/te/operation/extern_op.cc | 5 +- src/te/operation/hybrid_op.cc | 3 +- src/te/operation/op_util.cc | 2 +- src/te/operation/scan_op.cc | 4 +- src/te/schedule/schedule_ops.cc | 14 +-- src/tir/ir/expr.cc | 50 ++++++-- src/tir/ir/lowered_func.cc | 2 + src/tir/ir/op.cc | 24 +++- src/tir/pass/simple_passes.cc | 26 ++++ src/tir/pass/storage_rewrite.cc | 41 ++++++ tests/python/relay/test_json_compact.py | 24 ++++ tests/python/unittest/test_tir_nodes.py | 11 ++ 60 files changed, 995 insertions(+), 406 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c8b1a3fd2eab2..85b3937509138 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -32,6 +32,7 @@ #include #include #include +#include namespace tvm { @@ -307,6 +308,17 @@ class Integer : public IntImm { * \param other The other value. */ Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) + /*! + * \brief Constructor from enum + * \tparam Enum The enum type. + * \param value The enum value. + */ + template::value>::type> + explicit Integer(ENum value) : Integer(static_cast(value)) { + static_assert(std::is_same::type>::value, + "declare enum to be enum int to use visitor"); + } /*! * \brief Assign an expression to integer. * \param other another expression. diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index db7f4465f2a3a..ecf7c1978d07a 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -213,6 +213,27 @@ constexpr const char* kCallingConv = "calling_conv"; * \sa tvm::Target */ constexpr const char* kTarget = "target"; + +/*! + * \brief Global linker symbol of the function in generated code. + * + * This option forces the code generator to name the + * function with the given. + * + * For example, we could set a global_symbol of a function + * early to make sure that we can always refer to it by + * the symbol name in the generated DLL. + * + * We should not set the attribute for local functions, + * so that the compiler can freely rename them. + * + * A unique global symbol will be automatically assigned + * to each function in the module before the target code + * generation phase. + * + * Type: String + */ +constexpr const char* kGlobalSymbol = "global_symbol"; } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index a9475a1e4e91a..c23626e4de7f6 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -114,7 +114,8 @@ class PrimTypeNode : public TypeNode { TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); }; -/*! + +/* * \brief Managed reference to PrimTypeNode. * \sa PrimTypeNode */ @@ -124,11 +125,53 @@ class PrimType : public Type { * \brief Constructor * \param dtype The corresponding dtype. */ - TVM_DLL PrimType(runtime::DataType dtype); + TVM_DLL explicit PrimType(runtime::DataType dtype); TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); }; + +/*! + * \brief Low-level raw pointer type. + * + * PointerType represents type hints in the TIR to be + * passed to the final code generator. + * + * PointerType should not occur in the high-level analysis. + * + * \sa PointerType + */ +class PointerTypeNode : public TypeNode { + public: + /*! + * \brief The type of the element which the pointer points to. + */ + Type element_type; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("element_type", &element_type); + } + + static constexpr const char* _type_key = "PointerType"; + TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); +}; + +/* + * \brief Managed reference to PointerTypeNode. + * \sa PointerTypeNode + */ +class PointerType : public Type { + public: + /*! + * \brief Constructor + * \param element_type The type of the element which the pointer points to. + */ + TVM_DLL explicit PointerType(Type element_type); + + TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); +}; + + /*! \brief Possible kinds of TypeVars. */ enum TypeKind : int { kType = 0, @@ -283,6 +326,15 @@ inline Type VoidType() { return TupleType::Empty(); } +/*! + * \brief Check whether the tyep represents void. + * \return The check result. + */ +inline bool IsVoidType(const Type& type) { + auto* n = type.as(); + return n && n->fields.size() == 0; +} + /*! * \brief Potential Constraints in a function. * \sa TypeConstraint diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 7d497890dd137..9efc5d433230e 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -55,22 +55,27 @@ namespace tir { */ class VarNode : public PrimExprNode { public: - /*! \brief constructor */ - VarNode() {} - VarNode(DataType dtype, std::string name_hint); - /*! * \brief The hint to the variable name. * \note Each variable is uniquely identified by its address. */ std::string name_hint; + /*! + * \brief type annotaion of the variable. + * + * It is an optional field that provides a refined type of the variable than dtype. + * + * \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type. + */ + Type type_annotation; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); v->Visit("name", &name_hint); + v->Visit("type_annotation", &type_annotation); } - static constexpr const char* _type_key = "Variable"; + static constexpr const char* _type_key = "tir.Var"; TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); }; @@ -78,20 +83,25 @@ class VarNode : public PrimExprNode { class Var : public PrimExpr { public: explicit Var(ObjectPtr n) : PrimExpr(n) {} - /*! \brief constructor + /*! + * \brief Constructor * \param name_hint variable name - * \param t data type + * \param dtype data type */ TVM_DLL explicit Var(std::string name_hint = "v", - DataType t = DataType::Int(32)); + DataType dtype = DataType::Int(32)); + /*! + * \brief Constructor which provides a more detailed type annotation. + * \param name_hint variable name. + * \param type_annotation The type annotation. + */ + TVM_DLL explicit Var(std::string name_hint, Type type_annotation); /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. * \return the new Var copy */ - Var copy_with_suffix(const std::string& suffix) const { - return Var((*this)->name_hint + suffix, (*this)->dtype); - } + TVM_DLL Var copy_with_suffix(const std::string& suffix) const; /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. @@ -116,15 +126,7 @@ class Var : public PrimExpr { */ class SizeVarNode : public VarNode { public: - /*! \brief constructor */ - SizeVarNode() {} - /*! \brief constructor - * \param dtype data type - * \param name_hint variable name - */ - SizeVarNode(DataType dtype, std::string name_hint); - - static constexpr const char* _type_key = "SizeVar"; + static constexpr const char* _type_key = "tir.SizeVar"; TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode); }; @@ -132,12 +134,13 @@ class SizeVarNode : public VarNode { class SizeVar : public Var { public: explicit SizeVar(ObjectPtr n) : Var(n) {} - /*! \brief constructor + /*! + * \brief constructor * \param name_hint variable name * \param t data type */ TVM_DLL explicit SizeVar(std::string name_hint = "s", - DataType t = DataType::Int(32)); + DataType t = DataType::Int(32)); /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 06802671db91f..63a8630a92124 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -171,6 +171,16 @@ constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; * Type: Integer */ constexpr const char* kNoAlias = "tir.noalias"; + +/*! + * \brief Mark the function as the entry function of + * the final generated runtime module. + * + * Type: Integer + * + * \note There can only be one entry function per module. + */ +constexpr const char* kIsEntryFunc = "tir.is_entry_func"; } // namespace attr } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 3a8d62c956fe4..6e9a631fab4d5 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -515,6 +516,19 @@ LoweredFunc CombineContextCall(LoweredFunc f); */ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); + +/*! + * \brief Rewrite the pointer content type of arguments, + * as well as Alloc internal to the function to use + * the most frequently accessed type for load/store + * to avoid pointer casting in backend when possible. + * + * \note implemeneted in storage_rewrite.cc + * \param f The function to be trasnformed + * \return Transformed function. + */ +PrimFunc PointerValueTypeRewrite(PrimFunc f); + /*! * \brief Lower attached storage access information on device. * Do this pass after all storage access analysis finish. diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 6ee506350ba7f..afdc5fc9bf14e 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -52,10 +52,23 @@ namespace tvm { * This function could return a more refined type than * the runtime type provided by expr->dtype * + * \param expr The input parameter. + * \return The result type. + * * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. */ TVM_DLL Type GetType(const PrimExpr& expr); +/*! + * \brief Get the implied DataType for storing values with type during runtime. + * + * \param type The input type. + * \return The result runtime::DataType. + * + * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. + */ +TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); + /*! * Query the maximum possible value of dtype. * \param dtype The data type. diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 8418d63aed2c0..1e114469d9868 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -17,7 +17,7 @@ # pylint: disable=unused-import """Common data structures across all IR variants.""" from .base import SourceName, Span, Node, EnvFunc, load_json, save_json -from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType +from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .tensor_type import TensorType from .type_relation import TypeCall, TypeRelation diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 10ecbaa161c5f..8878108b598e9 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -72,7 +72,15 @@ def _convert(item, _): return item return _convert + def _update_tir_var(new_name): + def _convert(item, nodes): + item["type_key"] = new_name + item["attrs"]["type_annotation"] = "0" + return item + return _convert + node_map = { + # Base IR "relay.TypeVar": _ftype_var, "relay.GlobalTypeVar": _ftype_var, "relay.Type": _rename("Type"), @@ -91,6 +99,9 @@ def _convert(item, _): "relay.PassContext": _rename("transform.PassContext"), "relay.ModulePass": _rename("transform.ModulePass"), "relay.Sequantial": _rename("transform.Sequantial"), + # TIR + "Variable": _update_tir_var("tir.Var"), + "SizeVar": _update_tir_var("tir.SizeVar"), } return create_updater(node_map, "0.6", "0.7") diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index ebbb629911fe9..a61c6e4d58d16 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -46,6 +46,7 @@ class TypeKind(IntEnum): TypeData = 6 +@tvm._ffi.register_object("PrimType") class PrimType(Type): """Primitive data type in the low level IR @@ -59,6 +60,20 @@ def __init__(self, dtype): _ffi_api.PrimType, dtype) +@tvm._ffi.register_object("PointerType") +class PointerType(Type): + """PointerType used in the low-level TIR. + + Parameters + ---------- + element_type : tvm.ir.Type + The type of pointer's element. + """ + def __init__(self, element_type): + self.__init_handle_by_constructor__( + _ffi_api.PointerType, element_type) + + @tvm._ffi.register_object("TypeVar") class TypeVar(Type): """Type parameter in functions. diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index bcf596787cd43..deb8d3446fc1e 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -288,7 +288,7 @@ class CmpExpr(PrimExprWithOp): class LogicalExpr(PrimExprWithOp): pass -@tvm._ffi.register_object("Variable") +@tvm._ffi.register_object("tir.Var") class Var(PrimExprWithOp): """Symbolic variable. @@ -297,7 +297,7 @@ class Var(PrimExprWithOp): name : str The name - dtype : str + dtype : Union[str, tvm.irType] The data type """ def __init__(self, name, dtype): @@ -305,7 +305,7 @@ def __init__(self, name, dtype): _ffi_api.Var, name, dtype) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.SizeVar") class SizeVar(Var): """Symbolic variable to represent a tensor index size which is greater or equal to zero. diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 4eecabdb6d8cb..bda70fb67cba0 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -68,7 +68,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { /* TODO: Thread extent unitest not generated.*/ void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { const IterVarNode* thread_axis = op->node.as(); CHECK(thread_axis); const VarNode* var = thread_axis->var.get(); diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 32c732c21740e..6e653cec3c3bc 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -92,8 +92,8 @@ VisitStmt_(const IfThenElseNode* op) { Stmt IRMutatorWithAnalyzer:: VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == tir::attr::thread_extent || + op->attr_key == tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_->Bind(iv->var, diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0802845bc4a1e..774c47666b17b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -40,7 +40,7 @@ using runtime::PackedFunc; using tir::LoweredFunc; bool LLVMEnabled() { - const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm"); + const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); return pf != nullptr; } diff --git a/src/ir/type.cc b/src/ir/type.cc index b8f0a5c47a568..5b038218c1272 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -45,6 +45,27 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); +PointerType::PointerType(Type element_type) { + ObjectPtr n = make_object(); + n->element_type = std::move(element_type); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PointerTypeNode); + +TVM_REGISTER_GLOBAL("ir.PointerType") +.set_body_typed([](Type element_type) { + return PointerType(element_type); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->Print(node->element_type); + p->stream << '*'; +}); + + TypeVar::TypeVar(std::string name, TypeKind kind) { ObjectPtr n = make_object(); n->name_hint = std::move(name); diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 81efbfcb90f13..ab2cb67cfa4d1 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -139,7 +139,7 @@ bool RuntimeEnabled(const std::string& target) { } else if (target == "vulkan") { f_name = "device_api.vulkan"; } else if (target == "stackvm") { - f_name = "codegen.build_stackvm"; + f_name = "target.build.stackvm"; } else if (target == "rpc") { f_name = "device_api.rpc"; } else if (target == "micro_dev") { diff --git a/src/target/build_common.h b/src/target/build_common.h index 5565f40430617..47ec8f032c404 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -26,6 +26,9 @@ #include #include +#include +#include +#include #include #include #include @@ -51,6 +54,31 @@ ExtractFuncInfo(const Array& funcs) { } return fmap; } + +inline std::unordered_map +ExtractFuncInfo(const IRModule& mod) { + std::unordered_map fmap; + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + auto f = Downcast(kv.second); + + runtime::FunctionInfo info; + for (size_t i = 0; i < f->params.size(); ++i) { + info.arg_types.push_back(f->params[i].dtype()); + } + auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); + if (thread_axis.defined()) { + for (size_t i = 0; i < thread_axis.size(); ++i) { + info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); + } + } + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + fmap[static_cast(global_symbol)] = info; + } + return fmap; +} } // namespace codegen } // namespace tvm #endif // TVM_TARGET_BUILD_COMMON_H_ diff --git a/src/target/codegen.cc b/src/target/codegen.cc index ee5e6a62b6462..7dc23b6f9fd1c 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -23,7 +23,12 @@ */ #include #include + +#include #include +#include + +#include #include #include #include @@ -37,6 +42,63 @@ namespace tvm { namespace codegen { +// The new build function. +// adapt the old function to the new one +runtime::Module BuildForIRModule(const IRModule& module, + const Target& target) { + std::string build_f_name = "target.build." + target->target_name; + // the build function. + const PackedFunc* bf = runtime::Registry::Get(build_f_name); + CHECK(bf != nullptr) + << "target.build." << target << " is not enabled"; + return (*bf)(module, target->str()); +} + + + +// convert legacy LoweredFunc to PrimFunc. +tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) { + // remap args to attach type annotations. + Array args; + Map remap_vars; + + for (auto var : from->args) { + if (from->handle_data_type.count(var)) { + tir::Var new_var(var->name_hint, + PointerType(PrimType(var->dtype))); + args.push_back(new_var); + remap_vars.Set(var, new_var); + } else { + args.push_back(var); + } + } + tir::PrimFunc func(args, Substitute(from->body, remap_vars)); + + func = WithAttr(std::move(func), attr::kGlobalSymbol, runtime::String(from->name)); + func = WithAttr(std::move(func), tir::attr::kDeviceThreadAxis, from->thread_axis); + if (from->func_type == tir::LoweredFuncType::kDeviceFunc) { + func = WithAttr(std::move(func), + attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); + } + if (from->is_restricted) { + func = WithAttr(std::move(func), tir::attr::kNoAlias, Integer(1)); + } + return func; +} + +IRModule ToIRModule(const Array& funcs) { + Map functions; + for (size_t i = 0; i < funcs.size(); ++i) { + auto f = funcs[i]; + tir::PrimFunc pf = ToPrimFunc(f); + if (i == 0) { + pf = WithAttr(std::move(pf), tir::attr::kIsEntryFunc, Integer(1)); + } + functions.Set(GlobalVar(f->name), pf); + } + return IRModule(functions); +} + runtime::Module Build(const Array& funcs, const std::string& target) { std::string mode = target; @@ -51,15 +113,10 @@ runtime::Module Build(const Array& funcs, transformed_funcs.push_back(func); } } - std::string build_f_name = "codegen.build_" + mode; - // the build function. - const PackedFunc* bf = runtime::Registry::Get(build_f_name); - CHECK(bf != nullptr) - << "Target " << target << " is not enabled"; - runtime::Module m = transformed_funcs.empty() ? - (*bf)(funcs, target) : - (*bf)(transformed_funcs, target); - return m; + + return BuildForIRModule( + transformed_funcs.size() != 0 ? ToIRModule(transformed_funcs) : ToIRModule(funcs), + Target::Create(target)); } /*! \brief Helper class to serialize module */ diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 961ff418e2a53..3d1654c24e4f5 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -59,7 +59,7 @@ static inline int DetectROCMmaxThreadsPerBlock() { // AMDGPU code generator. class CodeGenAMDGPU : public CodeGenLLVM { public: - void AddFunction(const LoweredFunc& f) final { + void AddFunction(const PrimFunc& f) final { // add function as void return value CodeGenLLVM::AddFunctionInternal(f, true); function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); @@ -91,7 +91,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca( - LLVMType(op->dtype), ConstInt32(constant_size)); + DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 @@ -106,7 +106,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size); + llvm::Type* type = llvm::ArrayType::get( + DTypeToLLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 llvm::GlobalVariable *global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", @@ -120,7 +121,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } } buf = builder_->CreatePointerCast( - buf, LLVMType(op->dtype)->getPointerTo( + buf, DTypeToLLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; @@ -170,7 +171,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { // Additional optimization hook to tweak the builder. } - unsigned GetGlobalAddressSpace() { + unsigned GetGlobalAddressSpace() const final { return 1; } @@ -205,7 +206,7 @@ inline int DetectROCMComputeVersion(const std::string& target) { return 900; } -runtime::Module BuildAMDGPU(Array funcs, std::string target) { +runtime::Module BuildAMDGPU(IRModule mod, std::string target) { #if TVM_LLVM_VERSION < 90 LOG(FATAL) << "AMDGPU backend requires at least LLVM 9"; // Lower versions will crash when loading the bitcode, see @@ -222,8 +223,13 @@ runtime::Module BuildAMDGPU(Array funcs, std::string target) { std::unique_ptr tm = GetLLVMTargetMachine(config.str()); std::unique_ptr cg(new CodeGenAMDGPU()); std::unique_ptr ctx(new llvm::LLVMContext()); - cg->Init(funcs[0]->name, tm.get(), ctx.get(), false, false); - for (LoweredFunc f : funcs) { + + cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false); + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + auto f = Downcast(kv.second); cg->AddFunction(f); } @@ -306,10 +312,10 @@ runtime::Module BuildAMDGPU(Array funcs, std::string target) { std::string hsaco = (*f)(arr); std::string ll(data_ll.begin(), data_ll.end()); - return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly); + return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly); } -TVM_REGISTER_GLOBAL("codegen.build_rocm") +TVM_REGISTER_GLOBAL("target.build.rocm") .set_body_typed(BuildAMDGPU); } // namespace codegen diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 88ca6b6da4990..ba3dee73dbe59 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -122,11 +122,15 @@ void CodeGenCPU::Init(const std::string& module_name, this->InitGlobalContext(dynamic_lookup); } -void CodeGenCPU::AddFunction(const LoweredFunc& f) { +void CodeGenCPU::AddFunction(const PrimFunc& f) { CodeGenLLVM::AddFunction(f); if (f_tvm_register_system_symbol_ != nullptr) { + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; export_system_symbols_.emplace_back( - std::make_pair(f->name, builder_->CreatePointerCast(function_, t_void_p_))); + std::make_pair(global_symbol.operator std::string(), + builder_->CreatePointerCast(function_, t_void_p_))); } AddDebugInformation(function_); } @@ -328,7 +332,7 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { arg_types.push_back(v->getType()); } llvm::FunctionType* ftype = llvm::FunctionType::get( - LLVMType(op->dtype), arg_types, false); + GetLLVMType(GetRef(op)), arg_types, false); // Check if it is available in global function table as injected function. auto it = gv_func_map_.find(op->name); if (it != gv_func_map_.end()) { @@ -693,8 +697,8 @@ CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, ret_value, *ret_tcode})); DataType r_api_type = tir::APIType(r_type); *rvalue = builder_->CreateAlignedLoad( - builder_->CreatePointerCast(ret_value, - LLVMType(r_api_type)->getPointerTo()), + builder_->CreatePointerCast( + ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()), 8); *rvalue = CreateCast(r_api_type, r_type, *rvalue); return end_block; @@ -873,7 +877,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { this->CreateStaticInit(op->value.as()->value, op->body); } else if (op->attr_key == tir::attr::compute_scope) { this->CreateComputeScope(op); - } else if (attr::IsPragmaKey(op->attr_key)) { + } else if (tir::attr::IsPragmaKey(op->attr_key)) { if (op->attr_key == "pragma_parallel_stride_pattern") { CHECK(parallel_env_.penv != nullptr) << "Pragma parallel_stride_pattern only valid in parallel launch"; diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 5838735711468..aa8371c39a5ca 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -42,7 +42,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) override; - void AddFunction(const LoweredFunc& f) override; + void AddFunction(const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; std::unique_ptr Finish() override; void VisitStmt_(const AssertStmtNode* op) override; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9fe6cd8cc73da..68d004cc481b7 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -24,6 +24,7 @@ // Part of the code are adapted from Halide's CodeGen_LLVM #include #include +#include #include @@ -94,7 +95,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { } } -void CodeGenLLVM::AddFunction(const LoweredFunc& f) { +void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); } @@ -107,41 +108,43 @@ void CodeGenLLVM::InitFuncState() { } -void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { +void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { this->InitFuncState(); - std::vector arg_types; - is_restricted_ = f->is_restricted; - for (Var arg : f->args) { - DataType t = arg.dtype(); - if (t.is_handle()) { - auto it = f->handle_data_type.find(arg); - if (it != f->handle_data_type.end()) { - arg_types.push_back(LLVMType((*it).second.dtype()) - ->getPointerTo(GetGlobalAddressSpace())); - } else { - arg_types.push_back(t_int8_->getPointerTo(GetGlobalAddressSpace())); - } - if (!is_restricted_) { - alias_var_set_.insert(arg.get()); - } - } else { - arg_types.push_back(LLVMType(arg.dtype())); + + CHECK_EQ(f->buffer_map.size(), 0U) + << "Cannot codegen function with buffer_map, please lower them first"; + + std::vector param_types; + is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias); + for (Var param : f->params) { + param_types.push_back(GetLLVMType(param)); + if (!is_restricted_ && param.dtype().is_handle()) { + alias_var_set_.insert(param.get()); } } + // TODO(tvm-team): + // Update the function type to respect the ret_type field of f. + // Once we allow more flexibility in the PrimFunc. llvm::FunctionType* ftype = llvm::FunctionType::get( - ret_void ? t_void_ : t_int_, arg_types, false); - CHECK(module_->getFunction(f->name) == nullptr) - << "Function " << f->name << " already exist in module"; + ret_void ? t_void_ : t_int_, param_types, false); + + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; + CHECK(module_->getFunction(static_cast(global_symbol)) == nullptr) + << "Function " << global_symbol << " already exist in module"; + function_ = llvm::Function::Create( ftype, llvm::Function::ExternalLinkage, - f->name, module_.get()); + global_symbol.operator std::string(), module_.get()); function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); + // set var map and align information auto arg_it = function_->arg_begin(); - for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) { + for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) { llvm::Argument* v = &(*arg_it); - const Var& var = f->args[i]; + const Var& var = f->params[i]; var_map_[var.get()] = v; if (is_restricted_) { if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) { @@ -157,6 +160,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(entry); this->VisitStmt(f->body); + if (ret_void) { builder_->CreateRetVoid(); } else { @@ -295,33 +299,51 @@ int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) co return native_vector_bits_; } -unsigned CodeGenLLVM::GetGlobalAddressSpace() { +unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; } -llvm::Type* CodeGenLLVM::LLVMType(const DataType& t) const { - if (t.is_handle()) { - CHECK_EQ(t.lanes(), 1); +llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { + if (dtype.is_handle()) { + CHECK_EQ(dtype.lanes(), 1); return t_void_p_; } llvm::Type* etype = nullptr; - if (t.is_int() || t.is_uint()) { - etype = llvm::Type::getIntNTy(*ctx_, t.bits()); - } else if (t.is_float()) { - switch (t.bits()) { + if (dtype.is_int() || dtype.is_uint()) { + etype = llvm::Type::getIntNTy(*ctx_, dtype.bits()); + } else if (dtype.is_float()) { + switch (dtype.bits()) { case 16: etype = llvm::Type::getHalfTy(*ctx_); break; case 32: etype = llvm::Type::getFloatTy(*ctx_); break; case 64: etype = llvm::Type::getDoubleTy(*ctx_); break; - default: LOG(FATAL) << "do not support " << t; + default: LOG(FATAL) << "do not support " << dtype; } } - if (t.lanes() != 1) { - return llvm::VectorType::get(etype, t.lanes()); + if (dtype.lanes() != 1) { + return llvm::VectorType::get(etype, dtype.lanes()); } else { return etype; } } +llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { + if (auto* ptr = type.as()) { + return DTypeToLLVMType(ptr->dtype); + } else if (auto* ptr = type.as()) { + // TODO(tvm-team) consider put storage scope into the pointer type. + return GetLLVMType(ptr->element_type)->getPointerTo(GetGlobalAddressSpace()); + } else if (IsVoidType(type)) { + return t_void_; + } else { + LOG(FATAL) << "Type " << type << " does not have a corresponding LLVM Type"; + return t_void_; + } +} + +llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const { + return GetLLVMType(GetType(expr)); +} + // Add tbaa alias information for load // // use a binary tree typed system to declare information @@ -471,7 +493,8 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { } llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { - llvm::Value* mask = llvm::UndefValue::get(LLVMType(DataType::Int(32, target_lanes))); + llvm::Value* mask = llvm::UndefValue::get( + DTypeToLLVMType(DataType::Int(32, target_lanes))); int num_elems = static_cast(vec->getType()->getVectorNumElements()); if (num_elems == target_lanes) return vec; CHECK_LT(num_elems, target_lanes); @@ -552,16 +575,16 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, // cast operatpr llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { - llvm::Type * target = LLVMType(to); + llvm::Type * target = DTypeToLLVMType(to); if (value->getType() == target) return value; if (to.is_handle()) { return builder_->CreateBitCast(value, target); } else if (to.is_uint() && to.bits() == 1) { if (from.is_float()) { - llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.); + llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.); return builder_->CreateFCmpONE(value, zero); } else { - llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0); + llvm::Constant* zero = llvm::ConstantInt::get(DTypeToLLVMType(from), 0); return builder_->CreateICmpNE(value, zero); } } else if (!from.is_float() && !to.is_float()) { @@ -570,7 +593,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va return builder_->CreateFPToSI(value, target); } else if (from.is_float() && to.is_uint()) { if (to.bits() < 8) { - value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8))); + value = builder_->CreateFPToUI(value, DTypeToLLVMType(to.with_bits(8))); return builder_->CreateIntCast(value, target, false); } else { return builder_->CreateFPToUI(value, target); @@ -610,7 +633,7 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr( CHECK_EQ(t.lanes(), 1); llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); CHECK(btype != nullptr); - llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace()); + llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace()); if (btype != ptype) { buffer = builder_->CreatePointerCast(buffer, ptype); } @@ -623,7 +646,8 @@ llvm::Value* CodeGenLLVM::CreateBufferVecPtr( CHECK_GT(t.lanes(), 1); llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); CHECK(btype != nullptr); - llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace()); + llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo( + btype->getAddressSpace()); if (btype != ptype) { buffer = builder_->CreatePointerCast(buffer, ptype); } @@ -644,7 +668,7 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) { arg_type.push_back(arg_value.back()->getType()); } llvm::FunctionType* ftype = llvm::FunctionType::get( - LLVMType(op->dtype), arg_type, false); + GetLLVMType(GetRef(op)), arg_type, false); llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { f = llvm::Function::Create( @@ -669,7 +693,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { sig_type.push_back(arg_value.back()->getType()); } } - llvm::Type *return_type = LLVMType(op->dtype); + llvm::Type *return_type = GetLLVMType(GetRef(op)); if (sig_type.size() > 0 && return_type != sig_type[0]) { sig_type.insert(sig_type.begin(), return_type); } @@ -722,7 +746,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); uint64_t val = (high << 32U) | low; - return llvm::ConstantInt::get(LLVMType(op->dtype), val); + return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; @@ -748,7 +772,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { value->addIncoming(else_value, else_value_block); return value; } else if (op->is_intrinsic(CallNode::reinterpret)) { - llvm::Type * target = LLVMType(op->dtype); + llvm::Type * target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); } else if (op->is_intrinsic(CallNode::isnan)) { // TODO(hgt312): set fast math flag @@ -802,11 +826,11 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) { return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value)); } llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { - return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value); + return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value); } llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { - return llvm::ConstantFP::get(LLVMType(op->dtype), op->value); + return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value); } llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { @@ -970,7 +994,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { CHECK_EQ(ramp->lanes, t.lanes()); llvm::Value* ptr = CreateBufferPtr( t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace)); + ptr = builder_->CreatePointerCast( + ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); AddAliasInfo(load, op->buffer_var.get(), op->index, t); return load; @@ -979,7 +1004,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { } // scalarized load. int basic_align = t.bits() / 8; - llvm::Value* ret = llvm::UndefValue::get(LLVMType(t)); + llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t)); auto f = [&](int i, llvm::Value* index) { llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); llvm::LoadInst* load = builder_->CreateAlignedLoad( @@ -1007,7 +1032,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { - llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->dtype)); + llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype)); for (int i = 0; i < op->lanes; ++i) { vec = builder_->CreateInsertElement( vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), @@ -1066,7 +1091,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { CHECK_EQ(ramp->lanes, t.lanes()); llvm::Value* ptr = CreateBufferPtr( t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace)); + ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.dtype()); return; @@ -1147,7 +1172,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { } llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca( - LLVMType(op->dtype), ConstInt32(constant_size)); + DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 @@ -1160,7 +1185,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { buf = alloca; } buf = builder_->CreatePointerCast( - buf, LLVMType(op->dtype)->getPointerTo( + buf, DTypeToLLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; @@ -1168,7 +1193,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { } void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index de94bf3b2f881..6249aa4f74bc5 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -25,12 +25,17 @@ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #ifdef TVM_LLVM_VERSION +#include +#include #include #include #include #include +#include #include #include + + #include #include #include @@ -78,7 +83,7 @@ class CodeGenLLVM : * \brief Compile and add function f to the current module. * \param f The function to be added. */ - virtual void AddFunction(const LoweredFunc& f); + virtual void AddFunction(const PrimFunc& f); /*! * \brief Add main function as the entry name * \param entry_func_name The name of entry function to be added. @@ -167,7 +172,7 @@ class CodeGenLLVM : * \return The result. */ template - inline llvm::AllocaInst* WithFunctionEntry(F falloca) { + llvm::AllocaInst* WithFunctionEntry(F falloca) { llvm::BasicBlock* current = builder_->GetInsertBlock(); llvm::BasicBlock* entry = &(function_->getEntryBlock()); builder_->SetInsertPoint(entry, entry->begin()); @@ -198,18 +203,35 @@ class CodeGenLLVM : // Get the maximim storage align bits of buffer pointer given storage scope. virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const; // Get correct address space depending on the backend - virtual unsigned GetGlobalAddressSpace(); - - void AddFunctionInternal(const LoweredFunc& f, bool ret_void); + virtual unsigned GetGlobalAddressSpace() const; + void AddFunctionInternal(const PrimFunc& f, bool ret_void); // Create extern call llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name, const std::vector& value); /*! - * \param t The original type. - * \return LLVM type of t + * \brief Get the LLVM Type for a given runtime type. + * \param dtype The runtime dtype. + * + * \note Only use this function for dealing with PrimTypes. + * For Call and Var that could have more refined types, + * use GetLLVMType instead. + * + * \return LLVM type of dtype + */ + llvm::Type* DTypeToLLVMType(const DataType& dtype) const; + /*! + * \brief Get the LLVM Type for a given type. + * \param dtype The runtime dtype. + * \param type The corresponding TVM Type. + */ + llvm::Type* GetLLVMType(const Type& type) const; + /*! + * \brief Get the LLVM Type for a given type. + * \param dtype The runtime dtype. + * \param type The corresponding TVM Type. */ - llvm::Type* LLVMType(const DataType& t) const; + llvm::Type* GetLLVMType(const PrimExpr& expr) const; // initialize the function state. void InitFuncState(); // Get alignment given index. diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 821232ded1705..48c7968fb12cc 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -34,7 +34,7 @@ namespace codegen { // NVPTX code generator. class CodeGenNVPTX : public CodeGenLLVM { public: - void AddFunction(const LoweredFunc& f) final { + void AddFunction(const PrimFunc& f) final { // add function as void return value CodeGenLLVM::AddFunctionInternal(f, true); // annotate as kernel function @@ -68,7 +68,7 @@ class CodeGenNVPTX : public CodeGenLLVM { // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca( - LLVMType(op->dtype), ConstInt32(constant_size)); + DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 @@ -83,7 +83,8 @@ class CodeGenNVPTX : public CodeGenLLVM { << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size); + llvm::Type* type = llvm::ArrayType::get( + DTypeToLLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 llvm::GlobalVariable *global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", @@ -97,7 +98,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } } buf = builder_->CreatePointerCast( - buf, LLVMType(op->dtype)->getPointerTo( + buf, DTypeToLLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; @@ -190,7 +191,7 @@ inline int DetectCUDAComputeVersion() { } } -runtime::Module BuildNVPTX(Array funcs, std::string target) { +runtime::Module BuildNVPTX(IRModule mod, std::string target) { InitializeLLVM(); CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx"); @@ -202,8 +203,13 @@ runtime::Module BuildNVPTX(Array funcs, std::string target) { std::unique_ptr tm = GetLLVMTargetMachine(config.str()); std::unique_ptr cg(new CodeGenNVPTX()); std::unique_ptr ctx(new llvm::LLVMContext()); - cg->Init(funcs[0]->name, tm.get(), ctx.get(), false, false); - for (LoweredFunc f : funcs) { + + cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false); + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + auto f = Downcast(kv.second); cg->AddFunction(f); } @@ -249,10 +255,10 @@ runtime::Module BuildNVPTX(Array funcs, std::string target) { #endif pass.run(*module); std::string ptx(data_ptx.begin(), data_ptx.end()); - return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), ll); + return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll); } -TVM_REGISTER_GLOBAL("codegen.build_nvptx") +TVM_REGISTER_GLOBAL("target.build.nvptx") .set_body_typed(BuildNVPTX); } // namespace codegen diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 05467633ce5cd..cae4cf590736d 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -88,7 +88,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { if (from.lanes() >= 16 && has_avx512) { return CallVectorIntrin( ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, - LLVMType(DataType::Float(32, from.lanes())), + DTypeToLLVMType(DataType::Float(32, from.lanes())), { MakeValue(tir::CallNode::make( DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, @@ -103,7 +103,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { if (from.lanes() >= 8 && has_f16c) { return CallVectorIntrin( - ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())), + ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, + DTypeToLLVMType(DataType::Float(32, from.lanes())), {MakeValue(tir::CallNode::make( DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, tir::CallNode::PureIntrinsic))}); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index c04b257279723..6cf9f112444c3 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include #include "llvm_common.h" @@ -192,21 +193,39 @@ class LLVMModuleNode final : public runtime::ModuleNode { return ""; } - void Init(const Array& funcs, std::string target) { + void Init(const IRModule& mod, std::string target) { InitializeLLVM(); tm_ = GetLLVMTargetMachine(target); bool system_lib = (target.find("-system-lib") != std::string::npos); - CHECK_NE(funcs.size(), 0U); ctx_ = std::make_shared(); std::unique_ptr cg = CodeGenLLVM::Create(tm_.get()); - entry_func_ = funcs[0]->name; - cg->Init(funcs[0]->name, tm_.get(), ctx_.get(), system_lib, system_lib); - for (LoweredFunc f : funcs) { + + std::vector funcs; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "Can only lower IR Module with PrimFuncs"; + auto f = Downcast(kv.second); + if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()); + entry_func_ = global_symbol; + } + funcs.push_back(f); + } + CHECK_NE(funcs.size(), 0U); + // TODO(tqchen): remove the entry function behavior as it does not + // makes sense when we start to use multiple modules. + cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib); + + for (const auto& f : funcs) { cg->AddFunction(f); } - cg->AddMainFunction(funcs[0]->name); - module_ = cg->Finish(); + if (entry_func_.length() != 0) { + cg->AddMainFunction(entry_func_); + } + + module_ = cg->Finish(); module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, target)); module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); @@ -349,12 +368,14 @@ unsigned LookupLLVMIntrinsic(const std::string& name) { return llvm::Function::lookupIntrinsicID(name); } -TVM_REGISTER_GLOBAL("codegen.build_llvm") -.set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->Init(args[0].operator Array(), args[1].operator std::string()); - *rv = runtime::Module(n); - }); + +TVM_REGISTER_GLOBAL("target.build.llvm") +.set_body_typed([](IRModule mod, std::string target) { + auto n = make_object(); + n->Init(mod, target); + return runtime::Module(n); +}); + TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body([](TVMArgs args, TVMRetValue *rv) { diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 67aa09344c072..634fb9a57f27a 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -127,15 +127,23 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { return ptx; } -runtime::Module BuildCUDA(Array funcs) { +runtime::Module BuildCUDA(IRModule mod) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenCUDA cg; cg.Init(output_ssa); - for (LoweredFunc f : funcs) { + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "CodeGenCUDA: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + CHECK(calling_conv.defined() && + calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } + std::string code = cg.Finish(); if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { @@ -151,10 +159,10 @@ runtime::Module BuildCUDA(Array funcs) { } else { ptx = NVRTCCompile(code, cg.need_include_path()); } - return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code); + return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("codegen.build_cuda") +TVM_REGISTER_GLOBAL("target.build.cuda") .set_body_typed(BuildCUDA); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index 876b1002e05e5..c6011cd4dc87f 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -31,16 +31,26 @@ namespace tvm { namespace codegen { -runtime::Module BuildAOCL(Array funcs, std::string target_str, +runtime::Module BuildAOCL(IRModule mod, + std::string target_str, bool emulation) { // Get code. using tvm::runtime::Registry; bool output_ssa = false; CodeGenOpenCL cg; cg.Init(output_ssa); - for (LoweredFunc f : funcs) { + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "CodegenOpenCL: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + CHECK(calling_conv.defined() && + calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } + std::string code = cg.Finish(); if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) { code = (*f)(code).operator std::string(); @@ -68,15 +78,15 @@ runtime::Module BuildAOCL(Array funcs, std::string target_str, std::string aocxbin; runtime::LoadBinaryFromFile("aocl.aocx", &aocxbin); - return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(funcs), code); + return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("codegen.build_aocl") +TVM_REGISTER_GLOBAL("target.build.aocl") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildAOCL(args[0], args[1], false); }); -TVM_REGISTER_GLOBAL("codegen.build_aocl_sw_emu") +TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildAOCL(args[0], args[1], true); }); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 91020555e5c8a..0cb47427ff1c1 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -35,7 +35,7 @@ void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; } -void CodeGenC::InitFuncState(LoweredFunc f) { +void CodeGenC::InitFuncState(const PrimFunc& f) { alloc_storage_scope_.clear(); handle_data_type_.clear(); CodeGenSourceBase::ClearFuncState(); @@ -72,39 +72,46 @@ void CodeGenC::ReserveKeywordsAsUnique() { GetUniqueName("return"); } -void CodeGenC::AddFunction(LoweredFunc f) { +void CodeGenC::AddFunction(const PrimFunc& f) { // clear previous generated state. this->InitFuncState(f); // reserve keywords ReserveKeywordsAsUnique(); - // add to alloc buffer type. - for (const auto & kv : f->handle_data_type) { - RegisterHandleType(kv.first.get(), kv.second.dtype()); - } - this->stream << "void " << f->name << "("; - for (size_t i = 0; i < f->args.size(); ++i) { - Var v = f->args[i]; + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); + + this->PrintFuncPrefix(); + this->stream << " " << static_cast(global_symbol) << "("; + + for (size_t i = 0; i < f->params.size(); ++i) { + tir::Var v = f->params[i]; std::string vid = AllocVarID(v.get()); if (i != 0) stream << ", "; if (v.dtype().is_handle()) { auto it = alloc_storage_scope_.find(v.get()); - if (it != alloc_storage_scope_.end()) + if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); - stream << ' '; + stream << ' '; + } - if (handle_data_type_.count(v.get())) { - PrintType(handle_data_type_.at(v.get()), stream); - } else { - stream << "void"; + PrintType(GetType(v), stream); + // Register handle data type + // TODO(tvm-team): consider simply keep type info in the + // type annotation(via a normalizing rewriting). + if (auto* ptr = v->type_annotation.as()) { + if (auto* prim = ptr->element_type.as()) { + RegisterHandleType(v.get(), prim->dtype); + } } - stream << "*"; - if (f->is_restricted && restrict_keyword_.length() != 0) { + if (no_alias && restrict_keyword_.length() != 0) { stream << ' ' << restrict_keyword_; } } else { - PrintType(v.dtype(), stream); + PrintType(GetType(v), stream); } stream << ' ' << vid; } @@ -112,11 +119,19 @@ void CodeGenC::AddFunction(LoweredFunc f) { this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); + this->PrintFinalReturn(); this->EndScope(func_scope); this->PrintIndent(); this->stream << "}\n\n"; } +void CodeGenC::PrintFuncPrefix() { + stream << "void"; +} + +void CodeGenC::PrintFinalReturn() { +} + std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } @@ -275,7 +290,6 @@ std::string CodeGenC::GetStructRef( } } - bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const { auto it = handle_data_type_.find(buf_var); if (it == handle_data_type_.end()) return false; @@ -370,6 +384,20 @@ void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } +void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) + if (auto* ptr = type.as()) { + return PrintType(ptr->dtype, os); + } else if (auto* ptr = type.as()) { + PrintType(ptr->element_type, os); + os << '*'; + } else if (IsVoidType(type)) { + os << "void"; + } else { + LOG(FATAL) << "Type " << type << " does not have a corresponding C Type"; + } +} + + inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) if (op->dtype == DataType::Int(32)) { std::ostringstream temp; diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index c6da1c4ceb9f7..a9da780876122 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -26,9 +26,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -62,8 +64,9 @@ class CodeGenC : /*! * \brief Add the function to the generated module. * \param f The function to be compiled. + * \param whether to append return 0 in the end. */ - void AddFunction(LoweredFunc f); + void AddFunction(const PrimFunc& f); /*! * \brief Finalize the compilation and return the code. * \return The code. @@ -92,16 +95,26 @@ class CodeGenC : return os.str(); } // The following parts are overloadable print operations. + /*! + * \brief Print the function header before the argument list + * + * Example: stream << "void"; + */ + virtual void PrintFuncPrefix(); // NOLINT(*) + /*! + * \brief Print the final return at the end the function. + */ + virtual void PrintFinalReturn(); // NOLINT(*) /*! * \brief Insert statement before function body. * \param f The function to be compiled. */ - virtual void PreFunctionBody(LoweredFunc f) {} + virtual void PreFunctionBody(const PrimFunc& f) {} /*! * \brief Initialize codegen state for generating f. * \param f The function to be compiled. */ - virtual void InitFuncState(LoweredFunc f); + virtual void InitFuncState(const PrimFunc& f); // expression void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) @@ -148,6 +161,12 @@ class CodeGenC : * \param os The stream to print the ctype into */ virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) + /*! + * Print Type represetnation of type type. + * \param type The type representation. + * \param os The stream to print the ctype into + */ + virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*) /*! * \brief Print expr representing the thread tag * \param IterVar iv The thread index to be binded; @@ -223,12 +242,6 @@ class CodeGenC : // override void PrintSSAAssign( const std::string& target, const std::string& src, DataType t) final; - /*! \brief restrict keyword */ - std::string restrict_keyword_{""}; - /*! \brief the storage scope of allocation */ - std::unordered_map alloc_storage_scope_; - /*! \brief the data type of allocated buffers */ - std::unordered_map handle_data_type_; /*! \brief reserves common C keywords */ void ReserveKeywordsAsUnique(); @@ -237,6 +250,13 @@ class CodeGenC : return volatile_buf_.count(buf_var) != 0; } + /*! \brief restrict keyword */ + std::string restrict_keyword_{""}; + /*! \brief the storage scope of allocation */ + std::unordered_map alloc_storage_scope_; + /*! \brief the data type of allocated buffers */ + std::unordered_map handle_data_type_; + private: /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 64783f6b31bae..cbdec62017425 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -41,59 +41,16 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) { CodeGenC::Init(output_ssa); } -void CodeGenCHost::AddFunction(LoweredFunc f) { - // clear previous generated state. - this->InitFuncState(f); - // reserve keywords - ReserveKeywordsAsUnique(); - // add to alloc buffer type. - for (const auto & kv : f->handle_data_type) { - RegisterHandleType(kv.first.get(), kv.second.dtype()); - } - - this->stream << "#ifdef __cplusplus\n"; - this->stream << "extern \"C\"\n"; - this->stream << "#endif\n"; - this->stream << "TVM_DLL int32_t " << f->name << "("; - for (size_t i = 0; i < f->args.size(); ++i) { - Var v = f->args[i]; - std::string vid = AllocVarID(v.get()); - if (i != 0) stream << ", "; - if (v.dtype().is_handle()) { - auto it = alloc_storage_scope_.find(v.get()); - if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, stream); - } - stream << ' '; - - if (handle_data_type_.count(v.get())) { - PrintType(handle_data_type_.at(v.get()), stream); - } else { - stream << "void"; - } - stream << "*"; - - if (f->is_restricted && restrict_keyword_.length() != 0) { - stream << ' ' << restrict_keyword_; - } - } else { - PrintType(v.dtype(), stream); - } - stream << ' ' << vid; - } - stream << ") {\n"; - this->PreFunctionBody(f); - int func_scope = this->BeginScope(); - this->PrintStmt(f->body); - this->PrintIndent(); - this->stream << "return 0;\n"; - this->EndScope(func_scope); - this->PrintIndent(); - this->stream << "}\n\n"; +void CodeGenCHost::PrintFuncPrefix() { // NOLINT(*) + stream << "#ifdef __cplusplus\n" + << "extern \"C\"\n" + << "#endif\n" + << "TVM_DLL int32_t"; } -std::string CodeGenCHost::Finish() { - return CodeGenC::Finish(); +void CodeGenCHost::PrintFinalReturn() { // NOLINT(*) + this->PrintIndent(); + stream << "return 0;\n"; } void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) @@ -277,20 +234,25 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, << "? (" << a_id << ") : (" << b_id << "))"; } -runtime::Module BuildCHost(Array funcs) { +runtime::Module BuildCHost(IRModule mod) { using tvm::runtime::Registry; bool output_ssa = false; bool emit_asserts = false; CodeGenCHost cg; cg.Init(output_ssa, emit_asserts); - for (LoweredFunc f : funcs) { + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "CodegenCHost: Can only take PrimFunc"; + auto f = Downcast(kv.second); cg.AddFunction(f); } + std::string code = cg.Finish(); return CSourceModuleCreate(code, "c"); } -TVM_REGISTER_GLOBAL("codegen.build_c") +TVM_REGISTER_GLOBAL("target.build.c") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildCHost(args[0]); }); diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index a29730e8629a6..4f9a0a74511fc 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -36,10 +36,10 @@ class CodeGenCHost final : public CodeGenC { public: CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts); - void AddFunction(LoweredFunc f); - std::string Finish(); void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintFuncPrefix() final; // NOLINT(*) + void PrintFinalReturn() final; // NOLINT(*) // overload visitor functions void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 24f655b01c85c..2cc7b926287f8 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -43,9 +43,9 @@ void CodeGenCUDA::Init(bool output_ssa) { CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } -void CodeGenCUDA::AddFunction(LoweredFunc f) { - this->stream << "extern \"C\" __global__ "; - CodeGenC::AddFunction(f); + +void CodeGenCUDA::PrintFuncPrefix() { + stream << "extern \"C\" __global__ void"; } std::string CodeGenCUDA::Finish() { @@ -424,11 +424,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { } void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::fragment_shape) { + if (op->attr_key == tir::attr::fragment_shape) { const VarNode* buffer = op->node.as(); const StringImmNode* shape_str = op->value.as(); fragment_shapes[buffer] = shape_str->value; - } else if (op->attr_key == attr::fragment_layout) { + } else if (op->attr_key == tir::attr::fragment_layout) { const VarNode* buffer = op->node.as(); const StringImmNode* layout_str = op->value.as(); fragment_layouts[buffer] = layout_str->value; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index a634c107f966c..c31bdf5f2d595 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -37,12 +37,12 @@ class CodeGenCUDA final : public CodeGenC { public: CodeGenCUDA(); void Init(bool output_ssa); - void AddFunction(LoweredFunc f); std::string Finish(); bool need_include_path() { return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); } // override behavior + void PrintFuncPrefix() final; void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 1358a6a280389..2f31a3e3adf10 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -31,10 +31,10 @@ namespace tvm { namespace codegen { -void CodeGenMetal::InitFuncState(LoweredFunc f) { +void CodeGenMetal::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); // analyze the data; - for (Var arg : f->args) { + for (Var arg : f->params) { if (arg.dtype().is_handle()) { alloc_storage_scope_[arg.get()] = "global"; } @@ -49,48 +49,55 @@ CodeGenMetal::CodeGenMetal() { << "};\n\n"; } -void CodeGenMetal::AddFunction(LoweredFunc f) { +void CodeGenMetal::AddFunction(const PrimFunc& f) { // clear previous generated state. this->InitFuncState(f); // skip the first underscore, so SSA variable starts from _1 GetUniqueName("_"); + // add to alloc buffer type. - for (const auto & kv : f->handle_data_type) { - RegisterHandleType(kv.first.get(), kv.second.dtype()); - } + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + // Function header. - this->stream << "kernel void " << f->name << "(\n"; + this->stream << "kernel void " << static_cast(global_symbol) << "("; + // Buffer arguments size_t num_buffer = 0; - for (size_t i = 0; i < f->args.size(); ++i, ++num_buffer) { - Var v = f->args[i]; + for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) { + Var v = f->params[i]; if (!v.dtype().is_handle()) break; stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); - CHECK(it != alloc_storage_scope_.end()); - PrintStorageScope(it->second, stream); + if (it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, stream); + } stream << ' '; - if (handle_data_type_.count(v.get())) { - PrintType(handle_data_type_.at(v.get()), stream); - stream << "*"; - } else { - PrintType(v.dtype(), stream); + PrintType(GetType(v), stream); + // Register handle data type + // TODO(tvm-team): consider simply keep type info in the + // type annotation(via a normalizing rewriting). + if (auto* ptr = v->type_annotation.as()) { + if (auto* prim = ptr->element_type.as()) { + RegisterHandleType(v.get(), prim->dtype); + } } stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; } // Setup normal arguments. - size_t nargs = f->args.size() - num_buffer; + size_t nargs = f->params.size() - num_buffer; std::string varg = GetUniqueName("arg"); if (nargs != 0) { - std::string arg_buf_type = f->name + "_args_t"; + std::string arg_buf_type = static_cast(global_symbol) + "_args_t"; stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; // declare the struct decl_stream << "struct " << arg_buf_type << " {\n"; - for (size_t i = num_buffer; i < f->args.size(); ++i) { - Var v = f->args[i]; + for (size_t i = num_buffer; i < f->params.size(); ++i) { + Var v = f->params[i]; CHECK(!v.dtype().is_handle()); std::string vid = AllocVarID(v.get()); std::ostringstream vref; @@ -113,7 +120,10 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; - for (IterVar iv : f->thread_axis) { + auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); + CHECK(thread_axis.defined()); + + for (IterVar iv : thread_axis) { runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); work_dim = std::max(work_dim, scope.dim_index + 1); } @@ -127,7 +137,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { stream << " threadIdx [[thread_position_in_threadgroup]]\n"; } // bind thread axis - for (IterVar iv : f->thread_axis) { + for (IterVar iv : thread_axis) { CHECK(!var_idmap_.count(iv->var.get())); std::string vname = iv->thread_tag; if (work_dim <= 1) { @@ -257,14 +267,23 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT } } -runtime::Module BuildMetal(Array funcs) { +runtime::Module BuildMetal(IRModule mod) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenMetal cg; cg.Init(output_ssa); - for (LoweredFunc f : funcs) { + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "CodeGenMetal: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + CHECK(calling_conv.defined() && + calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } + std::string code = cg.Finish(); std::string fmt = "metal"; std::string source = ""; @@ -273,10 +292,10 @@ runtime::Module BuildMetal(Array funcs) { code = (*f)(code).operator std::string(); fmt = "metallib"; } - return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source); + return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source); } -TVM_REGISTER_GLOBAL("codegen.build_metal") +TVM_REGISTER_GLOBAL("target.build.metal") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildMetal(args[0]); }); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 291e95c0f2286..644c962ab2d68 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -34,10 +34,10 @@ namespace codegen { class CodeGenMetal final : public CodeGenC { public: CodeGenMetal(); - void AddFunction(LoweredFunc f); // override print thread tag. void PrintArgUnionDecl(); - void InitFuncState(LoweredFunc f) final; + void AddFunction(const PrimFunc& f); // NOLINT(*) + void InitFuncState(const PrimFunc& f) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) @@ -50,9 +50,10 @@ class CodeGenMetal final : public CodeGenC { const std::string& vec, DataType t, int i, const std::string& value) final; // overload visitor void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - // overload visitor void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + // reuse parent's function. + using CodeGenC::PrintType; private: int thread_index_bits_{32}; diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 437ff6def4baa..67761c17680a8 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -35,18 +35,17 @@ CodeGenOpenCL::CodeGenOpenCL() { restrict_keyword_ = "restrict"; } -void CodeGenOpenCL::InitFuncState(LoweredFunc f) { +void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); - for (Var arg : f->args) { + for (Var arg : f->params) { if (arg.dtype().is_handle()) { alloc_storage_scope_[arg.get()] = "global"; } } } -void CodeGenOpenCL::AddFunction(LoweredFunc f) { - this->stream << "__kernel "; - CodeGenC::AddFunction(f); +void CodeGenOpenCL::PrintFuncPrefix() { + stream << "__kernel void"; } std::string CodeGenOpenCL::Finish() { @@ -239,50 +238,31 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO } } -template -inline void PrintBinaryExpr(const T* op, - const char* opstr, - std::ostream& os, - CodeGenOpenCL* p) { - if (op->dtype.lanes() == 1) { - os << opstr << "(("; - p->PrintType(op->a->dtype, os); - os << ")"; - p->PrintExpr(op->a, os); - os << ", ("; - p->PrintType(op->b->dtype, os); - os << ")"; - p->PrintExpr(op->b, os); - os << ')'; - } else { - p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os); - } -} - -void CodeGenOpenCL::VisitExpr_(const MinNode *op, std::ostream& os) { - PrintBinaryExpr(op, "min", os, this); -} - -void CodeGenOpenCL::VisitExpr_(const MaxNode *op, std::ostream& os) { - PrintBinaryExpr(op, "max", os, this); -} - -runtime::Module BuildOpenCL(Array funcs) { +runtime::Module BuildOpenCL(IRModule mod) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenOpenCL cg; cg.Init(output_ssa); - for (LoweredFunc f : funcs) { + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "CodeGenOpenCL: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + CHECK(calling_conv.defined() && + calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } + std::string code = cg.Finish(); if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) { code = (*f)(code).operator std::string(); } - return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs), code); + return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("codegen.build_opencl") +TVM_REGISTER_GLOBAL("target.build.opencl") .set_body_typed(BuildOpenCL); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 9f1c7f4c3044b..cc1fe994739f3 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -34,11 +34,11 @@ namespace codegen { class CodeGenOpenCL final : public CodeGenC { public: CodeGenOpenCL(); - void AddFunction(LoweredFunc f); std::string Finish(); // override print thread tag. - void InitFuncState(LoweredFunc f) final; + void InitFuncState(const PrimFunc& f) final; + void PrintFuncPrefix() final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) @@ -56,9 +56,6 @@ class CodeGenOpenCL final : public CodeGenC { // overload visitor void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*) - // overload min and max to avoid ambiguous call errors - void VisitExpr_(const MinNode *op, std::ostream& os) final; - void VisitExpr_(const MaxNode *op, std::ostream& os) final; private: // whether enable fp16 and fp64 extension diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc index 49f464de8912d..474859977fcbb 100644 --- a/src/target/source/codegen_opengl.cc +++ b/src/target/source/codegen_opengl.cc @@ -37,7 +37,7 @@ namespace codegen { CodeGenOpenGL::CodeGenOpenGL() : output_(nullptr), output_iter_var_(nullptr) {} -void CodeGenOpenGL::InitFuncState(LoweredFunc f) { +void CodeGenOpenGL::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); output_ = nullptr; inputs_.clear(); @@ -47,7 +47,7 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) { this->stream.str(""); } -void CodeGenOpenGL::AddFunction(LoweredFunc f) { +void CodeGenOpenGL::AddFunction(const PrimFunc& f) { // clear previous generated state. this->InitFuncState(f); @@ -56,15 +56,17 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { // skip the first underscore, so SSA variable starts from _1 GetUniqueName("_"); - // add to alloc buffer type. - for (const auto& kv : f->handle_data_type) { - RegisterHandleType(kv.first.get(), kv.second.dtype()); - } // Allocate argument names. Store in `var_idmap_`. - for (auto arg : f->args) { + for (auto arg : f->params) { auto arg_name = GetUniqueName(arg.get()->name_hint); var_idmap_[arg.get()] = arg_name; + + if (auto* ptr = arg->type_annotation.as()) { + if (auto* prim = ptr->element_type.as()) { + RegisterHandleType(arg.get(), prim->dtype); + } + } } thread_extent_var_ = GetUniqueName("thread_extent"); @@ -80,7 +82,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { this->stream << "}\n\n"; // Declare arguments. - for (auto arg : f->args) { + for (auto arg : f->params) { if (this->inputs_.find(arg.get()) != this->inputs_.cend()) { // Declare input texture. // Format: @@ -138,7 +140,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { std::vector arg_names; std::vector arg_kinds; - for (auto arg : f->args) { + for (auto arg : f->params) { std::string name = GetVarID(arg.get()); runtime::OpenGLArgKind kind; @@ -154,7 +156,11 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { arg_kinds.push_back(kind); } - shaders_[f->name] = runtime::OpenGLShader( + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; + + shaders_[static_cast(global_symbol)] = runtime::OpenGLShader( this->decl_stream.str() + this->stream.str(), std::move(arg_names), std::move(arg_kinds), this->thread_extent_var_); @@ -283,18 +289,27 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) { this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n"; } -runtime::Module BuildOpenGL(Array funcs) { +runtime::Module BuildOpenGL(IRModule mod) { bool output_ssa = false; CodeGenOpenGL cg; cg.Init(output_ssa); - for (LoweredFunc f : funcs) { + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "CodeGenOpenGL: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + CHECK(calling_conv.defined() && + calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + << "CodeGenOpenGL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } + auto shaders = cg.Finish(); - return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(funcs)); + return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(mod)); } -TVM_REGISTER_GLOBAL("codegen.build_opengl") +TVM_REGISTER_GLOBAL("target.build.opengl") .set_body_typed(BuildOpenGL); } // namespace codegen diff --git a/src/target/source/codegen_opengl.h b/src/target/source/codegen_opengl.h index 0b29c28fe144f..954806bbca59a 100644 --- a/src/target/source/codegen_opengl.h +++ b/src/target/source/codegen_opengl.h @@ -37,10 +37,10 @@ namespace codegen { class CodeGenOpenGL final : public CodeGenC { public: CodeGenOpenGL(); - void AddFunction(LoweredFunc f); std::unordered_map Finish(); - void InitFuncState(LoweredFunc f) final; + void AddFunction(const PrimFunc& f); + void InitFuncState(const PrimFunc& f) final; void BindThreadIndex(const IterVar& iv) final; void VisitStmt_(const StoreNode* op) final; std::string TexelFetch(const VarNode* buffer, PrimExpr index); diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 1e6e7601d02c0..6c1c3b9d22fc4 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -68,14 +68,13 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) { } } -void CodeGenVivadoHLS::AddFunction(LoweredFunc f) { - this->stream << "extern \"C\" "; - CodeGenC::AddFunction(f); +void CodeGenVivadoHLS::PrintFuncPrefix() { + stream << "extern \"C\" void"; } -void CodeGenVivadoHLS::PreFunctionBody(LoweredFunc f) { - for (size_t i = 0; i < f->args.size(); ++i) { - Var v = f->args[i]; +void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) { + for (size_t i = 0; i < f->params.size(); ++i) { + Var v = f->params[i]; std::string vid = GetVarID(v.get()); if (v.dtype().is_handle()) { this->stream << "#pragma HLS INTERFACE m_axi port=" << vid << " offset=slave bundle=gmem\n"; @@ -126,21 +125,34 @@ void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOL } -runtime::Module BuildSDAccel(Array funcs, std::string target_str) { +runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenVivadoHLS cg; // Generate source code for get_source(). cg.Init(output_ssa); - for (LoweredFunc f : funcs) { + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "CodeGenVHLS: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + CHECK(calling_conv.defined() && + calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } + std::string whole_code = cg.Finish(); // Generate source code for compilation. Array > kernel_info; - for (LoweredFunc f : funcs) { + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "CodeGenOpenCL: Can only take PrimFunc"; + auto f = Downcast(kv.second); CodeGenVivadoHLS cg; cg.Init(output_ssa); cg.AddFunction(f); @@ -148,7 +160,12 @@ runtime::Module BuildSDAccel(Array funcs, std::string target_str) { if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) { code = (*f)(code).operator std::string(); } - kernel_info.push_back(Array({f->name, code})); + + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + std::string func_name = global_symbol; + kernel_info.push_back(Array({func_name, code})); } std::string xclbin; @@ -158,10 +175,10 @@ runtime::Module BuildSDAccel(Array funcs, std::string target_str) { } else { LOG(FATAL) << "Cannot compile Vivado HLS code."; } - return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), whole_code); + return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(mod), whole_code); } -TVM_REGISTER_GLOBAL("codegen.build_sdaccel") +TVM_REGISTER_GLOBAL("target.build.sdaccel") .set_body_typed(BuildSDAccel); } // namespace codegen diff --git a/src/target/source/codegen_vhls.h b/src/target/source/codegen_vhls.h index fc14c93f3564d..10f9ea7679b6d 100644 --- a/src/target/source/codegen_vhls.h +++ b/src/target/source/codegen_vhls.h @@ -15,7 +15,7 @@ * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. - */ +5B5B */ /*! * \file codegen_vhls.h @@ -37,10 +37,11 @@ class CodeGenVivadoHLS final : public CodeGenC { public: void Init(bool output_ssa); void PrintType(DataType t, std::ostream& os); - void AddFunction(LoweredFunc f); - void PreFunctionBody(LoweredFunc f); - void VisitExpr_(const MinNode *op, std::ostream& os); - void VisitExpr_(const MaxNode *op, std::ostream& os); + + void PrintFuncPrefix() final; + void PreFunctionBody(const PrimFunc& f) final; + void VisitExpr_(const MinNode *op, std::ostream& os) final; + void VisitExpr_(const MaxNode *op, std::ostream& os) final; }; } // namespace codegen diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index c90b4c7eeb483..b6f9b86fbdb35 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -70,7 +70,7 @@ class SPIRVTools { spv_context ctx_; }; -runtime::Module BuildSPIRV(Array funcs) { +runtime::Module BuildSPIRV(IRModule mod) { using tvm::runtime::Registry; using tvm::runtime::VulkanShader; @@ -81,8 +81,21 @@ runtime::Module BuildSPIRV(Array funcs) { const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc"); CodeGenSPIRV cg; - for (LoweredFunc f : funcs) { - f = PointerValueTypeRewrite(f); + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "CodeGenSPIRV: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + CHECK(calling_conv.defined() && + calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; + + std::string f_name = global_symbol; + f = PointerValueTypeRewrite(std::move(f)); VulkanShader shader; shader.data = cg.BuildFunction(f); @@ -97,13 +110,14 @@ runtime::Module BuildSPIRV(Array funcs) { reinterpret_cast(dmlc::BeginPtr(shader.data))); } code_data << spirv_tools.BinaryToText(shader.data); - smap[f->name] = std::move(shader); + smap[f_name] = std::move(shader); } + return runtime::VulkanModuleCreate( - smap, ExtractFuncInfo(funcs), code_data.str()); + smap, ExtractFuncInfo(mod), code_data.str()); } -TVM_REGISTER_GLOBAL("codegen.build_vulkan") +TVM_REGISTER_GLOBAL("target.build.vulkan") .set_body_typed(BuildSPIRV); } // namespace codegen diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 4021b17d72431..0241e2218d71e 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include "codegen_spirv.h" #include "../../arith/compute_expr.h" @@ -30,18 +31,20 @@ namespace tvm { namespace codegen { -std::vector CodeGenSPIRV::BuildFunction(const LoweredFunc& f) { +std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { this->InitFuncState(); - CHECK(f->is_restricted) + CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; uint32_t num_buffer = 0; - for (Var arg : f->args) { + + for (Var arg : f->params) { DataType t = arg.dtype(); if (t.is_handle()) { - auto it = f->handle_data_type.find(arg); - if (it != f->handle_data_type.end()) { - DataType value_type = (*it).second.dtype(); + if (auto* ptr = arg->type_annotation.as()) { + auto* prim = ptr->element_type.as(); + CHECK(prim); + DataType value_type = prim->dtype; spirv::Value arg_value = builder_->BufferArgument( builder_->GetSType(value_type), 0, num_buffer); storage_info_[arg.get()].UpdateContentType(value_type); @@ -75,7 +78,11 @@ std::vector CodeGenSPIRV::BuildFunction(const LoweredFunc& f) { builder_->MakeInst(spv::OpReturn); builder_->MakeInst(spv::OpFunctionEnd); - builder_->CommitKernelFunction(func_ptr, f->name); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; + + builder_->CommitKernelFunction(func_ptr, static_cast(global_symbol)); return builder_->Finalize(); } @@ -607,7 +614,7 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { } void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 3e970032090f4..a5ccd549633db 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -53,7 +53,7 @@ class CodeGenSPIRV: * \param f The function to be added. * \return The final spirv module. */ - virtual std::vector BuildFunction(const LoweredFunc& f); + virtual std::vector BuildFunction(const PrimFunc& f); /*! * \brief Create Value for expression e * \param e The expression to be created value for. diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index bce878a58abe8..af8b34142ec90 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -21,7 +21,10 @@ * \file codegen_stackvm.cc */ #include +#include +#include #include +#include #include #include #include "codegen_stackvm.h" @@ -54,9 +57,9 @@ StackVM::StructFieldKind MapFieldKind(int64_t kind) { return StackVM::kArrData; } -StackVM CodeGenStackVM::Compile(LoweredFunc f) { - for (size_t i = 0; i < f->args.size(); ++i) { - Var v = f->args[i]; +StackVM CodeGenStackVM::Compile(const PrimFunc& f) { + for (size_t i = 0; i < f->params.size(); ++i) { + Var v = f->params[i]; int vid = AllocVarID(v.get()); CHECK_EQ(static_cast(vid), i); } @@ -525,19 +528,32 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) { this->Push(op->body); } -runtime::Module BuildStackVM(const Array& funcs) { - CHECK_NE(funcs.size(), 0U); +runtime::Module BuildStackVM(const IRModule& mod) { std::unordered_map fmap; - for (LoweredFunc f : funcs) { + std::string entry_func; + + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) + << "CodeGenStackVM: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; + std::string f_name = global_symbol; StackVM vm = codegen::CodeGenStackVM().Compile(f); - CHECK(!fmap.count(f->name)) - << "Function name " << f->name << "already exist in list"; - fmap[f->name] = std::move(vm); + CHECK(!fmap.count(f_name)) + << "Function name " << f_name << "already exist in list"; + fmap[f_name] = std::move(vm); + + if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + entry_func = f_name; + } } - return runtime::StackVMModuleCreate(fmap, funcs[0]->name); + + return runtime::StackVMModuleCreate(fmap, entry_func); } -TVM_REGISTER_GLOBAL("codegen.build_stackvm") +TVM_REGISTER_GLOBAL("target.build.stackvm") .set_body_typed(BuildStackVM); } // namespace codegen } // namespace tvm diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index 10226429f34e2..041c7a7225cf6 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -56,7 +56,7 @@ class CodeGenStackVM * \note Only call compile once, * create a new codegen object each time. */ - StackVM Compile(LoweredFunc f); + StackVM Compile(const PrimFunc& f); /*! \brief Push stmt to generate new code */ void Push(const Stmt& n); /*! \brief Push expr to generate new code */ diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 180ee12929b8a..705d2317940cb 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -91,7 +91,7 @@ Stmt MakeCrossThreadReduction( freduce_args, CallNode::Intrinsic)); reduce_body = AttrStmtNode::make( reduces[0]->combiner, - attr::reduce_scope, + tir::attr::reduce_scope, make_zero(DataType::Handle()), reduce_body); std::vector assigns(size); @@ -109,7 +109,7 @@ Stmt MakeCrossThreadReduction( body = AllocateNode::make( res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); body = AttrStmtNode::make( - res_handles[idx - 1], attr::storage_scope, StringImmNode::make("local"), body); + res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body); } body = Substitute(body, value_map); return MergeNest(nest, body); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 62c8dfd30d490..9d95e329c8f29 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -165,7 +165,8 @@ Stmt ExternOpNode::BuildProvide( const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body); + Stmt ret = AttrStmtNode::make( + make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { Array bind_spec; Array tuple; @@ -176,7 +177,7 @@ Stmt ExternOpNode::BuildProvide( tuple.push_back(buffer->shape[k]); } ret = AttrStmtNode::make( - bind_spec, attr::buffer_bind_scope, + bind_spec, tir::attr::buffer_bind_scope, CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); }; for (size_t i = output_placeholders.size(); i != 0; --i) { diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 70abf34523b98..dcd09f9f1fa80 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -186,7 +186,8 @@ Stmt HybridOpNode::BuildProvide( const std::unordered_map &dom_map, bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body); + Stmt ret = AttrStmtNode::make( + make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); std::unordered_map rmap; for (int i = 0; i < this->num_outputs(); ++i) { rmap[outputs[i]] = stage->op.output(i); diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 8bc35e33abfc0..3714f439bd2b6 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -168,7 +168,7 @@ MakeLoopNest(const Stage& stage, // annotate the extent of the IterVar if (!new_loop_var) { nest[i + 1].emplace_back( - AttrStmtNode::make(iv, attr::loop_scope, iv->var, no_op)); + AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op)); } } // message passing to get offset of root iter vars. diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 956a297f5b3c6..2ee5b273d4f60 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -287,10 +287,10 @@ Stmt ScanOpNode::BuildProvide( bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); Stmt provide = AttrStmtNode::make( - stage->op, attr::scan_update_scope, this->scan_axis->var, + stage->op, tir::attr::scan_update_scope, this->scan_axis->var, EvaluateNode::make(0)); Stmt init = AttrStmtNode::make( - stage->op, attr::scan_init_scope, 0, + stage->op, tir::attr::scan_init_scope, 0, EvaluateNode::make(0)); size_t begin_scan = 0; for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index a110bc458fe9b..bec677a032283 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -85,7 +85,7 @@ class InjectAttach : public StmtMutator { auto stmt = StmtMutator::VisitStmt(input_stmt); const AttrStmtNode* op = stmt.as(); if (op != nullptr && - op->attr_key == attr::loop_scope) { + op->attr_key == tir::attr::loop_scope) { if (attach_spec_->attach_type == kScope && op->node == attach_spec_->attach_ivar) { CHECK(!found_attach) @@ -131,8 +131,8 @@ class InjectScanStep : public StmtMutator { // update const AttrStmtNode* op = stmt.as(); if (op != nullptr && - ((op->attr_key == attr::scan_update_scope && !is_init_) || - (op->attr_key == attr::scan_init_scope && is_init_))) { + ((op->attr_key == tir::attr::scan_update_scope && !is_init_) || + (op->attr_key == tir::attr::scan_init_scope && is_init_))) { if (op->node.same_as(scan_op_)) { found_attach = true; stmt = AttrStmtNode::make( @@ -187,15 +187,15 @@ class SchedulePostProc : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::loop_scope || - op->attr_key == attr::scan_init_scope) { + if (op->attr_key == tir::attr::loop_scope || + op->attr_key == tir::attr::scan_init_scope) { return this->VisitStmt(op->body); - } else if (op->attr_key == attr::scan_update_scope) { + } else if (op->attr_key == tir::attr::scan_update_scope) { const ScanOpNode* scan = op->node.as(); CHECK(scan); var_value_[scan->scan_axis->var.get()] = op->value; return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent) { + } else if (op->attr_key == tir::attr::thread_extent) { // delete duplicated thread extent attr auto it = thread_extent_scope_.find(op->node.get()); if (it != thread_extent_scope_.end()) { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 7572f8d6bbb04..07759b3f126e6 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -32,25 +32,51 @@ namespace tvm { namespace tir { -Var::Var(std::string name_hint, DataType t) - : Var(make_object(t, name_hint)) {} +Var::Var(std::string name_hint, DataType dtype) { + auto n = make_object(); + n->name_hint = std::move(name_hint); + n->dtype = std::move(dtype); + data_ = std::move(n); +} -VarNode::VarNode(DataType t, std::string name_hint) { - this->dtype = t; - this->name_hint = std::move(name_hint); +Var::Var(std::string name_hint, Type type_annotation) { + auto n = make_object(); + n->name_hint = std::move(name_hint); + n->dtype = GetRuntimeDataType(type_annotation); + n->type_annotation = std::move(type_annotation); + data_ = std::move(n); } -SizeVar::SizeVar(std::string name_hint, DataType t) - : SizeVar(make_object(t, name_hint)) {} -SizeVarNode::SizeVarNode(DataType t, std::string name_hint) - : VarNode(t, std::move(name_hint)) {} +Var Var::copy_with_suffix(const std::string& suffix) const { + const VarNode* node = get(); + ObjectPtr new_ptr; + if (auto* ptr = this->as()) { + new_ptr = make_object(*ptr); + } else { + new_ptr = make_object(*node); + } + new_ptr->name_hint += suffix; + + return Var(new_ptr); +} + +SizeVar::SizeVar(std::string name_hint, DataType dtype) { + auto n = make_object(); + n->name_hint = std::move(name_hint); + n->dtype = std::move(dtype); + data_ = std::move(n); +} TVM_REGISTER_GLOBAL("tir.Var") -.set_body_typed([](std::string s, DataType t) { - return Var(s, t); - }); +.set_body_typed([](std::string name_hint, runtime::TVMArgValue type) { + if (type.IsObjectRef()) { + return Var(name_hint, type.operator Type()); + } else { + return Var(name_hint, type.operator DataType()); + } +}); TVM_REGISTER_GLOBAL("tir.SizeVar") .set_body_typed([](std::string s, DataType t) { diff --git a/src/tir/ir/lowered_func.cc b/src/tir/ir/lowered_func.cc index c1331fbd4c1f5..8790f2b12e396 100644 --- a/src/tir/ir/lowered_func.cc +++ b/src/tir/ir/lowered_func.cc @@ -31,5 +31,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_NODE_TYPE(LoweredFuncNode); + + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index b0736435607df..c11fb2a27eff1 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -33,14 +33,34 @@ namespace tvm { using namespace tir; +runtime::DataType GetRuntimeDataType(const Type& type) { + if (auto * n = type.as()) { + return n->dtype; + } else if (type.as()) { + return DataType::Handle(); + } else { + LOG(FATAL) << "Type " << type + << " does not have a corresponding runtime::DataType"; + return DataType::Handle(); + } +} + Type GetType(const PrimExpr& expr) { + // TODO(tqchen): add recursive type inference for Call here + // once we introduced the corresponding fields to the IR. + if (auto* ptr = expr.as()) { + // If Var has a more refined type annotation, + // return the type anotation + if (ptr->type_annotation.defined()) { + return ptr->type_annotation; + } + } + // Default: return the type indicated by the dtype. runtime::DataType dtype = expr.dtype(); // These types already implies the specific type. if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) { return PrimType(dtype); } - // TODO(tqchen): add recursive type inference for Var and Call here - // once we introduced the corresponding fields to the IR. return PrimType(dtype); } diff --git a/src/tir/pass/simple_passes.cc b/src/tir/pass/simple_passes.cc index 81145c89b6a00..93d17ba347fc4 100644 --- a/src/tir/pass/simple_passes.cc +++ b/src/tir/pass/simple_passes.cc @@ -68,6 +68,32 @@ class IRSubstitue : public StmtExprMutator { } } + PrimExpr VisitExpr_(const LoadNode* op) final { + // NOTE: we do not explicit recursivly mutate op->buffer_var + PrimExpr ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + auto it = smap_.find(op->buffer_var.get()); + if (it != smap_.end()) { + return LoadNode::make( + op->dtype, Downcast(it->second), op->index, op->predicate); + } else { + return ret; + } + } + + Stmt VisitStmt_(const StoreNode* op) final { + // NOTE: we do not explicit recursivly mutate op->buffer_var + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + auto it = smap_.find(op->buffer_var.get()); + if (it != smap_.end()) { + return StoreNode::make( + Downcast(it->second), op->value, op->index, op->predicate); + } else { + return ret; + } + } + private: const std::unordered_map& smap_; }; diff --git a/src/tir/pass/storage_rewrite.cc b/src/tir/pass/storage_rewrite.cc index 98410336db614..39f71dd629b4d 100644 --- a/src/tir/pass/storage_rewrite.cc +++ b/src/tir/pass/storage_rewrite.cc @@ -1016,6 +1016,47 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { return LoweredFunc(n); } +PrimFunc PointerValueTypeRewrite(PrimFunc f) { + auto* n = f.CopyOnWrite(); + VectorAllocRewriter rewriter; + n->body = rewriter(n->body); + + Array args; + Map remap_vars; + + for (Var var : f->params) { + if (var.dtype().is_handle()) { + const auto& tvec = rewriter.acc_map_[var.get()]; + + if (tvec.size() == 1) { + tir::Var new_var(var->name_hint, + PointerType(PrimType(tvec[0]))); + args.push_back(new_var); + remap_vars.Set(var, new_var); + + } else { + // always set data type to be non vectorized so + // load/store can still work via scalarization + if (tvec.size() != 0 && !var->type_annotation.defined()) { + tir::Var new_var(var->name_hint, + PointerType(PrimType(tvec[0].with_lanes(1)))); + args.push_back(new_var); + remap_vars.Set(var, new_var); + } else { + args.push_back(var); + } + } + } else { + args.push_back(var); + } + } + + CHECK_EQ(args.size(), n->params.size()); + n->params = args; + n->body = Substitute(n->body, remap_vars); + return f; +} + Stmt StorageRewrite(Stmt stmt) { stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); return VectorAllocRewriter()(std::move(stmt)); diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index d58ddd5a2183f..54812be62d9b9 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -108,8 +108,32 @@ def test_global_var(): assert isinstance(tvar, tvm.ir.GlobalVar) +def test_tir_var(): + nodes = [ + {"type_key": ""}, + {"type_key": "Variable", + "attrs": {"dtype": "int32", "name": "x"}}, + {"type_key": "SizeVar", + "attrs": {"dtype": "int32", "name": "y"}}, + ] + data = { + "root" : 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + x = tvm.ir.load_json(json.dumps(data)) + assert isinstance(x, tvm.tir.Var) + assert x.name == "x" + data["root"] = 2 + y = tvm.ir.load_json(json.dumps(data)) + assert isinstance(y, tvm.tir.SizeVar) + assert y.name == "y" + + if __name__ == "__main__": test_type_var() test_incomplete_type() test_func_tuple_type() test_global_var() + test_tir_var() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 3a7985dde3fea..7e2c8b55a69b2 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -265,7 +265,18 @@ def test_prim_func(): assert func.attrs is None +def test_vars(): + x = tvm.tir.Var("xyz", "int8") + assert x.dtype == "int8" + ptype = tvm.ir.PointerType(tvm.ir.PrimType("float")) + x = tvm.tir.Var("xyz", ptype) + assert x.dtype == "handle" + assert x.type_annotation == ptype + assert isinstance(ptype.element_type, tvm.ir.PrimType) + + if __name__ == "__main__": + test_vars() test_prim_func() test_cast() test_attr()