Skip to content

Commit

Permalink
[TIR][TARGET] Refactor Target codegen to use IRModule and PrimFunc.
Browse files Browse the repository at this point in the history
As part of the unified IR refactor.
This PR refactors the target codegen to use IRModule containing tir::PrimFuncs.

In order to break the refactor into several steps without breaking the codebase,
we built an conversion pass to convert Array<LoweredFunc> into IRModule.

The follow-up refactors will gradually move the passes covered by IRModule up
until we cover all the passes. Then we can remove the additional redundant
concepts such as LoweredFunc.
  • Loading branch information
tqchen committed Mar 20, 2020
1 parent 8607947 commit 3bb801e
Show file tree
Hide file tree
Showing 61 changed files with 997 additions and 408 deletions.
12 changes: 12 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <string>
#include <algorithm>
#include <limits>
#include <type_traits>

namespace tvm {

Expand Down Expand Up @@ -307,6 +308,17 @@ class Integer : public IntImm {
* \param other The other value.
*/
Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*)
/*!
* \brief Constructor from enum
* \tparam Enum The enum type.
* \param value The enum value.
*/
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
explicit Integer(ENum value) : Integer(static_cast<int>(value)) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
}
/*!
* \brief Assign an expression to integer.
* \param other another expression.
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,27 @@ constexpr const char* kCallingConv = "calling_conv";
* \sa tvm::Target
*/
constexpr const char* kTarget = "target";

/*!
* \brief Global linker symbol of the function in generated code.
*
* This option forces the code generator to name the
* function with the given.
*
* For example, we could set a global_symbol of a function
* early to make sure that we can always refer to it by
* the symbol name in the generated DLL.
*
* We should not set the attribute for local functions,
* so that the compiler can freely rename them.
*
* A unique global symbol will be automatically assigned
* to each function in the module before the target code
* generation phase.
*
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";
} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
56 changes: 54 additions & 2 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ class PrimTypeNode : public TypeNode {
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};

/*!

/*
* \brief Managed reference to PrimTypeNode.
* \sa PrimTypeNode
*/
Expand All @@ -124,11 +125,53 @@ class PrimType : public Type {
* \brief Constructor
* \param dtype The corresponding dtype.
*/
TVM_DLL PrimType(runtime::DataType dtype);
TVM_DLL explicit PrimType(runtime::DataType dtype);

TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
};


/*!
* \brief Low-level raw pointer type.
*
* PointerType represents type hints in the TIR to be
* passed to the final code generator.
*
* PointerType should not occur in the high-level analysis.
*
* \sa PointerType
*/
class PointerTypeNode : public TypeNode {
public:
/*!
* \brief The type of the element which the pointer points to.
*/
Type element_type;

void VisitAttrs(AttrVisitor* v) {
v->Visit("element_type", &element_type);
}

static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};

/*
* \brief Managed reference to PointerTypeNode.
* \sa PointerTypeNode
*/
class PointerType : public Type {
public:
/*!
* \brief Constructor
* \param element_type The type of the element which the pointer points to.
*/
TVM_DLL explicit PointerType(Type element_type);

TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode);
};


/*! \brief Possible kinds of TypeVars. */
enum TypeKind : int {
kType = 0,
Expand Down Expand Up @@ -283,6 +326,15 @@ inline Type VoidType() {
return TupleType::Empty();
}

/*!
* \brief Check whether the tyep represents void.
* \return The check result.
*/
inline bool IsVoidType(const Type& type) {
auto* n = type.as<TupleTypeNode>();
return n && n->fields.size() == 0;
}

/*!
* \brief Potential Constraints in a function.
* \sa TypeConstraint
Expand Down
47 changes: 25 additions & 22 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,43 +55,53 @@ namespace tir {
*/
class VarNode : public PrimExprNode {
public:
/*! \brief constructor */
VarNode() {}
VarNode(DataType dtype, std::string name_hint);

/*!
* \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address.
*/
std::string name_hint;
/*!
* \brief type annotaion of the variable.
*
* It is an optional field that provides a refined type of the variable than dtype.
*
* \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type.
*/
Type type_annotation;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("name", &name_hint);
v->Visit("type_annotation", &type_annotation);
}

static constexpr const char* _type_key = "Variable";
static constexpr const char* _type_key = "tir.Var";
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};

/*! \brief a named variable in TVM */
class Var : public PrimExpr {
public:
explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
/*! \brief constructor
/*!
* \brief Constructor
* \param name_hint variable name
* \param t data type
* \param dtype data type
*/
TVM_DLL explicit Var(std::string name_hint = "v",
DataType t = DataType::Int(32));
DataType dtype = DataType::Int(32));
/*!
* \brief Constructor which provides a more detailed type annotation.
* \param name_hint variable name.
* \param type_annotation The type annotation.
*/
TVM_DLL explicit Var(std::string name_hint, Type type_annotation);
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
Var copy_with_suffix(const std::string& suffix) const {
return Var((*this)->name_hint + suffix, (*this)->dtype);
}
TVM_DLL Var copy_with_suffix(const std::string& suffix) const;
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
Expand All @@ -116,28 +126,21 @@ class Var : public PrimExpr {
*/
class SizeVarNode : public VarNode {
public:
/*! \brief constructor */
SizeVarNode() {}
/*! \brief constructor
* \param dtype data type
* \param name_hint variable name
*/
SizeVarNode(DataType dtype, std::string name_hint);

static constexpr const char* _type_key = "SizeVar";
static constexpr const char* _type_key = "tir.SizeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
};

/*! \brief a named variable represents a tensor index size */
class SizeVar : public Var {
public:
explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
/*! \brief constructor
/*!
* \brief constructor
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit SizeVar(std::string name_hint = "s",
DataType t = DataType::Int(32));
DataType t = DataType::Int(32));
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";
* Type: Integer
*/
constexpr const char* kNoAlias = "tir.noalias";

/*!
* \brief Mark the function as the entry function of
* the final generated runtime module.
*
* Type: Integer
*
* \note There can only be one entry function per module.
*/
constexpr const char* kIsEntryFunc = "tir.is_entry_func";
} // namespace attr
} // namespace tir
} // namespace tvm
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <tvm/te/schedule.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/function.h>
#include <tvm/tir/lowered_func.h>

#include <unordered_map>
Expand Down Expand Up @@ -515,6 +516,19 @@ LoweredFunc CombineContextCall(LoweredFunc f);
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);


/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \note implemeneted in storage_rewrite.cc
* \param f The function to be trasnformed
* \return Transformed function.
*/
PrimFunc PointerValueTypeRewrite(PrimFunc f);

/*!
* \brief Lower attached storage access information on device.
* Do this pass after all storage access analysis finish.
Expand Down
13 changes: 13 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,23 @@ namespace tvm {
* This function could return a more refined type than
* the runtime type provided by expr->dtype
*
* \param expr The input parameter.
* \return The result type.
*
* \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
*/
TVM_DLL Type GetType(const PrimExpr& expr);

/*!
* \brief Get the implied DataType for storing values with type during runtime.
*
* \param type The input type.
* \return The result runtime::DataType.
*
* \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
*/
TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);

/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType
from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,15 @@ def _convert(item, _):
return item
return _convert

def _update_tir_var(new_name):
def _convert(item, _):
item["type_key"] = new_name
item["attrs"]["type_annotation"] = "0"
return item
return _convert

node_map = {
# Base IR
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
Expand All @@ -91,6 +99,9 @@ def _convert(item, _):
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
}
return create_updater(node_map, "0.6", "0.7")

Expand Down
15 changes: 15 additions & 0 deletions python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class TypeKind(IntEnum):
TypeData = 6


@tvm._ffi.register_object("PrimType")
class PrimType(Type):
"""Primitive data type in the low level IR
Expand All @@ -59,6 +60,20 @@ def __init__(self, dtype):
_ffi_api.PrimType, dtype)


@tvm._ffi.register_object("PointerType")
class PointerType(Type):
"""PointerType used in the low-level TIR.
Parameters
----------
element_type : tvm.ir.Type
The type of pointer's element.
"""
def __init__(self, element_type):
self.__init_handle_by_constructor__(
_ffi_api.PointerType, element_type)


@tvm._ffi.register_object("TypeVar")
class TypeVar(Type):
"""Type parameter in functions.
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class CmpExpr(PrimExprWithOp):
class LogicalExpr(PrimExprWithOp):
pass

@tvm._ffi.register_object("Variable")
@tvm._ffi.register_object("tir.Var")
class Var(PrimExprWithOp):
"""Symbolic variable.
Expand All @@ -297,15 +297,15 @@ class Var(PrimExprWithOp):
name : str
The name
dtype : str
dtype : Union[str, tvm.irType]
The data type
"""
def __init__(self, name, dtype):
self.__init_handle_by_constructor__(
_ffi_api.Var, name, dtype)


@tvm._ffi.register_object
@tvm._ffi.register_object("tir.SizeVar")
class SizeVar(Var):
"""Symbolic variable to represent a tensor index size
which is greater or equal to zero.
Expand Down
Loading

0 comments on commit 3bb801e

Please sign in to comment.