diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 976af619256a3..faae303d95ddf 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -24,6 +24,7 @@ #ifndef TVM_EXPR_H_ #define TVM_EXPR_H_ +#include #include #include #include @@ -37,58 +38,6 @@ namespace tvm { -/*! - * \brief Base node of all primitive expressions. - * - * A primitive expression deals with low-level - * POD data types and handles without - * doing life-cycle management for objects. - * - * PrimExpr is used in the low-level code - * optimizations and integer analysis. - * - * \sa PrimExpr - */ -class PrimExprNode : public Object { - public: - /*! \brief The data type of the expression. */ - DataType dtype; - - static constexpr const char* _type_key = "PrimExpr"; - TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, Object); -}; - -/*! - * \brief Container of all primitive expressions. - * \sa PrimExprNode - */ -class PrimExpr : public ObjectRef { - public: - PrimExpr() {} - explicit PrimExpr(ObjectPtr ptr) : ObjectRef(ptr) {} - /*! - * \brief construct from integer. - * \param value The value to be constructed. - */ - TVM_DLL PrimExpr(int32_t value); // NOLINT(*) - /*! - * \brief construct from float. - * \param value The value to be constructed. - */ - TVM_DLL PrimExpr(float value); // NOLINT(*) - /*! - * \brief construct from string. - * \param str The value to be constructed. - */ - TVM_DLL PrimExpr(std::string str); // NOLINT(*) - - /*! \return the data type of this expression. */ - DataType dtype() const { - return static_cast(get())->dtype; - } - - using ContainerType = PrimExprNode; -}; /*! \brief Base node of all statements. */ class StmtNode : public Object { diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h new file mode 100644 index 0000000000000..c947338c407f7 --- /dev/null +++ b/include/tvm/ir/expr.h @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/expr.h + * \brief Base expr nodes in TVM. + */ +#ifndef TVM_IR_EXPR_H_ +#define TVM_IR_EXPR_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { + +/*! + * \brief Base type of all the expressions. + * \sa Expr + */ +class BaseExprNode : public Object { + public: + static constexpr const char* _type_key = "Expr"; + TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); +}; + +/*! + * \brief Managed reference to BaseExprNode. + * \sa BaseExprNode + */ +class BaseExpr : public ObjectRef { + public: + /*! \brief Cosntructor */ + BaseExpr() {} + /*! + * \brief Cosntructor from object ptr. + * \param ptr The object pointer. + */ + explicit BaseExpr(ObjectPtr ptr) : ObjectRef(ptr) {} + /*! \brief The container type. */ + using ContainerType = BaseExprNode; +}; + +/*! + * \brief Base node of all primitive expressions. + * + * A primitive expression deals with low-level + * POD data types and handles without + * doing life-cycle management for objects. + * + * PrimExpr is used in the low-level code + * optimizations and integer analysis. + * + * \sa PrimExpr + */ +class PrimExprNode : public BaseExprNode { + public: + /*! + * \brief The runtime data type of the primitive expression. + * + * runtime::DataType(dtype) provides coarse grained type information + * during compile time and runtime. It is eagerly built in + * PrimExpr expression construction and can be used for + * quick type checking. + * + * dtype is sufficient to decide the Type of the PrimExpr + * when it corresponds to POD value types such as i32. + * + * When dtype is DataType::Handle(), the expression could corresponds to + * a more fine-grained Type, and we can get the type by running lazy type inference. + */ + DataType dtype; + + static constexpr const char* _type_key = "PrimExpr"; + TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); +}; + +/*! + * \brief Reference to PrimExprNode. + * \sa PrimExprNode + */ +class PrimExpr : public BaseExpr { + public: + /*! \brief Cosntructor */ + PrimExpr() {} + /*! + * \brief Cosntructor from object ptr. + * \param ptr The object pointer. + */ + explicit PrimExpr(ObjectPtr ptr) : BaseExpr(ptr) {} + /*! + * \brief construct from integer. + * \param value The value to be constructed. + */ + TVM_DLL PrimExpr(int32_t value); // NOLINT(*) + /*! + * \brief construct from float. + * \param value The value to be constructed. + */ + TVM_DLL PrimExpr(float value); // NOLINT(*) + /*! + * \brief construct from string. + * \param str The value to be constructed. + */ + TVM_DLL PrimExpr(std::string str); // NOLINT(*) + + /*! \return the data type of this expression. */ + DataType dtype() const { + return static_cast(get())->dtype; + } + /*! \brief The container type. */ + using ContainerType = PrimExprNode; +}; + +/*! + * \brief Base node of all non-primitive expressions. + * + * RelayExpr supports tensor types, functions and ADT as + * first class citizens. The life-cycle of the corresponding + * objects are implicitly managed by the language. + * + * \sa RelayExpr + */ +class RelayExprNode : public BaseExprNode { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + /*! + * \brief Stores the result of type inference(type checking). + * + * \note This can be undefined before type inference. + * This value is discarded during serialization. + */ + mutable Type checked_type_ = Type(nullptr); + /*! + * \return The checked_type + */ + const Type& checked_type() const; + /*! + * \brief Check if the inferred(checked) type of the Expr + * is backed by a TTypeNode and return it. + * + * \note This function will thrown an error if the node type + * of this Expr is not TTypeNode. + * + * \return The corresponding TTypeNode pointer. + * \tparam The specific TypeNode we look for. + */ + template + inline const TTypeNode* type_as() const; + + static constexpr const char* _type_key = "relay.Expr"; + TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode); +}; + +/*! + * \brief Managed reference to RelayExprNode. + * \sa RelayExprNode + */ +class RelayExpr : public BaseExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RelayExpr, BaseExpr, RelayExprNode); +}; + +class GlobalVar; +/*! + * \brief Global variable that leaves in the top-level module. + * + * A GlobalVar only refers to function definitions. + * This is used to enable recursive calls between function. + * + * \sa GlobalVarNode + */ +class GlobalVarNode : public RelayExprNode { + public: + /*! \brief The name of the variable, this only acts as a hint. */ + std::string name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name_hint", &name_hint); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + static constexpr const char* _type_key = "relay.GlobalVar"; + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode); +}; + +/*! + * \brief Managed reference to GlobalVarNode. + * \sa GlobalVarNode + */ +class GlobalVar : public RelayExpr { + public: + TVM_DLL explicit GlobalVar(std::string name_hint); + + TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); +}; + +/*! + * \brief Base node of all functions. + * + * We support several variants of functions throughout the stack. + * All of the functions shares the same type system(via checked_type) + * to support cross variant calls. + * + * \sa BaseFunc + */ +class BaseFuncNode : public RelayExprNode { + public: + static constexpr const char* _type_key = "BaseFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(BaseFuncNode, RelayExprNode); +}; + +/*! + * \brief Managed reference to BaseFuncNode. + * \sa BaseFuncNode + */ +class BaseFunc : public RelayExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); +}; + +// implementataions +inline const Type& RelayExprNode::checked_type() const { + CHECK(checked_type_.defined()) + << "internal error: the type checker has " + << "not populated the checked_type " + << "field for " + << GetRef(this); + return this->checked_type_; +} + +template +inline const TTypeNode* RelayExprNode::type_as() const { + static_assert(std::is_base_of::value, + "TType must be a special case of type"); + CHECK(checked_type_.defined()) + << "Type inference for this Expr has not completed. Try to call infer_type pass."; + const TTypeNode* node = checked_type_.as(); + CHECK(node != nullptr) + << "Expected type to be " << TTypeNode::_type_key + << ", but get " << checked_type_->GetTypeKey(); + return node; +} + +} // namespace tvm +#endif // TVM_IR_EXPR_H_ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index ffe1ba876ac65..ab2003e9ec5e4 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -28,7 +28,7 @@ * * ## Relation between Type and runtime::DataType * - * Besides Type, we also store a dtype field in some of the low-level IR's Expr. + * Besides Type, we also store a dtype field in the low-level PrimExpr. * runtime::DataType(dtype) provides coarse grained type information * during compile time and runtime. It is eagerly built in * low-level expression construction and can be used for diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index dac39e014cc7c..7c4b289ec85b8 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -143,7 +143,7 @@ class ConstructorNode : public ExprNode { class Constructor : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Constructor, Expr, ConstructorNode); + TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode); }; /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ @@ -306,7 +306,7 @@ class MatchNode : public ExprNode { class Match : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Match, Expr, MatchNode); + TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode); }; } // namespace relay diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 47c83696c3e50..ba0afb284490d 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -25,6 +25,7 @@ #define TVM_RELAY_EXPR_H_ #include +#include #include #include #include "./base.h" @@ -33,47 +34,12 @@ namespace tvm { namespace relay { -/*! - * \brief A Relay expression. - */ -class Expr; -/*! - * \brief Base type of the Relay expression hiearchy. - */ -class ExprNode : public RelayNode { - public: - /*! - * \brief Stores the result of type inference(type checking). - * - * \note This can be undefined before type inference. - * This value is discarded during serialization. - */ - mutable Type checked_type_ = Type(nullptr); - /*! - * \return The checked_type - */ - const Type& checked_type() const; - /*! - * \brief Check if the inferred(checked) type of the Expr - * is backed by a TTypeNode and return it. - * - * \note This function will thrown an error if the node type - * of this Expr is not TTypeNode. - * - * \return The corresponding TTypeNode pointer. - * \tparam The specific TypeNode we look for. - */ - template - inline const TTypeNode* type_as() const; - - static constexpr const char* _type_key = "relay.Expr"; - TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, RelayNode); -}; - -class Expr : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(Expr, ObjectRef, ExprNode); -}; +using Expr = tvm::RelayExpr; +using ExprNode = tvm::RelayExprNode; +using BaseFunc = tvm::BaseFunc; +using BaseFuncNode = tvm::BaseFuncNode; +using GlobalVar = tvm::GlobalVar; +using GlobalVarNode = tvm::GlobalVarNode; /*! * \brief Constant tensor, backed by an NDArray on the cpu(0) device. @@ -112,7 +78,7 @@ class ConstantNode : public ExprNode { class Constant : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Constant, Expr, ConstantNode); + TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); }; /*! \brief Tuple of multiple Exprs */ @@ -137,7 +103,7 @@ class TupleNode : public ExprNode { class Tuple : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode); }; /*! @@ -193,37 +159,7 @@ class VarNode : public ExprNode { class Var : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); -}; - -/*! - * \brief Global variable that leaves in the top-level module. - * This is used to enable recursive calls between function. - * - * \note A GlobalVar may only point to functions. - */ -class GlobalVar; -/*! \brief A GlobalId from the node's current type to target type. */ -class GlobalVarNode : public ExprNode { - public: - /*! \brief The name of the variable, this only acts as a hint. */ - std::string name_hint; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name_hint", &name_hint); - v->Visit("span", &span); - v->Visit("_checked_type_", &checked_type_); - } - - TVM_DLL static GlobalVar make(std::string name_hint); - - static constexpr const char* _type_key = "relay.GlobalVar"; - TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, ExprNode); -}; - -class GlobalVar : public Expr { - public: - TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, Expr, GlobalVarNode); + TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); }; /*! @@ -231,7 +167,7 @@ class GlobalVar : public Expr { */ class Function; /*! \brief Function container */ -class FunctionNode : public ExprNode { +class FunctionNode : public BaseFuncNode { public: /*! \brief Function parameters */ tvm::Array params; @@ -315,9 +251,9 @@ class FunctionNode : public ExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, ExprNode); }; -class Function : public Expr { +class Function : public BaseFunc { public: - TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); + TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); }; @@ -388,7 +324,7 @@ class CallNode : public ExprNode { class Call : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode); + TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); }; /*! @@ -429,7 +365,7 @@ class LetNode : public ExprNode { class Let : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Let, Expr, LetNode); + TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode); }; /*! @@ -470,7 +406,7 @@ class IfNode : public ExprNode { class If : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); + TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode); }; /*! \brief Get index-th field out of a tuple. */ @@ -497,7 +433,7 @@ class TupleGetItemNode : public ExprNode { class TupleGetItem : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode); + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode); }; /*! \brief Create a new Reference out of initial value. */ @@ -521,7 +457,7 @@ class RefCreateNode : public ExprNode { class RefCreate : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, Expr, RefCreateNode); + TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode); }; /*! \brief Get value out of Reference. */ @@ -545,7 +481,7 @@ class RefReadNode : public ExprNode { class RefRead : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(RefRead, Expr, RefReadNode); + TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode); }; /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ class RefWrite; @@ -571,7 +507,7 @@ class RefWriteNode : public ExprNode { class RefWrite : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, Expr, RefWriteNode); + TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode); }; /*! @@ -600,32 +536,9 @@ class TempExprNode : public ExprNode { class TempExpr : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, Expr, TempExprNode); + TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode); }; -// implementataions -inline const Type& ExprNode::checked_type() const { - CHECK(checked_type_.defined()) - << "internal error: the type checker has " - << "not populated the checked_type " - << "field for " - << GetRef(this); - return this->checked_type_; -} - -template -inline const TTypeNode* ExprNode::type_as() const { - static_assert(std::is_base_of::value, - "TType must be a special case of type"); - CHECK(checked_type_.defined()) - << "Type inference for this Expr has not completed. Try to call infer_type pass."; - const TTypeNode* node = checked_type_.as(); - CHECK(node != nullptr) - << "Expected type to be " << TTypeNode::_type_key - << ", but get " << checked_type_->GetTypeKey(); - return node; -} - /*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */ std::string PrettyPrint(const ObjectRef& node); diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index d7b3b394c5cdd..8292344b3b850 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -25,7 +25,7 @@ #define TVM_RELAY_FEATURE_H_ #include -#include +#include #include namespace tvm { @@ -132,7 +132,6 @@ class FeatureSet { explicit FeatureSet(const std::bitset& bs) : bs_(bs) { } }; -class Expr; /*! * \brief Calculate the feature of the program. * @@ -140,7 +139,7 @@ class Expr; * * \return The FeatureSet. */ -FeatureSet DetectFeature(const Expr& expr); +FeatureSet DetectFeature(const RelayExpr& expr); struct Module; /*! diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index b4495191dd240..6bd0a359cc69e 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -140,7 +140,7 @@ class Op : public relay::Expr { /*! \brief default constructor */ Op() {} /*! \brief constructor from node pointer */ - explicit Op(ObjectPtr n) : Expr(n) {} + explicit Op(ObjectPtr n) : RelayExpr(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/src/ir/expr.cc b/src/ir/expr.cc new file mode 100644 index 0000000000000..f698a5d1802e0 --- /dev/null +++ b/src/ir/expr.cc @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/ir/expr.cc + * \brief The expression AST nodes for the common IR infra. + */ +#include +#include + +namespace tvm { + +GlobalVar::GlobalVar(std::string name_hint) { + ObjectPtr n = make_object(); + n->name_hint = std::move(name_hint); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(GlobalVarNode); + +TVM_REGISTER_GLOBAL("relay._make.GlobalVar") +.set_body_typed([](std::string name){ + return GlobalVar(name); +}); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "GlobalVar(" << node->name_hint << ")"; + }); + +} // namespace tvm diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 62de1c36fc45b..e95e03bb2d50f 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -628,7 +628,7 @@ class CompileEngineImpl : public CompileEngineNode { auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol); const tvm::ir::StringImmNode* symbol_name = ext_symbol.as(); CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false); - auto gv = GlobalVarNode::make(symbol_name->value); + auto gv = GlobalVar(symbol_name->value); ext_mods[code_gen->value]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index b7ecadcc84d9d..601af9e55f9fb 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -101,7 +101,7 @@ class LambdaLifter : public ExprMutator { } auto name = GenerateName(func); - auto global = GlobalVarNode::make(name); + auto global = GlobalVar(name); auto free_vars = FreeVars(func); auto free_type_vars = FreeTypeVars(func, module_); diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 4bac1fd1f6266..82b3513979dcb 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -37,7 +37,8 @@ TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_GLOBAL("relay._base.set_span") .set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { - CHECK(rn); + rn->span = sp; + } else if (auto* rn = node_ref.as()) { rn->span = sp; } else if (auto* rn = node_ref.as()) { rn->span = sp; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index f6ebadf477eb4..239a33ea642c1 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -18,7 +18,7 @@ */ /*! - * \file src/tvm/ir/expr.cc + * \file src/tvm/relay/ir/expr.cc * \brief The expression AST nodes of Relay. */ #include @@ -109,24 +109,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ")"; }); -GlobalVar GlobalVarNode::make(std::string name_hint) { - ObjectPtr n = make_object(); - n->name_hint = std::move(name_hint); - return GlobalVar(n); -} - -TVM_REGISTER_NODE_TYPE(GlobalVarNode); - -TVM_REGISTER_GLOBAL("relay._make.GlobalVar") -.set_body_typed(GlobalVarNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalVar(" << node->name_hint << ")"; - }); - - Function FunctionNode::make(tvm::Array params, Expr body, Type ret_type, diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index fdaa607e380f6..bf1ebf3bdd778 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -279,7 +279,7 @@ Module ModuleNode::FromExpr( } else { func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {}); } - auto main_gv = GlobalVarNode::make("main"); + auto main_gv = GlobalVar("main"); mod->Add(main_gv, func); return mod; } diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index d36733df341b8..98ad4338789c6 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -210,7 +210,7 @@ class ConstantFolder : public ExprMutator { {}, module_->type_definitions, module_->Imports()); - auto global = GlobalVarNode::make("main"); + auto global = GlobalVar("main"); mod->Add(global, func); auto seq = transform::Sequential(passes); mod = seq(mod); diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index 3ca7a081dc279..9e2516b051679 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -155,7 +155,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { auto gv = GetRef(op); if (cm->count(gv) == 0) { - auto cps_gv = GlobalVarNode::make(gv->name_hint + "_cps"); + auto cps_gv = GlobalVar(gv->name_hint + "_cps"); cm->insert({gv, cps_gv}); m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm)); } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index c62520a53e848..ceed96471df76 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -662,7 +662,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") using runtime::TypedPackedFunc; ErrorReporter *err_reporter = new ErrorReporter(); auto module = ModuleNode::make({}, {}); - auto dummy_fn_name = GlobalVarNode::make("test"); + auto dummy_fn_name = GlobalVar("test"); module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {})); auto solver = std::make_shared(dummy_fn_name, module, err_reporter);