diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h index 6e871623a79d..67cfb8d67a18 100644 --- a/include/tvm/ir/adt.h +++ b/include/tvm/ir/adt.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/relay/adt.h + * \file tvm/ir/adt.h * \brief Algebraic data type definitions. * * We adopt relay's ADT definition as a unified class diff --git a/include/tvm/node/env_func.h b/include/tvm/ir/env_func.h similarity index 94% rename from include/tvm/node/env_func.h rename to include/tvm/ir/env_func.h index c2ea2b449c04..f5b17bb1db08 100644 --- a/include/tvm/node/env_func.h +++ b/include/tvm/ir/env_func.h @@ -18,21 +18,24 @@ */ /*! - * \file tvm/node/env_func.h - * \brief Serializable global function. + * \file tvm/ir/env_func.h + * \brief Serializable global function used in IR. */ -#ifndef TVM_NODE_ENV_FUNC_H_ -#define TVM_NODE_ENV_FUNC_H_ +#ifndef TVM_IR_ENV_FUNC_H_ +#define TVM_IR_ENV_FUNC_H_ #include #include #include - namespace tvm { /*! - * \brief Node container of EnvFunc + * \brief A serializable function backed by TVM's global environment. + * + * This is a wrapper to enable serializable global PackedFunc. + * An EnvFunc is saved by its name in the global registry + * under the assumption that the same function is registered during load. * \sa EnvFunc */ class EnvFuncNode : public Object { @@ -53,11 +56,8 @@ class EnvFuncNode : public Object { }; /*! - * \brief A serializable function backed by TVM's global environment. - * - * This is a wrapper to enable serializable global PackedFunc. - * An EnvFunc is saved by its name in the global registry - * under the assumption that the same function is registered during load. + * \brief Managed reference to EnvFuncNode. + * \sa EnvFuncNode */ class EnvFunc : public ObjectRef { public: @@ -140,4 +140,4 @@ class TypedEnvFunc : public ObjectRef { }; } // namespace tvm -#endif // TVM_NODE_ENV_FUNC_H_ +#endif // TVM_IR_ENV_FUNC_H_ diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 7b42678ee103..12b34dd26398 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -198,7 +198,7 @@ class GlobalVarNode : public RelayExprNode { /*! \brief The name of the variable, this only acts as a hint. */ std::string name_hint; - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 735a5f50de2f..8f922c0d42f7 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -48,13 +48,13 @@ class IRModule; class IRModuleNode : public Object { public: /*! \brief A map from ids to all global functions. */ - tvm::Map functions; + Map functions; /*! \brief A map from global type vars to ADT type data. */ - tvm::Map type_definitions; + Map type_definitions; IRModuleNode() {} - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("functions", &functions); v->Visit("type_definitions", &type_definitions); v->Visit("global_var_map_", &global_var_map_); @@ -146,7 +146,7 @@ class IRModuleNode : public Object { * \brief Collect all global vars defined in this module. * \returns An array of global vars */ - TVM_DLL tvm::Array GetGlobalVars() const; + TVM_DLL Array GetGlobalVars() const; /*! * \brief Look up a global function by its name. @@ -159,7 +159,7 @@ class IRModuleNode : public Object { * \brief Collect all global type vars defined in this module. * \returns An array of global type vars */ - TVM_DLL tvm::Array GetGlobalTypeVars() const; + TVM_DLL Array GetGlobalTypeVars() const; /*! * \brief Look up a global function by its variable. @@ -235,12 +235,12 @@ class IRModuleNode : public Object { /*! \brief A map from string names to global variables that * ensures global uniqueness. */ - tvm::Map global_var_map_; + Map global_var_map_; /*! \brief A map from string names to global type variables (ADT names) * that ensures global uniqueness. */ - tvm::Map global_type_var_map_; + Map global_type_var_map_; /*! \brief A map from constructor tags to constructor objects * for convenient access @@ -266,8 +266,8 @@ class IRModule : public ObjectRef { * \param type_definitions Type definitions in the module. * \param import_set Set of imported files in the module */ - TVM_DLL explicit IRModule(tvm::Map functions, - tvm::Map type_definitions = {}, + TVM_DLL explicit IRModule(Map functions, + Map type_definitions = {}, std::unordered_set import_set = {}); /*! \brief default constructor */ IRModule() {} @@ -296,8 +296,8 @@ class IRModule : public ObjectRef { */ TVM_DLL static IRModule FromExpr( const RelayExpr& expr, - const tvm::Map& global_funcs = {}, - const tvm::Map& type_definitions = {}); + const Map& global_funcs = {}, + const Map& type_definitions = {}); /*! * \brief Parse text format source file into an IRModule. diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 19c5a51162a3..f5d0639eb3f2 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -91,7 +91,7 @@ class OpNode : public RelayExprNode { */ int32_t support_level = 10; - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("op_type", &op_type); v->Visit("description", &description); @@ -476,7 +476,7 @@ inline OpRegistry& OpRegistry::add_type_rel( std::string input_name_prefix = "in"; for (int i = 0; i < get()->num_inputs; i++) { auto name = input_name_prefix + std::to_string(i); - auto param = TypeVarNode::make(name, TypeKind::kType); + auto param = TypeVar(name, TypeKind::kType); type_params.push_back(param); arg_types.push_back(param); } @@ -484,7 +484,7 @@ inline OpRegistry& OpRegistry::add_type_rel( Array ty_call_args = arg_types; // Add output type. - auto out_param = TypeVarNode::make("out", TypeKind::kType); + auto out_param = TypeVar("out", TypeKind::kType); type_params.push_back(out_param); // this will trigger copy on write. ty_call_args.push_back(out_param); @@ -498,13 +498,13 @@ inline OpRegistry& OpRegistry::add_type_rel( // A common example is sum(x, axis), where the choice of axis // can affect the type of the function. TypeConstraint type_rel = - TypeRelationNode::make(env_type_rel_func, + TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs()); auto func_type = - FuncTypeNode::make(arg_types, out_param, type_params, {type_rel}); + FuncType(arg_types, out_param, type_params, {type_rel}); get()->op_type = func_type; diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index b9d9d855475b..424472ee7fdc 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -84,13 +84,13 @@ class PassContextNode : public Object { int fallback_device{static_cast(kDLCPU)}; /*! \brief The list of required passes. */ - tvm::Array required_pass; + Array required_pass; /*! \brief The list of disabled passes. */ - tvm::Array disabled_pass; + Array disabled_pass; PassContextNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("opt_level", &opt_level); v->Visit("fallback_device", &fallback_device); v->Visit("required_pass", &required_pass); @@ -118,7 +118,7 @@ class PassContextNode : public Object { class PassContext : public ObjectRef { public: PassContext() {} - explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} + explicit PassContext(ObjectPtr n) : ObjectRef(n) {} /*! * \brief const accessor. * \return const access pointer. @@ -158,7 +158,7 @@ class PassContext : public ObjectRef { // Classes to get the Python `with` like syntax. friend class Internal; - friend class tvm::With; + friend class With; }; /*! @@ -174,11 +174,11 @@ class PassInfoNode : public Object { std::string name; /*! \brief The passes that are required to perform the current pass. */ - tvm::Array required; + Array required; PassInfoNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("opt_level", &opt_level); v->Visit("name", &name); v->Visit("required", &required); @@ -202,7 +202,7 @@ class PassInfo : public ObjectRef { */ TVM_DLL PassInfo(int opt_level, std::string name, - tvm::Array required); + Array required); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -241,7 +241,7 @@ class PassNode : public Object { virtual IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const = 0; - void VisitAttrs(tvm::AttrVisitor* v) {} + void VisitAttrs(AttrVisitor* v) {} static constexpr const char* _type_key = "relay.Pass"; TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object); @@ -289,7 +289,7 @@ class Sequential : public Pass { * \param passes The passes to apply. * \param pass_info The pass metadata. */ - TVM_DLL Sequential(tvm::Array passes, PassInfo pass_info); + TVM_DLL Sequential(Array passes, PassInfo pass_info); /*! * \brief The constructor of `Sequential`. @@ -299,10 +299,10 @@ class Sequential : public Pass { * This allows users to only provide a list of passes and execute them * under a given context. */ - TVM_DLL Sequential(tvm::Array passes, std::string name = "sequential"); + TVM_DLL Sequential(Array passes, std::string name = "sequential"); Sequential() = default; - explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {} + explicit Sequential(ObjectPtr n) : Pass(n) {} const SequentialNode* operator->() const; using ContainerType = Sequential; @@ -322,7 +322,7 @@ Pass CreateModulePass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const tvm::Array& required); + const Array& required); } // namespace transform } // namespace tvm diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index ddabd0fbbcc0..e143588ee4a3 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -50,15 +50,26 @@ #define TVM_IR_TYPE_H_ #include +#include #include -#include #include #include #include namespace tvm { -/*! \brief Base type of all the types. */ +/*! + * \brief Type is the base type of all types. + * + * Relay's type system contains following subclasses: + * + * - PrimType: type of primitive type values used in the low-level IR. + * - FuncType: type of a function. + * - TensorType: type of certain Tensor values in the expression. + * + * There are also advanced types to support generic(polymorphic types). + * \sa Type + */ class TypeNode : public Object { public: /*! @@ -72,29 +83,58 @@ class TypeNode : public Object { }; /*! - * \brief Type is the base type of all types. - * - * Relay's type system contains following two key concepts: - * - * - PrimitiveType: type of primitive type values used in the low-level IR. - * - TensorType: type of certain Tensor values in the expression. - * - FunctionType: the type of the function. - * - * There are also advanced types to support generic(polymorphic types), - * which can be ignored when first reading the code base. + * \brief Managed reference to TypeNode. + * \sa TypeNode */ class Type : public ObjectRef { public: TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode); }; +/*! + * \brief Primitive data types used in the low-level IR. + * + * PrimType represents POD-values and handles that are + * not automatically managed by the runtime. + * + * \sa PrimType + */ +class PrimTypeNode : public TypeNode { + public: + /*! + * \brief The corresponding dtype field. + */ + runtime::DataType dtype; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "relay.PrimType"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); +}; + +/*! + * \brief Managed reference to PrimTypeNode. + * \sa PrimTypeNode + */ +class PrimType : public Type { + public: + /*! + * \brief Constructor + * \param dtype The corresponding dtype. + */ + TVM_DLL PrimType(runtime::DataType dtype); + + TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); +}; + /*! \brief Possible kinds of TypeVars. */ enum TypeKind : int { kType = 0, /*! \brief Template variable in shape expression. */ kShapeVar = 1, kBaseType = 2, - kShape = 3, kConstraint = 4, kAdtHandle = 5, kTypeData = 6 @@ -115,10 +155,8 @@ enum TypeKind : int { * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] * * \endcode - * \sa TypeVarNode The actual container class of TypeVar + * \sa TypeVar, TypeKind */ -class TypeVar; -/*! \brief TypeVar container node */ class TypeVarNode : public TypeNode { public: /*! @@ -130,28 +168,36 @@ class TypeVarNode : public TypeNode { /*! \brief The kind of type parameter */ TypeKind kind; - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); v->Visit("kind", &kind); v->Visit("span", &span); } - TVM_DLL static TypeVar make(std::string name, TypeKind kind); - static constexpr const char* _type_key = "relay.TypeVar"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); }; +/*! + * \brief Managed reference to TypeVarNode + * \sa TypeVarNode + */ class TypeVar : public Type { public: + /*! + * \brief Constructor + * \param name_hint The name of the type var. + * \param kind The kind of the type var. + */ + TVM_DLL TypeVar(std::string name_hint, TypeKind kind); + TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); }; /*! * \brief A global type variable that is used for defining new types or type aliases. + * \sa GlobalTypeVar */ -class GlobalTypeVar; -/*! \brief GlobalTypeVar container node */ class GlobalTypeVarNode : public TypeNode { public: /*! @@ -163,47 +209,98 @@ class GlobalTypeVarNode : public TypeNode { /*! \brief The kind of type parameter */ TypeKind kind; - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); v->Visit("kind", &kind); } - TVM_DLL static GlobalTypeVar make(std::string name, TypeKind kind); - static constexpr const char* _type_key = "relay.GlobalTypeVar"; TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); }; +/*! + * \brief Managed reference to GlobalTypeVarNode + * \sa GlobalTypeVarNode + */ class GlobalTypeVar : public Type { public: + /*! + * \brief Constructor + * \param name_hint The name of the type var. + * \param kind The kind of the type var. + */ + TVM_DLL GlobalTypeVar(std::string name_hint, TypeKind kind); + TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode); }; /*! - * \brief Potential Constraints in the type. - * \note This is reserved for future use. + * \brief The type of tuple values. + * \sa TupleType + */ +class TupleTypeNode : public TypeNode { + public: + /*! \brief The type of each field in the tuple. */ + Array fields; + + TupleTypeNode() {} + + void VisitAttrs(AttrVisitor* v) { + v->Visit("fields", &fields); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.TupleType"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); +}; + +/*! + * \brief Managed reference to TupleTypeNode. + * \sa TupleTypeNode. + */ +class TupleType : public Type { + public: + /*! + * \brief Constructor + * \param fields Fields in the tuple. + */ + TVM_DLL explicit TupleType(Array fields); + + /*! + * \brief Create an empty tuple type that constains nothing. + * \return A empty tuple type. + */ + TVM_DLL TupleType static Empty(); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode); +}; + +/*! + * \brief Potential Constraints in a function. + * \sa TypeConstraint */ -class TypeConstraint; -/*! \brief TypeConstraint container node. */ class TypeConstraintNode : public TypeNode { public: static constexpr const char* _type_key = "relay.TypeConstraint"; TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); }; +/*! + * \brief Managed reference to TypeConstraintNode. + * \sa TypeConstraintNode, TypeRelation + */ class TypeConstraint : public Type { public: TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode); }; -class FuncType; /*! - * \brief Function type in Relay. + * \brief Function type. * - * Relay support polymorphic function type. + * We support polymorphic function type. * This can be roughly viewed as template function in C++. * - * \sa TypeVar, TypeConstraint + * \sa FuncType, TypeVar, TypeConstraint */ class FuncTypeNode : public TypeNode { public: @@ -221,7 +318,7 @@ class FuncTypeNode : public TypeNode { */ Array type_constraints; - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("arg_types", &arg_types); v->Visit("ret_type", &ret_type); v->Visit("type_params", &type_params); @@ -229,17 +326,29 @@ class FuncTypeNode : public TypeNode { v->Visit("span", &span); } - TVM_DLL static FuncType make(Array arg_types, - Type ret_type, - Array type_params, - Array type_constraints); - static constexpr const char* _type_key = "relay.FuncType"; TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); }; +/*! + * \brief Managed reference to FuncTypeNode. + * \sa FuncTypeNode + */ class FuncType : public Type { public: + /*! + * \brief Constructor + * \param arg_types The types of the arguments. + * \param ret_type The type of the return value. + * \param type_params The type parameters. + * \param type_constraints The type constraints. + * \sa FuncTypeNode for more docs about these fields. + */ + TVM_DLL FuncType(Array arg_types, + Type ret_type, + Array type_params, + Array type_constraints); + TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); }; diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index db3582ec69f4..333c538ccf4a 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -19,19 +19,56 @@ /*! * \file tvm/ir/type_relation.h - * \brief Type relation function for type checking. + * \brief Type relation and function for type inference(checking). */ #ifndef TVM_IR_TYPE_RELATION_H_ #define TVM_IR_TYPE_RELATION_H_ #include +#include +#include #include namespace tvm { -// TODO(tqchen): remove after migrate Module to ir. -class IRModule; +/*! + * \brief Type function application. + * \sa TypeCall + */ +class TypeCallNode : public TypeNode { + public: + /*! + * \brief The type-level function (ADT that takes type params). + */ + Type func; + /*! \brief The arguments. */ + Array args; + void VisitAttrs(AttrVisitor* v) { + v->Visit("func", &func); + v->Visit("args", &args); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.TypeCall"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode); +}; + +/*! + * \brief Managed reference to TypeCallNode. + * \sa TypeCallNode + */ +class TypeCall : public Type { + public: + /*! + * \brief Constructor + * \param func The type function to apply. + * \param args The arguments to the type function. + */ + TVM_DLL TypeCall(Type func, Array args); + + TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode); +}; /*! * \brief reporter that reports back to the @@ -78,7 +115,7 @@ class TypeReporterNode : public Object { TVM_DLL virtual IRModule GetModule() = 0; // solver is not serializable. - void VisitAttrs(tvm::AttrVisitor* v) {} + void VisitAttrs(AttrVisitor* v) {} static constexpr const char* _type_key = "relay.TypeReporter"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); @@ -91,7 +128,7 @@ class TypeReporterNode : public Object { class TypeReporter : public ObjectRef { public: TypeReporter() {} - explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { + explicit TypeReporter(ObjectPtr n) : ObjectRef(n) { } TypeReporterNode* operator->() const { return const_cast( @@ -127,12 +164,11 @@ using TypeRelationFn = /*! * \brief User defined type relation, is an input-output relation on types. - */ -class TypeRelation; -/*! - * \brief TypeRelation container. - * \note This node is not directly serializable. - * The type function need to be lookedup in the module. + * + * TypeRelation is more generalized than type call as it allows inference + * of both inputs and outputs. + * + * \sa TypeRelation */ class TypeRelationNode : public TypeConstraintNode { public: @@ -143,13 +179,13 @@ class TypeRelationNode : public TypeConstraintNode { */ TypeRelationFn func; /*! \brief The type arguments to the type function. */ - tvm::Array args; + Array args; /*! \brief Number of inputs arguments */ int num_inputs; /*! \brief Attributes to the relation function */ Attrs attrs; - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("args", &args); v->Visit("num_inputs", &num_inputs); @@ -157,17 +193,29 @@ class TypeRelationNode : public TypeConstraintNode { v->Visit("span", &span); } - TVM_DLL static TypeRelation make(TypeRelationFn func, - Array args, - int num_args, - Attrs attrs); - static constexpr const char* _type_key = "relay.TypeRelation"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); }; +/*! + * \brief Managed reference to TypeRelationNode. + * \sa TypeRelationNode + */ class TypeRelation : public TypeConstraint { public: + /*! + * \brief Constructor + * \param func The relation function. + * \param args The arguments to the type relation. + * \param num_inputs Number of inputs. + * \param attrs Attributes to the relation function. + * \sa TypeRelationNode for more docs about these fields. + */ + TVM_DLL TypeRelation(TypeRelationFn func, + Array args, + int num_inputs, + Attrs attrs); + TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); }; } // namespace tvm diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 7748bd108dfb..17b8b57dbbf5 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -39,6 +39,8 @@ namespace tvm { namespace relay { +// namespace update for backward compact +// will be removed later. using Any = tvm::ir::AnyNode; using Kind = TypeKind; using Type = tvm::Type; @@ -47,10 +49,14 @@ using TypeVar = tvm::TypeVar; using TypeVarNode = tvm::TypeVarNode; using GlobalTypeVar = tvm::GlobalTypeVar; using GlobalTypeVarNode = tvm::GlobalTypeVarNode; +using TupleType = tvm::TupleType; +using TupleTypeNode = tvm::TupleTypeNode; using TypeConstraint = tvm::TypeConstraint; using TypeConstraintNode = tvm::TypeConstraintNode; using FuncType = tvm::FuncType; using FuncTypeNode = tvm::FuncTypeNode; +using TypeCall = tvm::TypeCall; +using TypeCallNode = tvm::TypeCallNode; using TypeRelation = tvm::TypeRelation; using TypeRelationNode = tvm::TypeRelationNode; using TypeRelationFn = tvm::TypeRelationFn; @@ -118,37 +124,6 @@ class TensorType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode); }; -/*! - * \brief Type application. - */ -class TypeCall; -/*! \brief TypeCall container node */ -class TypeCallNode : public TypeNode { - public: - /*! - * \brief The type-level function (ADT that takes type params). - */ - Type func; - /*! \brief The arguments. */ - tvm::Array args; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("args", &args); - v->Visit("span", &span); - } - - TVM_DLL static TypeCall make(Type func, tvm::Array args); - - static constexpr const char* _type_key = "relay.TypeCall"; - TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode); -}; - -class TypeCall : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode); -}; - /*! * \brief IncompleteType. * This is intermediate values that is used during type inference. @@ -180,36 +155,6 @@ class IncompleteType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); }; -/*! - * \brief The type of tuple values. - */ -class TupleType; -/*! - * \brief TupleType container. - */ -class TupleTypeNode : public TypeNode { - public: - /*! \brief The type of each field in the tuple. */ - tvm::Array fields; - - TupleTypeNode() {} - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("fields", &fields); - v->Visit("span", &span); - } - - TVM_DLL static TupleType make(tvm::Array fields); - - static constexpr const char* _type_key = "relay.TupleType"; - TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); -}; - -class TupleType : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode); -}; - /*! * \brief The type of reference values. */ diff --git a/src/api/api_test.cc b/src/api/api_test.cc index 6f01c7a4fb2d..7ded78b26c35 100644 --- a/src/api/api_test.cc +++ b/src/api/api_test.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include diff --git a/src/node/env_func.cc b/src/ir/env_func.cc similarity index 98% rename from src/node/env_func.cc rename to src/ir/env_func.cc index f6603ea57e2a..8118d036bcd6 100644 --- a/src/node/env_func.cc +++ b/src/ir/env_func.cc @@ -20,7 +20,7 @@ /*! * \file env_func.cc */ -#include +#include #include #include diff --git a/src/ir/type.cc b/src/ir/type.cc index 465d1ea110d2..4ba160758e34 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -27,18 +27,38 @@ namespace tvm { -TypeVar TypeVarNode::make(std::string name, TypeKind kind) { +PrimType::PrimType(runtime::DataType dtype) { + ObjectPtr n = make_object(); + n->dtype = dtype; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PrimTypeNode); + +TVM_REGISTER_GLOBAL("relay._make.PrimType") +.set_body_typed([](runtime::DataType dtype) { + return PrimType(dtype); +}); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << node->dtype; +}); + + +TypeVar::TypeVar(std::string name, TypeKind kind) { ObjectPtr n = make_object(); n->name_hint = std::move(name); n->kind = std::move(kind); - return TypeVar(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_GLOBAL("relay._make.TypeVar") .set_body_typed([](std::string name, int kind) { - return TypeVarNode::make(name, static_cast(kind)); + return TypeVar(name, static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) @@ -48,18 +68,19 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) << node->kind << ")"; }); -GlobalTypeVar GlobalTypeVarNode::make(std::string name, TypeKind kind) { + +GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { ObjectPtr n = make_object(); n->name_hint = std::move(name); n->kind = std::move(kind); - return GlobalTypeVar(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar") .set_body_typed([](std::string name, int kind) { - return GlobalTypeVarNode::make(name, static_cast(kind)); + return GlobalTypeVar(name, static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) @@ -69,22 +90,27 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) << node->kind << ")"; }); -FuncType FuncTypeNode::make(tvm::Array arg_types, - Type ret_type, - tvm::Array type_params, - tvm::Array type_constraints) { +FuncType::FuncType(tvm::Array arg_types, + Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints) { ObjectPtr n = make_object(); n->arg_types = std::move(arg_types); n->ret_type = std::move(ret_type); n->type_params = std::move(type_params); n->type_constraints = std::move(type_constraints); - return FuncType(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_GLOBAL("relay._make.FuncType") -.set_body_typed(FuncTypeNode::make); +.set_body_typed([](tvm::Array arg_types, + Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints) { + return FuncType(arg_types, ret_type, type_params, type_constraints); +}); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& ref, NodePrinter* p) { @@ -94,4 +120,27 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) << node->type_constraints << ")"; }); +TupleType::TupleType(Array fields) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + data_ = std::move(n); +} + +TupleType TupleType::Empty() { + return TupleType(Array()); +} + +TVM_REGISTER_NODE_TYPE(TupleTypeNode); + +TVM_REGISTER_GLOBAL("relay._make.TupleType") +.set_body_typed([](Array fields) { + return TupleType(fields); +}); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleTypeNode(" << node->fields << ")"; +}); + } // namespace tvm diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc index cc5ceef7dd3e..06d665e8aced 100644 --- a/src/ir/type_relation.cc +++ b/src/ir/type_relation.cc @@ -27,22 +27,49 @@ #include namespace tvm { -TypeRelation TypeRelationNode::make(TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs) { + +TypeCall::TypeCall(Type func, tvm::Array args) { + ObjectPtr n = make_object(); + n->func = std::move(func); + n->args = std::move(args); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TypeCallNode); + +TVM_REGISTER_GLOBAL("relay._make.TypeCall") +.set_body_typed([](Type func, Array type) { + return TypeCall(func, type); +}); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeCallNode(" << node->func << ", " + << node->args << ")"; +}); + +TypeRelation::TypeRelation(TypeRelationFn func, + Array args, + int num_inputs, + Attrs attrs) { ObjectPtr n = make_object(); n->func = std::move(func); n->args = std::move(args); n->num_inputs = num_inputs; n->attrs = std::move(attrs); - return TypeRelation(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_GLOBAL("relay._make.TypeRelation") -.set_body_typed(TypeRelationNode::make); +.set_body_typed([](TypeRelationFn func, + Array args, + int num_inputs, + Attrs attrs) { + return TypeRelation(func, args, num_inputs, attrs); +}); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& ref, NodePrinter* p) { diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index d4a7cb1f2ad1..00c40b2565bb 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -246,7 +246,7 @@ class ScheduleGetter : new_fields.push_back(field); } } - call_node_type = TupleTypeNode::make(new_fields); + call_node_type = TupleType(new_fields); } CHECK(call_node->op.as()) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 239a33ea642c..f66cce6b7b82 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -135,7 +135,7 @@ FuncType FunctionNode::func_type_annotation() const { Type ret_type = (this->ret_type.defined()) ? this->ret_type : IncompleteTypeNode::make(Kind::kType); - return FuncTypeNode::make(param_types, ret_type, this->type_params, {}); + return FuncType(param_types, ret_type, this->type_params, {}); } bool FunctionNode::IsPrimitive() const { diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 099b8013c895..5680a789544c 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -63,25 +63,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; }); -TypeCall TypeCallNode::make(Type func, tvm::Array args) { - ObjectPtr n = make_object(); - n->func = std::move(func); - n->args = std::move(args); - return TypeCall(n); -} - -TVM_REGISTER_NODE_TYPE(TypeCallNode); - -TVM_REGISTER_GLOBAL("relay._make.TypeCall") -.set_body_typed(TypeCallNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeCallNode(" << node->func << ", " - << node->args << ")"; -}); - IncompleteType IncompleteTypeNode::make(Kind kind) { auto n = make_object(); n->kind = std::move(kind); @@ -102,23 +83,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); -TupleType TupleTypeNode::make(Array fields) { - ObjectPtr n = make_object(); - n->fields = std::move(fields); - return TupleType(n); -} - -TVM_REGISTER_NODE_TYPE(TupleTypeNode); - -TVM_REGISTER_GLOBAL("relay._make.TupleType") -.set_body_typed(TupleTypeNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleTypeNode(" << node->fields << ")"; -}); - RefType RefTypeNode::make(Type value) { ObjectPtr n = make_object(); n->value = std::move(value); diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 3ac0564e7ec5..0180a0c64ba0 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -154,7 +154,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { changed = changed || !new_ret_type.same_as(op->ret_type); if (!changed) return GetRef(op); - return FuncTypeNode::make(new_args, + return FuncType(new_args, new_ret_type, type_params, type_constraints); @@ -165,7 +165,7 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) { if (new_fields.same_as(op->fields)) { return GetRef(op); } else { - return TupleTypeNode::make(new_fields); + return TupleType(new_fields); } } @@ -178,7 +178,7 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { if (new_args.same_as(type_rel->args)) { return GetRef(type_rel); } else { - return TypeRelationNode::make(type_rel->func, + return TypeRelation(type_rel->func, new_args, type_rel->num_inputs, type_rel->attrs); @@ -195,7 +195,7 @@ Type TypeMutator::VisitType_(const TypeCallNode* op) { if (new_args.same_as(op->args) && new_func.same_as(op->func)) { return GetRef(op); } else { - return TypeCallNode::make(new_func, new_args); + return TypeCall(new_func, new_args); } } diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index 161ca1c34e0a..888d431579a8 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -55,7 +55,7 @@ bool TopKRel(const Array& types, auto values_ty = TensorTypeNode::make(out_shape, data->dtype); auto indices_ty = TensorTypeNode::make(out_shape, param->dtype); if (param->ret_type == "both") { - reporter->Assign(types[1], TupleTypeNode::make({values_ty, indices_ty})); + reporter->Assign(types[1], TupleType({values_ty, indices_ty})); } else if (param->ret_type == "values") { reporter->Assign(types[1], values_ty); } else if (param->ret_type == "indices") { diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index bd8fb4272a48..bd3b543659ae 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -65,7 +65,7 @@ bool AllocStorageRel(const Array& types, int num_inputs, const Attrs& attr auto mod = reporter->GetModule(); CHECK(mod.defined()); auto storage_name = mod->GetGlobalTypeVar("Storage"); - auto storage = TypeCallNode::make(storage_name, {}); + auto storage = TypeCall(storage_name, {}); reporter->Assign(types[2], storage); return true; } @@ -136,7 +136,7 @@ bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs auto mod = reporter->GetModule(); CHECK(mod.defined()); auto storage_name = mod->GetGlobalTypeVar("Storage"); - auto storage = relay::TypeCallNode::make(storage_name, {}); + auto storage = relay::TypeCall(storage_name, {}); reporter->Assign(types[0], storage); // Second argument should be shape tensor. auto tt = types[1].as(); @@ -196,15 +196,15 @@ bool InvokeTVMOPRel(const Array& types, int num_inputs, const Attrs& attrs << "internal invariant violated: invoke_tvm_op outputs must be a tuple"; Type ex_output; if (func_type->ret_type.as()) { - ex_output = TupleTypeNode::make({func_type->ret_type}); + ex_output = TupleType({func_type->ret_type}); } else { CHECK(func_type->ret_type.as()) << "should be tuple type"; ex_output = func_type->ret_type; } - auto ex_input = TupleTypeNode::make(func_type->arg_types); + auto ex_input = TupleType(func_type->arg_types); reporter->Assign(ex_input, GetRef(input_type)); reporter->Assign(ex_output, GetRef(output_type)); - reporter->Assign(types[3], TupleTypeNode::make({})); + reporter->Assign(types[3], TupleType::Empty()); return true; } @@ -236,7 +236,7 @@ bool KillRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2u); // TODO(@jroesch): should only support tensors. - reporter->Assign(types[1], TupleTypeNode::make({})); + reporter->Assign(types[1], TupleType::Empty()); return true; } @@ -297,7 +297,7 @@ bool ShapeFuncRel(const Array& types, int num_inputs, const Attrs& attrs, auto func_type = types[0].as(); CHECK(func_type != nullptr); - auto tuple = TupleTypeNode::make(func_type->arg_types); + auto tuple = TupleType(func_type->arg_types); auto in_types = FlattenType(tuple); auto out_types = FlattenType(func_type->ret_type); @@ -318,12 +318,12 @@ bool ShapeFuncRel(const Array& types, int num_inputs, const Attrs& attrs, shape_func_outs.push_back(TensorTypeNode::make(rank_shape, DataType::Int(64))); } - auto input_type = TupleTypeNode::make(shape_func_ins); - auto output_type = TupleTypeNode::make(shape_func_outs); + auto input_type = TupleType(shape_func_ins); + auto output_type = TupleType(shape_func_outs); reporter->Assign(types[1], input_type); reporter->Assign(types[2], output_type); - reporter->Assign(types[3], TupleTypeNode::make({})); + reporter->Assign(types[3], TupleType::Empty()); return true; } diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 35d6cba2747c..f1d711146a99 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -586,7 +586,7 @@ bool DropoutRel(const Array& types, // dropout returns the original tensor with dropout applied // and a mask tensor (1.0 where element not dropped, 0.0 where dropped) auto ret_type = TensorTypeNode::make(data->shape, data->dtype); - reporter->Assign(types[1], TupleTypeNode::make(Array({ret_type, ret_type}))); + reporter->Assign(types[1], TupleType(Array({ret_type, ret_type}))); return true; } @@ -674,7 +674,7 @@ bool BatchNormRel(const Array& types, fields.push_back(TensorTypeNode::make(data->shape, data->dtype)); fields.push_back(vec_ty); fields.push_back(vec_ty); - reporter->Assign(types[5], TupleTypeNode::make(Array(fields))); + reporter->Assign(types[5], TupleType(Array(fields))); return true; } diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 75aefe276f59..f2be18202410 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -109,7 +109,7 @@ bool SparseTransposeRel(const Array& types, int num_inputs, const Attrs& a output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype)); output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype)); - reporter->Assign(types[3], TupleTypeNode::make(Array(output_types))); + reporter->Assign(types[3], TupleType(Array(output_types))); return true; } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index aa643c4d8bc6..4d3a4b9589ee 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2150,7 +2150,7 @@ bool SplitRel(const Array& types, auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); } - reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); + reporter->Assign(types[1], TupleType(Array(fields))); } else { auto indices = param->indices_or_sections.as()->data; auto begin = IndexExpr(make_zero(DataType::Int(32))); @@ -2170,7 +2170,7 @@ bool SplitRel(const Array& types, oshape[axis] = data->shape[axis] - begin; auto vec_type = TensorTypeNode::make(oshape, data->dtype); fields.push_back(vec_type); - reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); + reporter->Assign(types[1], TupleType(Array(fields))); } return true; } diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index d9ec21f8ddf2..6a1b34d092d2 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -125,7 +125,7 @@ bool MultiBoxTransformLocRel(const Array& types, fields.push_back(TensorTypeNode::make(oshape1, DataType::Int(32))); // assign output type - reporter->Assign(types[3], TupleTypeNode::make(Array(fields))); + reporter->Assign(types[3], TupleType(Array(fields))); return true; } diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 307b2a4d0fa6..452477928593 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -44,7 +44,7 @@ bool GetValidCountRel(const Array& types, fields.push_back(TensorTypeNode::make(data->shape, data->dtype)); // assign output type - reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); + reporter->Assign(types[1], TupleType(Array(fields))); return true; } diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc index 3cfed1b9788b..fc7f820e6d86 100644 --- a/src/relay/pass/de_duplicate.cc +++ b/src/relay/pass/de_duplicate.cc @@ -37,7 +37,7 @@ Expr DeDup(const Expr& e) { public PatternMutator { public: TypeVar Fresh(const TypeVar& tv) { - TypeVar ret = TypeVarNode::make(tv->name_hint, tv->kind); + TypeVar ret = TypeVar(tv->name_hint, tv->kind); type_rename_[tv] = ret; return ret; } diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 3d7ed9390c8c..8dece3fa3528 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -42,7 +42,7 @@ class TypeVarReplacer : public TypeMutator { Type VisitType_(const TypeVarNode* type_var_node) final { const auto type_var = GetRef(type_var_node); if (replace_map_.find(type_var) == replace_map_.end()) { - replace_map_[type_var] = TypeVarNode::make("A", Kind::kType); + replace_map_[type_var] = TypeVar("A", Kind::kType); } return replace_map_[type_var]; } @@ -109,7 +109,7 @@ class EtaExpander : public ExprMutator { type_params.push_back(type_var_replacer_.VisitType(type_var)); } Expr body = CallNode::make(cons, params, Attrs()); - Type ret_type = TypeCallNode::make(cons->belong_to, type_params); + Type ret_type = TypeCall(cons->belong_to, type_params); return FunctionNode::make( Downcast>(params), diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 78f17bcc6eeb..e236de72be54 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -73,10 +73,10 @@ Type WithGradientType(const Type& t) { // TODO(M.K.): stricter checking auto ty = t.as(); CHECK(ty) << "input should be a function"; - return FuncTypeNode::make(ty->arg_types, - TupleTypeNode::make({ + return FuncType(ty->arg_types, + TupleType({ ty->ret_type, - TupleTypeNode::make(ty->arg_types)}), {}, {}); + TupleType(ty->arg_types)}), {}, {}); } //! \brief if the expression is a GlobalVar, transform to it's expression. @@ -219,7 +219,7 @@ Type GradRetType(const Function& f) { vt.push_back(p->type_annotation); } - return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); + return TupleType({f->ret_type, TupleType(vt)}); } Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { @@ -265,7 +265,7 @@ TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient") struct ReverseADType : TypeMutator { Type VisitType_(const TensorTypeNode* ttn) final { Type t = GetRef(ttn); - return TupleTypeNode::make({t, RefTypeNode::make(t)}); + return TupleType({t, RefTypeNode::make(t)}); } }; @@ -299,7 +299,7 @@ Expr LiftTensor(const std::function& f, types.push_back(field->checked_type_); } auto ret = TupleNode::make(fields); - ret->checked_type_ = TupleTypeNode::make(types); + ret->checked_type_ = TupleType(types); return std::move(ret); } else { LOG(FATAL) << "unsupported input/output type: " << tt; @@ -385,7 +385,7 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { } Expr BPEmpty() { - Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {}); + Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleType::Empty(), {}); return RefCreateNode::make(unitF); } @@ -426,7 +426,7 @@ struct ReverseAD : ExprMutator { ll->Push(CallNode::make(RefReadNode::make(dup_bp), {})); return CallNode::make(bpv, {}); }), - TupleTypeNode::make({}), + TupleType::Empty(), {}); ll->Push(RefWriteNode::make(bp, nbp)); return ret; @@ -468,7 +468,7 @@ struct ReverseAD : ExprMutator { } return CallNode::make(bpv, {}); }), - TupleTypeNode::make({}), + TupleType::Empty(), {}); ll->Push(RefWriteNode::make(bp, nbp)); return ret; diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index 3ffc4ad959f6..f88a7c91b27b 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -63,7 +63,7 @@ namespace relay { // we assume the data type has no closure - no idea how to look into datatype right now. Type Arrow(const Type& l, const Type& r) { - return FuncTypeNode::make({l}, r, {}, {}); + return FuncType({l}, r, {}, {}); } Type CPSType(const Type& t, const TypeVar& answer); @@ -74,7 +74,7 @@ FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) { new_arg_types.push_back(CPSType(t, answer)); } new_arg_types.push_back(Arrow(CPSType(f->ret_type, answer), answer)); - return FuncTypeNode::make(new_arg_types, answer, f->type_params, f->type_constraints); + return FuncType(new_arg_types, answer, f->type_params, f->type_constraints); } Type CPSType(const Type& t, const TypeVar& answer) { @@ -302,7 +302,7 @@ Function ToCPS(const Function& f, } Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { - TypeVar answer = TypeVarNode::make("answer", kType); + TypeVar answer = TypeVar("answer", kType); VarMap var; struct Remapper : ExprVisitor, PatternVisitor { Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) { } @@ -348,7 +348,7 @@ Function UnCPS(const Function& f) { auto new_ret_type = Type(cont_type->arg_types[0]); std::vector new_type_params; for (const auto& tp : f->type_params) { - new_type_params.push_back(TypeVarNode::make(tp->name_hint, tp->kind)); + new_type_params.push_back(TypeVar(tp->name_hint, tp->kind)); } auto answer_type = new_type_params.back(); new_type_params.pop_back(); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index faf42ab9ed4a..a513f3e51a10 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -206,7 +206,7 @@ class TypeInferencer : private ExprFunctor, for (Expr field : op->fields) { types.push_back(GetType(field)); } - return TupleTypeNode::make(types); + return TupleType(types); } Type VisitExpr_(const TupleGetItemNode* op) final { @@ -218,7 +218,7 @@ class TypeInferencer : private ExprFunctor, Type rtype = IncompleteTypeNode::make(Kind::kType); auto attrs = make_object(); attrs->index = op->index; - solver_.AddConstraint(TypeRelationNode::make( + solver_.AddConstraint(TypeRelation( tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef(op)); return rtype; } @@ -235,7 +235,7 @@ class TypeInferencer : private ExprFunctor, for (size_t i = 0; i < td->type_vars.size(); i++) { unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } - Type expected = TypeCallNode::make(con->constructor->belong_to, unknown_args); + Type expected = TypeCall(con->constructor->belong_to, unknown_args); Type unified = Unify(t, expected, GetRef(con)); auto* tc = unified.as(); @@ -277,7 +277,7 @@ class TypeInferencer : private ExprFunctor, for (size_t i = 0; i < tup->patterns.size(); i++) { unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } - Type expected = TupleTypeNode::make(unknown_args); + Type expected = TupleType(unknown_args); Type unified = Unify(t, expected, GetRef(tup)); auto* tt = unified.as(); @@ -388,7 +388,7 @@ class TypeInferencer : private ExprFunctor, Type rtype = IncompleteTypeNode::make(Kind::kType); arg_types.push_back(rtype); // we can do simple replacement here - solver_.AddConstraint(TypeRelationNode::make( + solver_.AddConstraint(TypeRelation( rel->func, arg_types, arg_types.size() - 1, attrs), loc); return rtype; } @@ -418,7 +418,7 @@ class TypeInferencer : private ExprFunctor, ret_type = IncompleteTypeNode::make(Kind::kType); } - Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, + Type inst_ty = FuncType(fn_ty->arg_types, ret_type, {}, fn_ty->type_constraints); inst_ty = Bind(inst_ty, subst_map); @@ -467,7 +467,7 @@ class TypeInferencer : private ExprFunctor, // with an unknown return type if (inc_ty_node != nullptr) { Type ret_type = IncompleteTypeNode::make(Kind::kType); - Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); + Type func_type = FuncType(arg_types, ret_type, {}, {}); Type unified = this->Unify(ftype, func_type, GetRef(call)); fn_ty_node = unified.as(); } @@ -513,7 +513,7 @@ class TypeInferencer : private ExprFunctor, for (auto cs : fn_ty->type_constraints) { if (const auto* tr = cs.as()) { solver_.AddConstraint( - TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs), + TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), GetRef(call)); } else { solver_.AddConstraint(cs, GetRef(call)); @@ -557,7 +557,7 @@ class TypeInferencer : private ExprFunctor, rtype = this->Unify(f->ret_type, rtype, GetRef(f)); } CHECK(rtype.defined()); - auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); + auto ret = FuncType(arg_types, rtype, f->type_params, {}); return solver_.Resolve(ret); } @@ -575,7 +575,7 @@ class TypeInferencer : private ExprFunctor, Type it = IncompleteTypeNode::make(Kind::kType); this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef(op)); this->Unify(GetType(op->value), it, GetRef(op)); - return TupleTypeNode::make({}); + return TupleType::Empty(); } Type VisitExpr_(const ConstructorNode* c) final { @@ -587,7 +587,7 @@ class TypeInferencer : private ExprFunctor, for (const auto & t : td->type_vars) { types.push_back(t); } - return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types), + return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {}); } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index d0d8b43f4c61..594669343f62 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -286,7 +286,7 @@ class TypeSolver::Unifier : public TypeFunctor { Type field = Unify(tt1->fields[i], tt2->fields[i]); new_fields.push_back(field); } - return TupleTypeNode::make(new_fields); + return TupleType(new_fields); } Type VisitType_(const FuncTypeNode* op, const Type& tn) final { @@ -314,7 +314,7 @@ class TypeSolver::Unifier : public TypeFunctor { subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType)); } - FuncType ft = FuncTypeNode::make(op->arg_types, + FuncType ft = FuncType(op->arg_types, op->ret_type, ft_type_params, op->type_constraints); @@ -339,7 +339,7 @@ class TypeSolver::Unifier : public TypeFunctor { type_constraints.push_back(GetRef(tcn)); } - return FuncTypeNode::make(arg_types, ret_type, ft2->type_params, type_constraints); + return FuncType(arg_types, ret_type, ft2->type_params, type_constraints); } Type VisitType_(const RefTypeNode* op, const Type& tn) final { @@ -361,7 +361,7 @@ class TypeSolver::Unifier : public TypeFunctor { for (size_t i = 0; i < op->args.size(); i++) { args.push_back(Unify(op->args[i], tcn->args[i])); } - return TypeCallNode::make(func, args); + return TypeCall(func, args); } private: diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 070892b1a0e2..7d03d2e0ef87 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -37,7 +37,7 @@ TEST(Relay, SelfReference) { mod = relay::transform::InferType()(mod); auto type_fx = mod->Lookup("main"); - auto expected = relay::FuncTypeNode::make(tvm::Array{ tensor_type }, tensor_type, {}, {}); + auto expected = relay::FuncType(tvm::Array{ tensor_type }, tensor_type, {}, {}); CHECK(relay::AlphaEqual(type_fx->checked_type(), expected)); }