From 9b4eb10cbfbf91ea9369067ccf6c7841fe91588e Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Mon, 23 Mar 2020 11:15:07 -0700 Subject: [PATCH] [Refactor] Relay Node::make to constructor (#5128) * relay Node::make to constructor * patternwildcard * Address comments --- include/tvm/relay/adt.h | 85 ++++++++++----- include/tvm/relay/base.h | 6 ++ include/tvm/relay/expr.h | 100 ++++++++++++++---- include/tvm/tir/data_layout.h | 9 +- src/relay/analysis/match_exhaustion.cc | 14 +-- src/relay/analysis/type_solver.cc | 4 +- src/relay/analysis/util.cc | 2 +- src/relay/backend/build_module.cc | 2 +- src/relay/backend/compile_engine.cc | 8 +- src/relay/backend/compile_engine.h | 16 +-- src/relay/backend/interpreter.cc | 4 +- src/relay/backend/utils.h | 2 +- src/relay/backend/vm/compiler.cc | 6 +- src/relay/backend/vm/inline_primitives.cc | 4 +- src/relay/backend/vm/lambda_lift.cc | 8 +- src/relay/ir/adt.cc | 50 +++++---- src/relay/ir/base.cc | 6 ++ src/relay/ir/expr.cc | 88 ++++++++------- src/relay/ir/expr_functor.cc | 24 ++--- src/relay/ir/pattern_functor.cc | 8 +- src/relay/ir/transform.cc | 24 +++-- src/relay/op/algorithm/argsort.cc | 2 +- src/relay/op/algorithm/topk.cc | 2 +- src/relay/op/annotation/annotation.cc | 12 +-- src/relay/op/debug.cc | 2 +- src/relay/op/device_copy.cc | 2 +- src/relay/op/image/dilation2d.cc | 8 +- src/relay/op/image/resize.cc | 8 +- src/relay/op/memory/memory.cc | 8 +- src/relay/op/nn/bitserial.cc | 8 +- src/relay/op/nn/convolution.cc | 36 +++---- src/relay/op/nn/convolution.h | 18 ++-- src/relay/op/nn/nn.cc | 44 ++++---- src/relay/op/nn/pad.cc | 4 +- src/relay/op/nn/pooling.cc | 34 +++--- src/relay/op/nn/sparse.cc | 4 +- src/relay/op/nn/upsampling.cc | 8 +- src/relay/op/op_common.h | 6 +- src/relay/op/tensor/reduce.cc | 4 +- src/relay/op/tensor/transform.cc | 70 ++++++------ src/relay/op/tensor/unary.cc | 6 +- src/relay/op/vision/multibox_op.cc | 4 +- src/relay/op/vision/nms.cc | 4 +- src/relay/op/vision/rcnn_op.cc | 6 +- src/relay/op/vision/yolo.cc | 2 +- src/relay/qnn/op/concatenate.cc | 4 +- src/relay/qnn/op/convolution.cc | 2 +- src/relay/qnn/op/dense.cc | 2 +- src/relay/qnn/op/dequantize.cc | 2 +- src/relay/qnn/op/op_common.h | 8 +- src/relay/qnn/op/quantize.cc | 2 +- src/relay/qnn/op/requantize.cc | 2 +- src/relay/quantize/annotate.cc | 24 +++-- src/relay/quantize/calibrate.cc | 4 +- src/relay/quantize/partition.cc | 20 ++-- src/relay/quantize/quantize.cc | 2 +- src/relay/quantize/realize.cc | 47 ++++---- src/relay/transforms/alter_op_layout.cc | 2 +- src/relay/transforms/annotate_target.cc | 2 +- src/relay/transforms/canonicalize_cast.cc | 4 +- .../transforms/combine_parallel_conv2d.cc | 12 +-- .../transforms/combine_parallel_op_batch.cc | 10 +- src/relay/transforms/convert_layout.cc | 2 +- src/relay/transforms/de_duplicate.cc | 6 +- src/relay/transforms/dead_code.cc | 2 +- src/relay/transforms/device_annotation.cc | 16 +-- src/relay/transforms/eta_expand.cc | 10 +- src/relay/transforms/fold_constant.cc | 10 +- src/relay/transforms/fold_scale_axis.cc | 42 ++++---- src/relay/transforms/forward_rewrite.cc | 6 +- src/relay/transforms/fuse_ops.cc | 10 +- src/relay/transforms/gradient.cc | 62 +++++------ src/relay/transforms/inline.cc | 2 +- src/relay/transforms/let_list.h | 14 +-- src/relay/transforms/merge_composite.cc | 8 +- src/relay/transforms/partial_eval.cc | 38 +++---- src/relay/transforms/partition_graph.cc | 18 ++-- src/relay/transforms/pattern_util.h | 88 +++++++-------- src/relay/transforms/to_a_normal_form.cc | 20 ++-- src/relay/transforms/to_cps.cc | 42 ++++---- src/relay/transforms/to_graph_normal_form.cc | 2 +- src/relay/transforms/transform_layout.h | 14 +-- src/tir/ir/data_layout.cc | 22 ++-- tests/cpp/relay_build_module_test.cc | 10 +- tests/cpp/relay_pass_type_infer_test.cc | 6 +- tests/cpp/relay_transform_sequential.cc | 30 +++--- topi/include/topi/transform.h | 2 +- 87 files changed, 782 insertions(+), 621 deletions(-) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 6f72072e66c1..8189b210e6ba 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -26,11 +26,12 @@ #include #include +#include +#include +#include #include #include -#include "./base.h" -#include "./type.h" -#include "./expr.h" +#include namespace tvm { namespace relay { @@ -69,10 +70,6 @@ class PatternWildcard; /*! \brief PatternWildcard container node */ class PatternWildcardNode : public PatternNode { public: - PatternWildcardNode() {} - - TVM_DLL static PatternWildcard make(); - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } @@ -83,7 +80,29 @@ class PatternWildcardNode : public PatternNode { class PatternWildcard : public Pattern { public: - TVM_DEFINE_OBJECT_REF_METHODS(PatternWildcard, Pattern, PatternWildcardNode); + /* \brief Overload the default constructors. */ + TVM_DLL PatternWildcard(); + explicit PatternWildcard(ObjectPtr n) : Pattern(n) {} + /* \brief Copy constructor. */ + PatternWildcard(const PatternWildcard& pat) : PatternWildcard(pat.data_) {} + /* \brief Move constructor. */ + PatternWildcard(PatternWildcard&& pat) : PatternWildcard(std::move(pat.data_)) {} + /* \brief Copy assignment. */ + PatternWildcard& operator=(const PatternWildcard& other) { + (*this).data_ = other.data_; + return *this; + } + /* \brief Move assignment. */ + PatternWildcard& operator=(PatternWildcard&& other) { + (*this).data_ = std::move(other.data_); + return *this; + } + + const PatternWildcardNode* operator->() const { + return static_cast(get()); + } + + using ContainerType = PatternWildcardNode; }; /*! \brief A var pattern. Accept all input and bind to a var. */ @@ -91,13 +110,9 @@ class PatternVar; /*! \brief PatternVar container node */ class PatternVarNode : public PatternNode { public: - PatternVarNode() {} - /*! \brief Variable that stores the matched value. */ tvm::relay::Var var; - TVM_DLL static PatternVar make(tvm::relay::Var var); - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("var", &var); v->Visit("span", &span); @@ -109,6 +124,12 @@ class PatternVarNode : public PatternNode { class PatternVar : public Pattern { public: + /*! + * \brief Constructor + * \param var The var to construct a pattern + */ + TVM_DLL explicit PatternVar(tvm::relay::Var var); + TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode); }; @@ -122,10 +143,6 @@ class PatternConstructorNode : public PatternNode { /*! Sub-patterns to match against each input to the constructor. */ tvm::Array patterns; - PatternConstructorNode() {} - - TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array var); - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("constructor", &constructor); v->Visit("patterns", &patterns); @@ -138,6 +155,13 @@ class PatternConstructorNode : public PatternNode { class PatternConstructor : public Pattern { public: + /*! + * \brief Constructor + * \param constructor The constructor of a pattern + * \param patterns The sub-patterns for matching + */ + TVM_DLL PatternConstructor(Constructor constructor, tvm::Array patterns); + TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode); }; @@ -149,10 +173,6 @@ class PatternTupleNode : public PatternNode { /*! Sub-patterns to match against each value of the tuple. */ tvm::Array patterns; - PatternTupleNode() {} - - TVM_DLL static PatternTuple make(tvm::Array var); - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("patterns", &patterns); v->Visit("span", &span); @@ -164,6 +184,12 @@ class PatternTupleNode : public PatternNode { class PatternTuple : public Pattern { public: + /*! + * \brief Constructor + * \param patterns The sub-patterns to match against each value of the tuple + */ + TVM_DLL explicit PatternTuple(tvm::Array patterns); + TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode); }; @@ -182,14 +208,19 @@ class ClauseNode : public Object { v->Visit("rhs", &rhs); } - TVM_DLL static Clause make(Pattern lhs, Expr rhs); - static constexpr const char* _type_key = "relay.Clause"; TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object); }; class Clause : public ObjectRef { public: + /*! + * \brief Constructor + * \param lhs The pattern matched by the clause. + * \param rhs The resulting value + */ + TVM_DLL explicit Clause(Pattern lhs, Expr rhs); + TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode); }; @@ -217,14 +248,20 @@ class MatchNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static Match make(Expr data, tvm::Array pattern, bool complete = true); - static constexpr const char* _type_key = "relay.Match"; TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode); }; class Match : public Expr { public: + /*! + * \brief Constructor + * \param data the input being deconstructed. + * \param clauses The clauses for matching. + * \param complete Indicate if this match is complete. + */ + TVM_DLL Match(Expr data, tvm::Array clauses, bool complete = true); + TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode); }; diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index e00329d4d3ed..1d0120675e99 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -103,6 +103,12 @@ class IdNode : public Object { class Id : public ObjectRef { public: + /*! + * \brief The constructor + * \param name_hint The name of the variable. + */ + TVM_DLL explicit Id(std::string name_hint); + TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); }; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 49356ac8a955..3acb5ddae778 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -72,14 +72,18 @@ class ConstantNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static Constant make(runtime::NDArray data); - static constexpr const char* _type_key = "relay.Constant"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode); }; class Constant : public Expr { public: + /*! + * \brief The constructor + * \param data The data of the constant tensor. + */ + TVM_DLL explicit Constant(runtime::NDArray data); + TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); }; @@ -97,14 +101,18 @@ class TupleNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static Tuple make(tvm::Array fields); - static constexpr const char* _type_key = "relay.Tuple"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); }; class Tuple : public Expr { public: + /*! + * \brief The constructor + * \param fields The fields of a tuple. + */ + TVM_DLL explicit Tuple(tvm::Array fields); + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode); }; @@ -161,6 +169,21 @@ class VarNode : public ExprNode { class Var : public Expr { public: + /*! + * \brief The constructor + * \param name_hint The name hint of a variable. + * \param type_annotation The type annotation of a variable. + */ + TVM_DLL Var(std::string name_hint, Type type_annotation) : + Var(Id(name_hint), type_annotation) {} + + /*! + * \brief The constructor + * \param vid The unique id of a variable. + * \param type_annotation The type annotation of a variable. + */ + TVM_DLL Var(Id vid, Type type_annotation); + TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); }; @@ -215,17 +238,24 @@ class CallNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static Call make(Expr op, - Array args, - Attrs attrs = Attrs(), - Array type_args = Array()); - static constexpr const char* _type_key = "relay.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); }; class Call : public Expr { public: + /*! + * \brief The constructor + * \param op The operator will be invoked. + * \param args The arguments of the call. + * \param attrs The attributes of the call node. + * \param type_args The type arguments passed to a polymorphic function. + */ + TVM_DLL Call(Expr op, + Array args, + Attrs attrs = Attrs(), + Array type_args = Array()); + TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); }; @@ -259,14 +289,20 @@ class LetNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static Let make(Var var, Expr value, Expr body); - static constexpr const char* _type_key = "relay.Let"; TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode); }; class Let : public Expr { public: + /*! + * \brief The constructor + * \param var The variable that is bound to. + * \param value The value used to bind to the variable. + * \param body The body of the let binding. + */ + TVM_DLL Let(Var var, Expr value, Expr body); + TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode); }; @@ -300,14 +336,20 @@ class IfNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch); - static constexpr const char* _type_key = "relay.If"; TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); }; class If : public Expr { public: + /*! + * \brief The constructor + * \param cond The condition of a if node. + * \param true_branch The fall through branch + * \param false_branch The branch for execution when condition is false. + */ + TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch); + TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode); }; @@ -327,14 +369,19 @@ class TupleGetItemNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static TupleGetItem make(Expr tuple, int index); - static constexpr const char* _type_key = "relay.TupleGetItem"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); }; class TupleGetItem : public Expr { public: + /*! + * \brief The constructor + * \param tuple The tuple to get an element from. + * \param index The index for extracting a value in the tuple. + */ + TVM_DLL TupleGetItem(Expr tuple, int index); + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode); }; @@ -351,14 +398,18 @@ class RefCreateNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static RefCreate make(Expr value); - static constexpr const char* _type_key = "relay.RefCreate"; TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode); }; class RefCreate : public Expr { public: + /*! + * \brief The constructor + * \param value The initial value of the reference. + */ + TVM_DLL explicit RefCreate(Expr value); + TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode); }; @@ -375,14 +426,18 @@ class RefReadNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static RefRead make(Expr ref); - static constexpr const char* _type_key = "relay.RefRead"; TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode); }; class RefRead : public Expr { public: + /*! + * \brief The constructor + * \param ref The reference where to read data. + */ + TVM_DLL explicit RefRead(Expr ref); + TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode); }; /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ @@ -409,6 +464,13 @@ class RefWriteNode : public ExprNode { class RefWrite : public Expr { public: + /*! + * \brief The constructor + * \param ref The reference where data is write to. + * \param value The value to write. + */ + TVM_DLL RefWrite(Expr ref, Expr value); + TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode); }; diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 52870c669e9d..434337057167 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -335,9 +335,6 @@ class BijectiveLayoutNode : public Object { static constexpr const char* _type_key = "BijectiveLayout"; TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object); - - TVM_DLL static BijectiveLayout make(const Layout& src_layout, - const Layout& dst_layout); }; /*! \brief Bijective function mapping for data layout transformation. @@ -349,6 +346,12 @@ class BijectiveLayout : public ObjectRef { public: BijectiveLayout() = default; explicit BijectiveLayout(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief The constructor + * \param src_layout The source layout + * \param dst_layout The destination layout + */ + TVM_DLL BijectiveLayout(Layout src_layout, Layout dst_layout); // Given the source shape, infer the destination shape. TVM_DLL Array ForwardShape(const Array& shape) const; diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc index 919065469a4d..eeb7fce18c52 100644 --- a/src/relay/analysis/match_exhaustion.cc +++ b/src/relay/analysis/match_exhaustion.cc @@ -191,9 +191,9 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, for (auto constructor : td->constructors) { Array args; for (auto inp : constructor->inputs) { - args.push_back(PatternWildcardNode::make()); + args.push_back(PatternWildcard()); } - ret.push_back(PatternConstructorNode::make(constructor, args)); + ret.push_back(PatternConstructor(constructor, args)); } return ret; } @@ -212,7 +212,7 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, auto all_subfields = CartesianProduct(values_by_field); Array ret; for (auto subfields : all_subfields) { - ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields)); + ret.push_back(PatternConstructor(ctor_cand->constructor, subfields)); } return ret; } @@ -226,9 +226,9 @@ Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, if (cand.as()) { Array args; for (auto inp : clause_tuple->patterns) { - args.push_back(PatternWildcardNode::make()); + args.push_back(PatternWildcard()); } - return {PatternTupleNode::make(args)}; + return {PatternTuple(args)}; } auto tuple_cand = Downcast(cand); @@ -245,7 +245,7 @@ Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, auto all_subfields = CartesianProduct(values_by_field); Array ret; for (auto subfields : all_subfields) { - ret.push_back(PatternTupleNode::make(subfields)); + ret.push_back(PatternTuple(subfields)); } return ret; } @@ -272,7 +272,7 @@ Array UnmatchedCases(const Match& match, const IRModule& mod) { * return failed_candidates */ std::stack candidates; - candidates.push(PatternWildcardNode::make()); + candidates.push(PatternWildcard()); CandidateChecker checker; Array failures; diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index a6ac9ce9a7ec..c39df9d50f58 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -666,7 +666,7 @@ TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver") ErrorReporter *err_reporter = new ErrorReporter(); auto module = IRModule({}, {}); auto dummy_fn_name = GlobalVar("test"); - module->Add(dummy_fn_name, Function({}, TupleNode::make({}), Type(), {}, {})); + module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {}, {})); auto solver = std::make_shared(dummy_fn_name, module, err_reporter); auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { @@ -689,7 +689,7 @@ TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver") }); } else if (name == "AddConstraint") { return TypedPackedFunc([solver](TypeConstraint c) { - Expr e = VarNode::make("dummy_var", + Expr e = Var("dummy_var", IncompleteType(Kind::kType)); return solver->AddConstraint(c, e); }); diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 6a151d7d21f1..6132532e00b6 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -433,7 +433,7 @@ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { Clause VisitClause(const Clause& c) final { Pattern pat = VisitPattern(c->lhs); - return ClauseNode::make(pat, VisitExpr(c->rhs)); + return Clause(pat, VisitExpr(c->rhs)); } private: diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index d42cc27f77d0..4073271073ea 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -210,7 +210,7 @@ class RelayBuildModule : public runtime::ModuleNode { Map GetParams() { Map ret; for (const auto& kv : ret_.params) { - ret.Set(kv.first, ConstantNode::make(kv.second)); + ret.Set(kv.first, Constant(kv.second)); } return ret; } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 1237c56163f9..410a6df504c5 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -60,11 +60,11 @@ LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation im data_ = std::move(n); } -CCacheKey CCacheKeyNode::make(Function source_func, Target target) { +CCacheKey::CCacheKey(Function source_func, Target target) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); - return CCacheKey(n); + data_ = std::move(n); } struct IsDynamicVisitor : public TypeVisitor { @@ -819,7 +819,9 @@ TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") }); TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") -.set_body_typed(CCacheKeyNode::make); +.set_body_typed([](Function source_func, Target target) { + return CCacheKey(source_func, target); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal") .set_body_typed([]() { diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 2dbacf645482..098211e7ea86 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -124,14 +124,6 @@ class CCacheKeyNode : public Object { * \return The result of equality check. */ inline bool Equal(const CCacheKeyNode* other) const; - /*! - * \brief create a cache key. - * \param source_func The source function. - * \param target The target device. - * \return the created key. - */ - TVM_DLL static CCacheKey make(Function source_func, - Target target); static constexpr const char* _type_key = "relay.CCacheKey"; TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); @@ -148,6 +140,14 @@ class CCacheKey : public ObjectRef { public: CCacheKey() {} explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief The constructor + * \param source_func The source function. + * \param target The target device. + */ + TVM_DLL CCacheKey(Function source_func, Target target); + const CCacheKeyNode* operator->() const { return static_cast(get()); } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index cac948836507..631f2d433be5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -309,7 +309,7 @@ class Interpreter : Array ComputeDynamicShape(const Function& func, const Array& args) { - auto key = CCacheKeyNode::make(func, Target::Create("llvm")); + CCacheKey key(func, Target::Create("llvm")); auto cfunc = engine_->LowerShapeFunc(key); size_t arity = cfunc->inputs.size() + cfunc->outputs.size(); @@ -520,7 +520,7 @@ class Interpreter : out_shapes = ComputeDynamicShape(func, args); } - PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_)); + PackedFunc packed_func = engine_->JIT(CCacheKey(func, target_)); TVMRetValue rv; if (const TupleTypeNode* rtype = func->body->checked_type().as()) { CHECK(!is_dyn || out_shapes.size() == rtype->fields.size()); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index cccd4badbb3a..7171589f79d3 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -113,7 +113,7 @@ BindParamsByName(relay::Function func, if (repeat_var.count(arg)) { LOG(FATAL) << "Multiple args in the function have name " << kv.first; } - bind_dict[arg] = ConstantNode::make(kv.second); + bind_dict[arg] = Constant(kv.second); } Expr bound_expr = relay::Bind(func, bind_dict); Function ret = Downcast(bound_expr); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 2fc6567348d8..4d15c76ccb75 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -404,7 +404,7 @@ class VMFunctionCompiler : ExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function - auto key = CCacheKeyNode::make(func, target_host_); + CCacheKey key(func, target_host_); auto cfunc = engine_->LowerShapeFunc(key); int op_index = -1; if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) { @@ -485,7 +485,7 @@ class VMFunctionCompiler : ExprFunctor { } } - auto key = CCacheKeyNode::make(func, target); + CCacheKey key(func, target); auto cfunc = engine_->Lower(key); auto op_index = -1; @@ -780,7 +780,7 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map ret; for (const auto& kv : params_) { - ret.Set(kv.first, ConstantNode::make(kv.second)); + ret.Set(kv.first, Constant(kv.second)); } *rv = ret; }); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 8327a6bb2168..74b2a47634c8 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -92,7 +92,7 @@ struct PrimitiveInliner : ExprMutator { auto new_arg = VisitExpr(arg); call_args.push_back(new_arg); } - return CallNode::make(GetRef(func), call_args, call->attrs, call->type_args); + return Call(GetRef(func), call_args, call->attrs, call->type_args); } } @@ -102,7 +102,7 @@ struct PrimitiveInliner : ExprMutator { auto new_arg = VisitExpr(arg); call_args.push_back(new_arg); } - return CallNode::make(GetRef(global), call_args, call->attrs, call->type_args); + return Call(GetRef(global), call_args, call->attrs, call->type_args); } return ExprMutator::VisitExpr_(call); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index fd8c35152ff1..7e7622ca96cb 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -73,7 +73,7 @@ class LambdaLifter : public ExprMutator { letrec_.pop_back(); } auto body = VisitExpr(let_node->body); - return LetNode::make(let_node->var, value, body); + return Let(let_node->var, value, body); } Expr VisitExpr_(const CallNode* call_node) final { @@ -83,7 +83,7 @@ class LambdaLifter : public ExprMutator { if (!letrec_.empty() && var == letrec_.back()) { auto it = lambda_map_.find(var); CHECK(it != lambda_map_.end()); - return CallNode::make(it->second, call->args, call_node->attrs, + return Call(it->second, call->args, call_node->attrs, call_node->type_args); } } @@ -118,7 +118,7 @@ class LambdaLifter : public ExprMutator { for (auto fv : captured_vars) { fvs.push_back(fv); } - lambda_map_.emplace(letrec_.back(), CallNode::make(global, fvs)); + lambda_map_.emplace(letrec_.back(), Call(global, fvs)); } else { lambda_map_.emplace(letrec_.back(), global); } @@ -178,7 +178,7 @@ class LambdaLifter : public ExprMutator { for (auto fv : captured_vars) { fvs.push_back(fv); } - return CallNode::make(global, fvs); + return Call(global, fvs); } } diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index ff24825f14e8..11c2cbb772fc 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -27,31 +27,35 @@ namespace tvm { namespace relay { -PatternWildcard PatternWildcardNode::make() { +PatternWildcard::PatternWildcard() { ObjectPtr n = make_object(); - return PatternWildcard(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(PatternWildcardNode); TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard") -.set_body_typed(PatternWildcardNode::make); +.set_body_typed([]() { + return PatternWildcard(); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { p->stream << "PatternWildcardNode()"; }); -PatternVar PatternVarNode::make(tvm::relay::Var var) { +PatternVar::PatternVar(tvm::relay::Var var) { ObjectPtr n = make_object(); n->var = std::move(var); - return PatternVar(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(PatternVarNode); TVM_REGISTER_GLOBAL("relay.ir.PatternVar") -.set_body_typed(PatternVarNode::make); +.set_body_typed([](tvm::relay::Var var) { + return PatternVar(var); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -59,18 +63,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "PatternVarNode(" << node->var << ")"; }); -PatternConstructor PatternConstructorNode::make(Constructor constructor, - tvm::Array patterns) { +PatternConstructor::PatternConstructor(Constructor constructor, + tvm::Array patterns) { ObjectPtr n = make_object(); n->constructor = std::move(constructor); n->patterns = std::move(patterns); - return PatternConstructor(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(PatternConstructorNode); TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor") -.set_body_typed(PatternConstructorNode::make); +.set_body_typed([](Constructor constructor, tvm::Array patterns) { + return PatternConstructor(constructor, patterns); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -79,16 +85,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", " << node->patterns << ")"; }); -PatternTuple PatternTupleNode::make(tvm::Array patterns) { +PatternTuple::PatternTuple(tvm::Array patterns) { ObjectPtr n = make_object(); n->patterns = std::move(patterns); - return PatternTuple(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(PatternTupleNode); TVM_REGISTER_GLOBAL("relay.ir.PatternTuple") -.set_body_typed(PatternTupleNode::make); +.set_body_typed([](tvm::Array patterns) { + return PatternTuple(patterns); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -96,17 +104,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "PatternTupleNode(" << node->patterns << ")"; }); -Clause ClauseNode::make(Pattern lhs, Expr rhs) { +Clause::Clause(Pattern lhs, Expr rhs) { ObjectPtr n = make_object(); n->lhs = std::move(lhs); n->rhs = std::move(rhs); - return Clause(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(ClauseNode); TVM_REGISTER_GLOBAL("relay.ir.Clause") -.set_body_typed(ClauseNode::make); +.set_body_typed([](Pattern lhs, Expr rhs) { + return Clause(lhs, rhs); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -115,18 +125,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->rhs << ")"; }); -Match MatchNode::make(Expr data, tvm::Array clauses, bool complete) { +Match::Match(Expr data, tvm::Array clauses, bool complete) { ObjectPtr n = make_object(); n->data = std::move(data); n->clauses = std::move(clauses); n->complete = complete; - return Match(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(MatchNode); TVM_REGISTER_GLOBAL("relay.ir.Match") -.set_body_typed(MatchNode::make); +.set_body_typed([](Expr data, tvm::Array clauses, bool complete) { + return Match(data, clauses, complete); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 22423b8dfe5f..76a3f9d4446e 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -33,6 +33,12 @@ using namespace tvm::runtime; TVM_REGISTER_NODE_TYPE(IdNode); +Id::Id(std::string name_hint) { + ObjectPtr n = make_object(); + n->name_hint = std::move(name_hint); + data_ = std::move(n); +} + TVM_REGISTER_GLOBAL("ir.NodeSetSpan") .set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 5da5be3c43a7..169db62eee26 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -30,16 +30,18 @@ namespace relay { using tvm::ReprPrinter; using namespace tvm::runtime; -Constant ConstantNode::make(runtime::NDArray data) { +Constant::Constant(runtime::NDArray data) { ObjectPtr n = make_object(); n->data = std::move(data); - return Constant(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(ConstantNode); TVM_REGISTER_GLOBAL("relay.ir.Constant") -.set_body_typed(ConstantNode::make); +.set_body_typed([](runtime::NDArray data) { + return Constant(data); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -63,16 +65,18 @@ TensorType ConstantNode::tensor_type() const { return TensorType(shape, dtype); } -Tuple TupleNode::make(tvm::Array fields) { +Tuple::Tuple(tvm::Array fields) { ObjectPtr n = make_object(); n->fields = std::move(fields); - return Tuple(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay.ir.Tuple") -.set_body_typed(TupleNode::make); +.set_body_typed([](tvm::Array fields) { + return Tuple(fields); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -81,23 +85,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); -Var VarNode::make(Id vid, Type type_annotation) { +Var::Var(Id vid, Type type_annotation) { ObjectPtr n = make_object(); n->vid = std::move(vid); n->type_annotation = std::move(type_annotation); - return Var(n); -} - -Var VarNode::make(std::string name_hint, Type type_annotation) { - ObjectPtr n = make_object(); - n->name_hint = std::move(name_hint); - return VarNode::make(Id(n), type_annotation); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_GLOBAL("relay.ir.Var") -.set_body_typed(static_cast(VarNode::make)); +.set_body_typed([](std::string str, Type type_annotation) { + return Var(str, type_annotation); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -110,21 +110,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); - -Call CallNode::make(Expr op, Array args, Attrs attrs, - Array type_args) { +Call::Call(Expr op, Array args, Attrs attrs, Array type_args) { ObjectPtr n = make_object(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); n->type_args = std::move(type_args); - return Call(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_GLOBAL("relay.ir.Call") -.set_body_typed(CallNode::make); +.set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args) { + return Call(op, args, attrs, type_args); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -133,18 +133,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << node->attrs << ", " << node->type_args << ")"; }); -Let LetNode::make(Var var, Expr value, Expr body) { +Let::Let(Var var, Expr value, Expr body) { ObjectPtr n = make_object(); n->var = std::move(var); n->value = std::move(value); n->body = std::move(body); - return Let(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_GLOBAL("relay.ir.Let") -.set_body_typed(LetNode::make); +.set_body_typed([](Var var, Expr value, Expr body) { + return Let(var, value, body); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -153,18 +155,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", " << node->body << ")"; }); -If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { +If::If(Expr cond, Expr true_branch, Expr false_branch) { ObjectPtr n = make_object(); n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); - return If(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay.ir.If") -.set_body_typed(IfNode::make); +.set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { + return If(cond, true_branch, false_branch); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -173,17 +177,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", " << node->false_branch << ")"; }); -TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { +TupleGetItem::TupleGetItem(Expr tuple, int index) { ObjectPtr n = make_object(); n->tuple = std::move(tuple); n->index = index; - return TupleGetItem(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TupleGetItemNode); TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem") -.set_body_typed(TupleGetItemNode::make); +.set_body_typed([](Expr tuple, int index) { + return TupleGetItem(tuple, index); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -191,16 +197,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; }); -RefCreate RefCreateNode::make(Expr value) { +RefCreate::RefCreate(Expr value) { ObjectPtr n = make_object(); n->value = std::move(value); - return RefCreate(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(RefCreateNode); TVM_REGISTER_GLOBAL("relay.ir.RefCreate") -.set_body_typed(RefCreateNode::make); +.set_body_typed([](Expr value) { + return RefCreate(value); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -208,16 +216,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefCreateNode(" << node->value << ")"; }); -RefRead RefReadNode::make(Expr ref) { +RefRead::RefRead(Expr ref) { ObjectPtr n = make_object(); n->ref = std::move(ref); - return RefRead(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(RefReadNode); TVM_REGISTER_GLOBAL("relay.ir.RefRead") -.set_body_typed(RefReadNode::make); +.set_body_typed([](Expr ref) { + return RefRead(ref); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -225,17 +235,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefReadNode(" << node->ref << ")"; }); -RefWrite RefWriteNode::make(Expr ref, Expr value) { +RefWrite::RefWrite(Expr ref, Expr value) { ObjectPtr n = make_object(); n->ref = std::move(ref); n->value = std::move(value); - return RefWrite(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(RefWriteNode); TVM_REGISTER_GLOBAL("relay.ir.RefWrite") -.set_body_typed(RefWriteNode::make); +.set_body_typed([](Expr ref, Expr value) { + return RefWrite(ref, value); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 4b0239b8da21..11e85d583112 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -47,7 +47,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { if (op->type_annotation.defined()) { auto type = this->VisitType(op->type_annotation); if (!op->type_annotation.same_as(type)) { - return VarNode::make(op->vid, type); + return Var(op->vid, type); } } // default case return self. @@ -78,7 +78,7 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { if (all_fields_unchanged) { return GetRef(op); } else { - return TupleNode::make(fields); + return Tuple(fields); } } @@ -134,7 +134,7 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { if (unchanged) { return GetRef(call_node); } else { - return CallNode::make(new_op, call_args, call_node->attrs, ty_args); + return Call(new_op, call_args, call_node->attrs, ty_args); } } @@ -148,7 +148,7 @@ Expr ExprMutator::VisitExpr_(const LetNode* op) { body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(var, value, body); + return Let(var, value, body); } } @@ -161,7 +161,7 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { op->false_branch.same_as(false_b)) { return GetRef(op);; } else { - return IfNode::make(guard, true_b, false_b); + return If(guard, true_b, false_b); } } @@ -170,7 +170,7 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { if (g->tuple == t) { return GetRef(g); } else { - return TupleGetItemNode::make(t, g->index); + return TupleGetItem(t, g->index); } } @@ -179,7 +179,7 @@ Expr ExprMutator::VisitExpr_(const RefCreateNode* op) { if (value.same_as(op->value)) { return GetRef(op); } else { - return RefCreateNode::make(value); + return RefCreate(value); } } @@ -188,7 +188,7 @@ Expr ExprMutator::VisitExpr_(const RefReadNode* op) { if (ref.same_as(op->ref)) { return GetRef(op); } else { - return RefReadNode::make(ref); + return RefRead(ref); } } @@ -198,7 +198,7 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { if (ref.same_as(op->ref) && value.same_as(op->value)) { return GetRef(op); } else { - return RefWriteNode::make(ref, value); + return RefWrite(ref, value); } } @@ -211,12 +211,12 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) { for (const Clause& p : m->clauses) { clauses.push_back(VisitClause(p)); } - return MatchNode::make(VisitExpr(m->data), clauses, m->complete); + return Match(VisitExpr(m->data), clauses, m->complete); } Clause ExprMutator::VisitClause(const Clause& c) { Pattern p = VisitPattern(c->lhs); - return ClauseNode::make(p, VisitExpr(c->rhs)); + return Clause(p, VisitExpr(c->rhs)); } Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } @@ -391,7 +391,7 @@ class ExprBinder : public ExprMutator, PatternMutator { Clause VisitClause(const Clause& c) final { Pattern pat = VisitPattern(c->lhs); - return ClauseNode::make(pat, VisitExpr(c->rhs)); + return Clause(pat, VisitExpr(c->rhs)); } Var VisitVar(const Var& v) final { diff --git a/src/relay/ir/pattern_functor.cc b/src/relay/ir/pattern_functor.cc index cfe8f345f5f8..6795884ef438 100644 --- a/src/relay/ir/pattern_functor.cc +++ b/src/relay/ir/pattern_functor.cc @@ -36,7 +36,7 @@ Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { } Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) { - return PatternVarNode::make(VisitVar(op->var)); + return PatternVar(VisitVar(op->var)); } Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) { @@ -44,7 +44,7 @@ Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) { for (const auto& p : op->patterns) { pat.push_back(VisitPattern(p)); } - return PatternConstructorNode::make(VisitConstructor(op->constructor), pat); + return PatternConstructor(VisitConstructor(op->constructor), pat); } Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) { @@ -52,7 +52,7 @@ Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) { for (const auto& p : op->patterns) { pat.push_back(VisitPattern(p)); } - return PatternTupleNode::make(pat); + return PatternTuple(pat); } Type PatternMutator::VisitType(const Type& t) { @@ -62,7 +62,7 @@ Type PatternMutator::VisitType(const Type& t) { Var PatternMutator::VisitVar(const Var& v) { if (var_map_.count(v) == 0) { var_map_.insert(std::pair(v, - VarNode::make(v->name_hint(), + Var(v->name_hint(), VisitType(v->type_annotation)))); } return var_map_.at(v); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 59c1750ae677..a4bab36d3fe5 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -75,10 +75,6 @@ class FunctionPassNode : public PassNode { */ PassInfo Info() const override { return pass_info; } - TVM_DLL static FunctionPass make( - runtime::TypedPackedFunc pass_func, - PassInfo pass_info); - static constexpr const char* _type_key = "relay.FunctionPass"; TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); @@ -95,16 +91,25 @@ class FunctionPassNode : public PassNode { class FunctionPass : public Pass { public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL FunctionPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info); + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); }; -FunctionPass FunctionPassNode::make( +FunctionPass::FunctionPass( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { auto n = make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); - return FunctionPass(n); + data_ = std::move(n); } // Perform Module -> Module optimizations at the Function level. @@ -149,13 +154,16 @@ Pass CreateFunctionPass( const std::string& name, const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); - return FunctionPassNode::make(pass_func, pass_info); + return FunctionPass(pass_func, pass_info); } TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass") -.set_body_typed(FunctionPassNode::make); +.set_body_typed([](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + return FunctionPass(pass_func, pass_info); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relay/op/algorithm/argsort.cc b/src/relay/op/algorithm/argsort.cc index 13d89a7e1af5..5b03ceec6ccf 100644 --- a/src/relay/op/algorithm/argsort.cc +++ b/src/relay/op/algorithm/argsort.cc @@ -56,7 +56,7 @@ Expr MakeArgsort(Expr data, attrs->is_ascend = is_ascend; attrs->dtype = dtype; static const Op& op = Op::Get("argsort"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index 0ff30bbd3933..225575c69b00 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -79,7 +79,7 @@ Expr MakeTopK(Expr data, attrs->is_ascend = is_ascend; attrs->dtype = dtype; static const Op& op = Op::Get("topk"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index c79ebc143912..dd1bcdc1b9eb 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -44,7 +44,7 @@ TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") auto attrs = make_object(); attrs->device_type = device_type; static const Op& op = Op::Get("on_device"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("on_device") @@ -59,7 +59,7 @@ RELAY_REGISTER_OP("on_device") Expr StopFusion(Expr data) { static const Op& op = Op::Get("annotation.stop_fusion"); - return CallNode::make(op, {data}, Attrs{}, {}); + return Call(op, {data}, Attrs{}, {}); } TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion") @@ -90,7 +90,7 @@ Expr CastHint(Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("annotation.cast_hint"); - return CallNode::make(op, {data}, Attrs{attrs}, {}); + return Call(op, {data}, Attrs{attrs}, {}); } RELAY_REGISTER_OP("annotation.cast_hint") @@ -147,7 +147,7 @@ Mark the end of bitpacking. TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint") .set_body_typed([](Expr data) { static const Op& op = Op::Get("annotation.checkpoint"); - return CallNode::make(op, {data}, Attrs{}, {}); + return Call(op, {data}, Attrs{}, {}); }); RELAY_REGISTER_OP("annotation.checkpoint") @@ -195,7 +195,7 @@ TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin") auto attrs = make_object(); attrs->compiler = compiler; static const Op& op = Op::Get("annotation.compiler_begin"); - return CallNode::make(op, {expr}, Attrs(attrs), {}); + return Call(op, {expr}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("annotation.compiler_end") @@ -220,7 +220,7 @@ TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end") auto attrs = make_object(); attrs->compiler = compiler; static const Op& op = Op::Get("annotation.compiler_end"); - return CallNode::make(op, {expr}, Attrs(attrs), {}); + return Call(op, {expr}, Attrs(attrs), {}); }); } // namespace relay diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc index a0f7fbf4cfeb..8e8586f9d213 100644 --- a/src/relay/op/debug.cc +++ b/src/relay/op/debug.cc @@ -61,7 +61,7 @@ Expr MakeDebug(Expr expr, std::string name) { dattrs->debug_func = EnvFunc(); } static const Op& op = Op::Get("debug"); - return CallNode::make(op, {expr}, Attrs(dattrs), {}); + return Call(op, {expr}, Attrs(dattrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.debug") diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 7e0530c29476..4aae549f217b 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -48,7 +48,7 @@ TVM_REGISTER_GLOBAL("relay.op._make.device_copy") attrs->src_dev_type = src_dev_type; attrs->dst_dev_type = dst_dev_type; static const Op& op = Op::Get("device_copy"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("device_copy") diff --git a/src/relay/op/image/dilation2d.cc b/src/relay/op/image/dilation2d.cc index 55c49d7a081a..7146f3736dd6 100644 --- a/src/relay/op/image/dilation2d.cc +++ b/src/relay/op/image/dilation2d.cc @@ -62,7 +62,7 @@ Expr MakeDilation2D(Expr data, attrs->kernel_layout = std::move(kernel_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("image.dilation2d"); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); + return Call(op, {data, weight}, Attrs(attrs), {}); } template @@ -80,18 +80,18 @@ bool Dilation2DRel(const Array& types, int num_inputs, const Attrs& attrs, const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); CHECK(trans_in_layout.defined()) << "Dilation2D only support input layouts that are convertible from NCHW." << " But got " << in_layout; - const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW); + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); CHECK(trans_kernel_layout.defined()) << "Dilation2D only support kernel layouts that are convertible from OIHW." << " But got " << kernel_layout; Layout out_layout(param->data_layout); - const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); CHECK(trans_out_layout.defined()) << "Dilation2D only support output layouts that are convertible from NCHW." << " But got " << out_layout; diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index 99d97b27503c..c8f976256600 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -44,7 +44,7 @@ bool ResizeRel(const Array& types, const ResizeAttrs* param = attrs.as(); CHECK(param != nullptr); const Layout in_layout(param->layout); - auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) << "Resize only support input layouts that are convertible from NCHW." << " But got " << in_layout; @@ -80,7 +80,7 @@ Expr MakeResize(Expr data, attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } @@ -135,7 +135,7 @@ bool CropAndResizeRel(const Array& types, // 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] static const Layout kNCHW("NCHW"); const Layout in_layout(param->layout); - auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(0, box_indices->shape[0]); oshape.Set(2, crop_size[0]); @@ -163,7 +163,7 @@ Expr MakeCropAndResize(Expr data, attrs->extrapolation_value = std::move(extrapolation_value); attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.crop_and_resize"); - return CallNode::make(op, {data, boxes, box_indices}, Attrs(attrs), {}); + return Call(op, {data, boxes, box_indices}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize") diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index dcb50e680505..c9ab067da594 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -46,7 +46,7 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage") auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("memory.alloc_storage"); - return CallNode::make(op, {size, alignment}, Attrs(attrs), {}); + return Call(op, {size, alignment}, Attrs(attrs), {}); }); bool AllocStorageRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -98,7 +98,7 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") attrs->const_shape = Downcast(shape); } static const Op& op = Op::Get("memory.alloc_tensor"); - return CallNode::make(op, {storage, shape}, Attrs(attrs), {}); + return Call(op, {storage, shape}, Attrs(attrs), {}); }); std::vector FromConstShape(Constant konst) { @@ -211,7 +211,7 @@ bool InvokeTVMOPRel(const Array& types, int num_inputs, const Attrs& attrs TVM_REGISTER_GLOBAL("relay.op.memory._make.invoke_tvm_op") .set_body_typed( [](Expr func, Expr inputs, Expr outputs) { - return CallNode::make(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs()); + return Call(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs()); }); RELAY_REGISTER_OP("memory.invoke_tvm_op") @@ -262,7 +262,7 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_func") static const Op& op = Op::Get("memory.shape_func"); auto attrs = make_object(); attrs->is_input = is_input; - return CallNode::make(op, {func, inputs, outputs}, Attrs(attrs), {}); + return Call(op, {func, inputs, outputs}, Attrs(attrs), {}); }); static void FlattenTypeAux(const Type& type, std::vector* out) { diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index 9457b4b5d171..d2174579cc31 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -94,7 +94,7 @@ Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack attrs->pack_type = pack_type; attrs->name = name; static const Op& op = Op::Get("nn.bitpack"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack); @@ -130,7 +130,7 @@ bool BinaryConv2DRel(const Array& types, int num_inputs, const Attrs& attr static const Layout kNCHW("NCHW"); const Layout in_layout(param->data_layout); - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); CHECK(param->channels.defined()); CHECK(param->kernel_size.defined()); @@ -167,7 +167,7 @@ Expr MakeBinaryConv2D(Expr data, Expr weight, Array strides, Arrayout_dtype = std::move(out_dtype); attrs->unipolar = unipolar; static const Op& op = Op::Get("nn.bitserial_conv2d"); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); + return Call(op, {data, weight}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.bitserial_conv2d").set_body_typed(MakeBinaryConv2D); @@ -235,7 +235,7 @@ Expr MakeBinaryDense(Expr data, Expr weight, IndexExpr units, int data_bits, int attrs->out_dtype = out_dtype; attrs->unipolar = unipolar; static const Op& op = Op::Get("nn.bitserial_dense"); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); + return Call(op, {data, weight}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.bitserial_dense").set_body_typed(MakeBinaryDense); diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 0d10253d80eb..547d5a6ff692 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -60,7 +60,7 @@ Expr MakeConv(Expr data, attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get(op_name); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); + return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -218,18 +218,18 @@ bool Conv2DTransposeRel(const Array& types, const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); CHECK(trans_in_layout.defined()) << "Conv only support input layouts that are convertible from NCHW." << " But got " << in_layout; - const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW); + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); CHECK(trans_kernel_layout.defined()) << "Conv only support kernel layouts that are convertible from OIHW." << " But got "<< kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); CHECK(trans_out_layout.defined()) << "Conv only support output layouts that are convertible from NCHW." << " But got " << out_layout; @@ -324,7 +324,7 @@ Expr MakeConv2DTranspose(Expr data, attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.conv2d_transpose"); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); + return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -383,18 +383,18 @@ bool Conv1DTransposeRel(const Array& types, const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); CHECK(trans_in_layout.defined()) << "Conv only support input layouts that are convertible from NCW." << " But got " << in_layout; - const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIW); + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); CHECK(trans_kernel_layout.defined()) << "Conv only support kernel layouts that are convertible from OIW." << " But got "<< kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCW); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); CHECK(trans_out_layout.defined()) << "Conv only support output layouts that are convertible from NCW." << " But got " << out_layout; @@ -483,7 +483,7 @@ Expr MakeConv1DTranspose(Expr data, attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.conv1d_transpose"); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); + return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -538,18 +538,18 @@ bool Conv2DWinogradRel(const Array& types, const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); CHECK(trans_in_layout.defined()) << "Conv only support input layouts that are convertible from NCHW." << " But got " << in_layout; - const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW); + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); CHECK(trans_kernel_layout.defined()) << "Conv only support kernel layouts that are convertible from OIHW." << " But got "<< kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); CHECK(trans_out_layout.defined()) << "Conv only support output layouts that are convertible from NCHW." << " But got " << out_layout; @@ -632,7 +632,7 @@ Expr MakeConv2DWinograd(Expr data, attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.contrib_conv2d_winograd_without_weight_transform"); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); + return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -695,7 +695,7 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight, auto attrs = make_object(); attrs->tile_size = tile_size; static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform"); - return CallNode::make(op, {weight}, Attrs(attrs), {}); + return Call(op, {weight}, Attrs(attrs), {}); } @@ -759,7 +759,7 @@ Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, attrs->convolution_algorithm = convolution_algorithm; attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_weight_transform"); - return CallNode::make(op, {weight}, Attrs(attrs), {}); + return Call(op, {weight}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform") @@ -805,7 +805,7 @@ Expr MakeConv2DNCHWc(Expr data, attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc"); - return CallNode::make(op, {data, kernel}, Attrs(attrs), {}); + return Call(op, {data, kernel}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc") @@ -855,7 +855,7 @@ Expr MakeDepthwiseConv2DNCHWc(Expr data, attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.contrib_depthwise_conv2d_NCHWc"); - return CallNode::make(op, {data, kernel}, Attrs(attrs), {}); + return Call(op, {data, kernel}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc") @@ -1017,7 +1017,7 @@ Expr MakeDeformableConv2D(Expr data, attrs->out_layout = out_layout; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("nn.deformable_conv2d"); - return CallNode::make(op, {data, offset, weight}, Attrs{attrs}, {}); + return Call(op, {data, offset, weight}, Attrs{attrs}, {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d") diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index e3b4a8510600..6a69178f49b1 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -48,18 +48,18 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); CHECK(trans_in_layout.defined()) << "Conv only support input layouts that are convertible from NCW." << " But got " << in_layout; - const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIW); + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); CHECK(trans_kernel_layout.defined()) << "Conv only support kernel layouts that are convertible from OIW." << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCW); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); CHECK(trans_out_layout.defined()) << "Conv only support output layouts that are convertible from NCW." << " But got " << out_layout; @@ -136,18 +136,18 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); CHECK(trans_in_layout.defined()) << "Conv only support input layouts that are convertible from NCHW." << " But got " << in_layout; - const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW); + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); CHECK(trans_kernel_layout.defined()) << "Conv only support kernel layouts that are convertible from OIHW." << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); CHECK(trans_out_layout.defined()) << "Conv only support output layouts that are convertible from NCHW." << " But got " << out_layout; @@ -255,18 +255,18 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCDHW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); CHECK(trans_in_layout.defined()) << "Conv only support input layouts that are convertible from NCDHW." << " But got " << in_layout; - const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIDHW); + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); CHECK(trans_kernel_layout.defined()) << "Conv only support kernel layouts that are convertible from OIDHW." << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCDHW); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); CHECK(trans_out_layout.defined()) << "Conv only support output layouts that are convertible from NCDHW." << " But got " << out_layout; diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 5203ffc39217..4934e0666315 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -75,7 +75,7 @@ Expr MakeBiasAdd(Expr data, auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.bias_add"); - return CallNode::make(op, {data, bias}, Attrs(attrs), {}); + return Call(op, {data, bias}, Attrs(attrs), {}); } @@ -108,7 +108,7 @@ Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.fifo_buffer"); - return CallNode::make(op, {input, buffer}, Attrs(attrs), {}); + return Call(op, {input, buffer}, Attrs(attrs), {}); } bool FIFOBufferRel(const Array& types, @@ -180,7 +180,7 @@ Expr MakeDense(Expr data, attrs->units = units; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("nn.dense"); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); + return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -212,7 +212,7 @@ Expr MakeLeakyRelu(Expr data, auto attrs = make_object(); attrs->alpha = alpha; static const Op& op = Op::Get("nn.leaky_relu"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } @@ -291,7 +291,7 @@ Expr MakePRelu(Expr data, auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.prelu"); - return CallNode::make(op, {data, alpha}, Attrs(attrs), {}); + return Call(op, {data, alpha}, Attrs(attrs), {}); } @@ -329,7 +329,7 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax") auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.softmax"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); }); @@ -363,7 +363,7 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax") auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.log_softmax"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("nn.log_softmax") @@ -422,7 +422,7 @@ bool BatchFlattenRel(const Array& types, Expr MakeBatchFlatten(Expr data) { static const Op& op = Op::Get("nn.batch_flatten"); - return CallNode::make(op, {data}, Attrs(), {}); + return Call(op, {data}, Attrs(), {}); } @@ -468,7 +468,7 @@ Example:: TVM_REGISTER_GLOBAL("relay.op.nn._make.relu") .set_body_typed([](Expr data) { static const Op& op = Op::Get("nn.relu"); - return CallNode::make(op, {data}, Attrs(), {}); + return Call(op, {data}, Attrs(), {}); }); RELAY_REGISTER_OP("nn.relu") @@ -506,7 +506,7 @@ Expr MakeLRN(Expr data, attrs->beta = beta; attrs->bias = bias; static const Op& op = Op::Get("nn.lrn"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn") @@ -544,7 +544,7 @@ Expr MakeL2Normalize(Expr data, attrs->eps = eps; attrs->axis = std::move(axis); static const Op& op = Op::Get("nn.l2_normalize"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize") @@ -589,7 +589,7 @@ Expr MakeDropout(Expr data, double rate) { auto attrs = make_object(); attrs->rate = rate; static const Op& op = Op::Get("nn.dropout"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout") @@ -687,7 +687,7 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi attrs->center = center; attrs->scale = scale; static const Op& op = Op::Get("nn.batch_norm"); - return CallNode::make(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {}); + return Call(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm") @@ -770,7 +770,7 @@ Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon attrs->center = center; attrs->scale = scale; static const Op& op = Op::Get("nn.instance_norm"); - return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {}); + return Call(op, {data, gamma, beta}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm") @@ -840,7 +840,7 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, attrs->center = center; attrs->scale = scale; static const Op& op = Op::Get("nn.layer_norm"); - return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {}); + return Call(op, {data, gamma, beta}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm") @@ -891,7 +891,7 @@ bool BatchMatmulRel(const Array& types, Expr MakeBatchMatmul(Expr x, Expr y) { static const Op& op = Op::Get("nn.batch_matmul"); - return CallNode::make(op, {x, y}, Attrs(), {}); + return Call(op, {x, y}, Attrs(), {}); } @@ -948,7 +948,7 @@ bool CrossEntropyRel(const Array& types, // Positional relay function to create cross_entropy operator used by frontend FFI. Expr MakeCrossEntropy(Expr predictions, Expr targets) { static const Op& op = Op::Get("nn.cross_entropy"); - return CallNode::make(op, {predictions, targets}, Attrs(), {}); + return Call(op, {predictions, targets}, Attrs(), {}); } @@ -971,7 +971,7 @@ Do log on the data - do not accept logits. // Positional relay function to create cross_entropy_with_logits operator used by frontend FFI. Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { static const Op& op = Op::Get("nn.cross_entropy_with_logits"); - return CallNode::make(op, {predictions, targets}, Attrs(), {}); + return Call(op, {predictions, targets}, Attrs(), {}); } @@ -1005,7 +1005,7 @@ bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attr CHECK(param != nullptr); const int block_size = param->block_size; const Layout in_layout(param->layout); - auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) << "DepthToSpace only support input layouts that are convertible from NCHW." << " But got " << in_layout; @@ -1030,7 +1030,7 @@ Expr MakeDepthToSpace(Expr data, int block_size, std::string layout, std::string attrs->layout = std::move(layout); attrs->mode = std::move(mode); static const Op& op = Op::Get("nn.depth_to_space"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.depth_to_space").set_body_typed(MakeDepthToSpace); @@ -1063,7 +1063,7 @@ bool SpaceToDepthRel(const Array& types, int num_inputs, const Attrs& attr CHECK(param != nullptr); const int block_size = param->block_size; const Layout in_layout(param->layout); - auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) << "SpaceToDepth only support input layouts that are convertible from NCHW." << " But got " << in_layout; @@ -1087,7 +1087,7 @@ Expr MakeSpaceToDepth(Expr data, int block_size, std::string layout) { attrs->block_size = block_size; attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.space_to_depth"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.space_to_depth").set_body_typed(MakeSpaceToDepth); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 3506f42f7675..abff06ef9d88 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -196,7 +196,7 @@ Expr MakePad(Expr data, attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); static const Op& op = Op::Get("nn.pad"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.pad") @@ -270,7 +270,7 @@ Expr MakeMirrorPad(Expr data, Array > pad_width, std::string mo attrs->mode = mode; attrs->pad_width = std::move(pad_width); static const Op& op = Op::Get("nn.mirror_pad"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad") diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 00ec55af87df..6a2a59b91be0 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -70,7 +70,7 @@ Expr MakeMaxPool(Expr data, attrs->layout = std::move(layout); attrs->ceil_mode = ceil_mode; static const Op& op = Op::Get(op_name); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } template @@ -90,7 +90,7 @@ Expr MakeAvgPool(Expr data, attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; static const Op& op = Op::Get(op_name); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } template @@ -175,7 +175,7 @@ Array Pool2DCompute(const Attrs& attrs, auto ceil_mode = param->ceil_mode; Layout layout(param->layout); - CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined()) + CHECK(tir::BijectiveLayout(layout, kNCHW).defined()) << "max_pool2d currently only supports layouts that are convertible from NCHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) << "max_pool2d does not support input split on height"; @@ -336,7 +336,7 @@ Array GlobalPool2DCompute(const Attrs& attrs, const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); - CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined()) + CHECK(tir::BijectiveLayout(layout, kNCHW).defined()) << "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) << "global_avg_pool2d does not support input split on height"; @@ -355,7 +355,7 @@ Expr MakeGlobalAvgPool2D(Expr data, auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_avg_pool2d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } @@ -387,7 +387,7 @@ Expr MakeGlobalMaxPool2D(Expr data, auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_max_pool2d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d") @@ -469,7 +469,7 @@ Array AdaptivePool2DCompute(const Attrs& attrs, const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); - CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined()) + CHECK(tir::BijectiveLayout(layout, kNCHW).defined()) << "Adaptive pool2d currently only supports layouts that are convertible from NCHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) << "Adaptive pool2d does not support input split on height"; @@ -507,7 +507,7 @@ Expr MakeAdaptiveAvgPool2D(Expr data, attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.adaptive_avg_pool2d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool2d") @@ -545,7 +545,7 @@ Expr MakeAdaptiveMaxPool2D(Expr data, attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.adaptive_max_pool2d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool2d") @@ -637,7 +637,7 @@ Array AdaptivePool3DCompute(const Attrs& attrs, const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); - CHECK(BijectiveLayoutNode::make(layout, kNCDHW).defined()) + CHECK(tir::BijectiveLayout(layout, kNCDHW).defined()) << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1) << "Adaptive pool3d does not support input split on depth"; @@ -683,7 +683,7 @@ Expr MakeAdaptiveMaxPool3D(Expr data, attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.adaptive_max_pool3d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d") @@ -721,7 +721,7 @@ Expr MakeAdaptiveAvgPool3D(Expr data, attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.adaptive_avg_pool3d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d") @@ -776,7 +776,7 @@ Array Pool2DGradCompute(const Attrs& attrs, auto ceil_mode = param->ceil_mode; Layout layout(param->layout); - CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined()) + CHECK(tir::BijectiveLayout(layout, kNCHW).defined()) << "pool2d_grad currently only supports layouts that are convertible from NCHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) << "pool2d_grad does not support input split on height"; @@ -820,7 +820,7 @@ Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, attrs->layout = std::move(layout); attrs->ceil_mode = ceil_mode; static const Op& op = Op::Get("nn.max_pool2d_grad"); - return CallNode::make(op, {out_grad, data}, Attrs(attrs), {}); + return Call(op, {out_grad, data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d_grad").set_body_typed(MakeMaxPool2DGrad); @@ -869,7 +869,7 @@ Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, attrs->ceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; static const Op& op = Op::Get("nn.avg_pool2d_grad"); - return CallNode::make(op, {out_grad, data}, Attrs(attrs), {}); + return Call(op, {out_grad, data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d_grad").set_body_typed(MakeAvgPool2DGrad); @@ -975,7 +975,7 @@ Array Pool1DCompute(const Attrs& attrs, auto ceil_mode = param->ceil_mode; Layout layout(param->layout); - CHECK(BijectiveLayoutNode::make(layout, kNCW).defined()) + CHECK(tir::BijectiveLayout(layout, kNCW).defined()) << "max_pool1d currently only supports layouts that are convertible from NCW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool1d does not support input split on width"; @@ -1166,7 +1166,7 @@ Array Pool3DCompute(const Attrs& attrs, auto ceil_mode = param->ceil_mode; Layout layout(param->layout); - CHECK(BijectiveLayoutNode::make(layout, kNCDHW).defined()) + CHECK(tir::BijectiveLayout(layout, kNCDHW).defined()) << "max_pool3d currently only supports layouts that are convertible from NCDHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1) << "max_pool3d does not support input split on depth"; diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index e7db1255c6f0..c761c3f8466e 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -67,7 +67,7 @@ bool SparseDenseRel(const Array& types, int num_inputs, const Attrs& attrs Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) { auto attrs = make_object(); static const Op& op = Op::Get("nn.sparse_dense"); - return CallNode::make(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); + return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense") @@ -116,7 +116,7 @@ bool SparseTransposeRel(const Array& types, int num_inputs, const Attrs& a Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) { auto attrs = make_object(); static const Op& op = Op::Get("nn.sparse_transpose"); - return CallNode::make(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {}); + return Call(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose") diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 2c1f45dc317a..63bd42d8f508 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -76,7 +76,7 @@ bool UpSamplingRel(const Array& types, CHECK(param != nullptr); const Layout in_layout(param->layout); - auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) << "UpSampling only support input layouts that are convertible from NCHW." << " But got " << in_layout; @@ -108,7 +108,7 @@ Expr MakeUpSampling(Expr data, attrs->scale_w = scale_w; attrs->align_corners = align_corners; static const Op& op = Op::Get("nn.upsampling"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling") @@ -155,7 +155,7 @@ bool UpSampling3DRel(const Array& types, CHECK(param != nullptr); const Layout in_layout(param->layout); - auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCDHW); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW); CHECK(layout_converter.defined()) << "UpSampling3D only support input layouts that are convertible from NCDHW." << " But got " << in_layout; @@ -189,7 +189,7 @@ Expr MakeUpSampling3D(Expr data, attrs->scale_w = scale_w; attrs->coordinate_transformation_mode = coordinate_transformation_mode; static const Op& op = Op::Get("nn.upsampling3d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d") diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 25293adea67b..2d89d778e62c 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -51,7 +51,7 @@ namespace relay { TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr data) { \ static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {data}, Attrs(), {}); \ + return Call(op, {data}, Attrs(), {}); \ }); \ RELAY_REGISTER_OP(OpName) \ .set_num_inputs(1) \ @@ -77,7 +77,7 @@ namespace relay { TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr lhs, Expr rhs) { \ static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ + return Call(op, {lhs, rhs}, Attrs(), {}); \ }); \ RELAY_REGISTER_OP(OpName) \ .set_num_inputs(2) \ @@ -94,7 +94,7 @@ namespace relay { TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr lhs, Expr rhs) { \ static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ + return Call(op, {lhs, rhs}, Attrs(), {}); \ }); \ RELAY_REGISTER_OP(OpName) \ .set_num_inputs(2) \ diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 74bc84e9153a..3f220fb64ad5 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -317,7 +317,7 @@ bool ReduceRel(const Array& types, attrs->keepdims = keepdims; \ attrs->exclude = exclude; \ static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {data}, Attrs(attrs), {}); \ + return Call(op, {data}, Attrs(attrs), {}); \ }); \ RELAY_REGISTER_OP(OpName) \ .set_num_inputs(1) \ @@ -624,7 +624,7 @@ Expr MakeVariance(Expr data, attrs->keepdims = keepdims; attrs->exclude = exclude; static const Op& op = Op::Get("variance"); - return CallNode::make(op, {data, mean}, Attrs(attrs), {}); + return Call(op, {data, mean}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make._variance") diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 942ba7e6e41c..3d03b4af8720 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -79,7 +79,7 @@ Expr MakeCast(Expr data, auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("cast"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.ir.cast") @@ -134,7 +134,7 @@ Array CastLikeCompute(const Attrs& attrs, Expr MakeCastLike(Expr data, Expr dtype_like) { static const Op& op = Op::Get("cast_like"); - return CallNode::make(op, {data, dtype_like}, Attrs(), {}); + return Call(op, {data, dtype_like}, Attrs(), {}); } @@ -167,7 +167,7 @@ Expr MakeReinterpret(Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("reinterpret"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay._make.reinterpret").set_body([](const TVMArgs& args, TVMRetValue* rv) { @@ -244,7 +244,7 @@ Expr MakeExpandDims(Expr data, attrs->axis = axis; attrs->num_newaxis = num_newaxis; static const Op& op = Op::Get("expand_dims"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.expand_dims") @@ -280,7 +280,7 @@ Expr MakeConcatenate(Expr data, auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("concatenate"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.concatenate") @@ -374,7 +374,7 @@ Expr MakeStack(Expr data, auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("stack"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.stack") @@ -465,7 +465,7 @@ Expr MakeTranspose(Expr data, auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("transpose"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.transpose") @@ -656,7 +656,7 @@ Expr MakeReshape(Expr data, attrs->newshape = std::move(newshape); attrs->reverse = false; static const Op& op = Op::Get("reshape"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.reshape") @@ -763,7 +763,7 @@ bool ReshapeLikeRel(const Array& types, Expr MakeReshapeLike(Expr data, Expr shape_like) { static const Op& op = Op::Get("reshape_like"); - return CallNode::make(op, {data, shape_like}, Attrs(), {}); + return Call(op, {data, shape_like}, Attrs(), {}); } @@ -806,7 +806,7 @@ bool ArgWhereRel(const Array& types, TVM_REGISTER_GLOBAL("relay.op._make.argwhere") .set_body_typed([](Expr data) { static const Op& op = Op::Get("argwhere"); - return CallNode::make(op, {data}, Attrs(), {}); + return Call(op, {data}, Attrs(), {}); }); RELAY_REGISTER_OP("argwhere") @@ -886,7 +886,7 @@ Expr MakeTake(Expr data, attrs->axis = std::move(axis); attrs->mode = std::move(mode); static const Op& op = Op::Get("take"); - return CallNode::make(op, {data, indices}, Attrs(attrs), {}); + return Call(op, {data, indices}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.take") @@ -966,7 +966,7 @@ Expr MakeFull(Expr fill_value, attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("full"); - return CallNode::make(op, {fill_value}, Attrs(attrs), {}); + return Call(op, {fill_value}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.full") @@ -1001,7 +1001,7 @@ Expr MakeZeros(Array shape, attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("zeros"); - return CallNode::make(op, {}, Attrs(attrs), {}); + return Call(op, {}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.zeros") @@ -1022,7 +1022,7 @@ Expr MakeOnes(Array shape, attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("ones"); - return CallNode::make(op, {}, Attrs(attrs), {}); + return Call(op, {}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.ones") @@ -1068,7 +1068,7 @@ Array FullLikeCompute(const Attrs& attrs, Expr MakeFullLike(Expr data, Expr fill_value) { static const Op& op = Op::Get("full_like"); - return CallNode::make(op, {data, fill_value}, Attrs(), {}); + return Call(op, {data, fill_value}, Attrs(), {}); } TVM_REGISTER_GLOBAL("relay.op._make.full_like") @@ -1191,7 +1191,7 @@ Expr MakeArange(Expr start, attrs->step = step; attrs->dtype = dtype; static const Op& op = Op::Get("arange"); - return CallNode::make(op, {start, stop, step}, Attrs(attrs), {}); + return Call(op, {start, stop, step}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.arange") @@ -1279,7 +1279,7 @@ Expr MakeRepeat(Expr data, attrs->repeats = repeats; attrs->axis = axis; static const Op& op = Op::Get("repeat"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.repeat") @@ -1387,7 +1387,7 @@ Expr MakeTile(Expr data, auto attrs = make_object(); attrs->reps = reps; static const Op& op = Op::Get("tile"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.tile") @@ -1447,7 +1447,7 @@ Expr MakeReverse(Expr data, auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("reverse"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.reverse") @@ -1504,7 +1504,7 @@ bool WhereRel(const Array& types, // Positional relay function to create where operator. Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) { static const Op& op = Op::Get("where"); - return CallNode::make(op, {condition, x, y}); + return Call(op, {condition, x, y}); } Array WhereCompute(const Attrs& attrs, @@ -1563,7 +1563,7 @@ Expr MakeSqueeze(Expr data, auto attrs = make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("squeeze"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.squeeze") @@ -1660,7 +1660,7 @@ bool CollapseSumLikeRel(const Array& types, Expr MakeCollapseSumLike(Expr data, Expr collapse_type) { static const Op& op = Op::Get("collapse_sum_like"); - return CallNode::make(op, {data, collapse_type}, Attrs(), {}); + return Call(op, {data, collapse_type}, Attrs(), {}); } Array CollapseSumLikeCompute(const Attrs& attrs, @@ -1704,7 +1704,7 @@ Expr MakeBroadCastTo(Expr data, Array shape) { static const Op& op = Op::Get("broadcast_to"); auto attrs = make_object(); attrs->shape = std::move(shape); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } Array BroadCastToCompute(const Attrs& attrs, @@ -1741,7 +1741,7 @@ bool BroadCastToLikeRel(const Array& types, Expr MakeBroadCastToLike(Expr data, Expr broadcast_type) { static const Op& op = Op::Get("broadcast_to_like"); - return CallNode::make(op, {data, broadcast_type}, Attrs(), {}); + return Call(op, {data, broadcast_type}, Attrs(), {}); } Array BroadCastToLikeCompute(const Attrs& attrs, @@ -1954,7 +1954,7 @@ Expr MakeStridedSlice(Expr data, attrs->end = std::move(end); attrs->strides = std::move(strides); static const Op& op = Op::Get("strided_slice"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } Array StridedSliceCompute(const Attrs& attrs, @@ -2021,7 +2021,7 @@ Expr MakeStridedSet(Expr data, Expr end, Expr strides) { static const Op& op = Op::Get("strided_set"); - return CallNode::make(op, {data, v, begin, end, strides}, {}); + return Call(op, {data, v, begin, end, strides}, {}); } TVM_REGISTER_GLOBAL("relay.op._make.strided_set") @@ -2136,7 +2136,7 @@ Expr MakeSplit(Expr data, attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); static const Op& op = Op::Get("split"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.split") @@ -2238,7 +2238,7 @@ Expr MakeSliceLike(Expr data, auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("slice_like"); - return CallNode::make(op, {data, shape_like}, Attrs(attrs), {}); + return Call(op, {data, shape_like}, Attrs(attrs), {}); } Array SliceLikeCompute(const Attrs& attrs, @@ -2330,7 +2330,7 @@ bool LayoutTransformRel(const Array& types, CHECK(src_layout.defined() && dst_layout.defined()) << "cannot convert from/to undefined layout"; - auto layout_converter = BijectiveLayoutNode::make(src_layout, dst_layout); + auto layout_converter = tir::BijectiveLayout(src_layout, dst_layout); CHECK(layout_converter.defined()) << "cannot convert from " << params->src_layout << " to " << params->dst_layout; @@ -2346,7 +2346,7 @@ Expr MakeLayoutTransform(Expr data, attrs->src_layout = std::move(src_layout); attrs->dst_layout = std::move(dst_layout); static const Op& op = Op::Get("layout_transform"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.layout_transform") @@ -2374,7 +2374,7 @@ Expr MakeReverseReshape(Expr data, attrs->newshape = std::move(newshape); attrs->reverse = true; static const Op& op = Op::Get("_contrib_reverse_reshape"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape") @@ -2447,7 +2447,7 @@ Array GatherNDCompute(const Attrs& attrs, Expr MakeGatherND(Expr data, Expr indices) { static const Op& op = Op::Get("gather_nd"); - return CallNode::make(op, {data, indices}, {}); + return Call(op, {data, indices}, {}); } TVM_REGISTER_GLOBAL("relay.op._make.gather_nd") @@ -2508,7 +2508,7 @@ Expr MakeSequenceMask(Expr data, attrs->mask_value = std::move(mask_value); attrs->axis = std::move(axis); static const Op& op = Op::Get("sequence_mask"); - return CallNode::make(op, {data, valid_length}, Attrs(attrs), {}); + return Call(op, {data, valid_length}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask") @@ -2629,7 +2629,7 @@ Expr MakeOneHot(Expr indices, attrs->axis = axis; attrs->dtype = dtype; static const Op& op = Op::Get("one_hot"); - return CallNode::make(op, {indices, on_value, off_value}, Attrs(attrs), {}); + return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.one_hot") @@ -2710,7 +2710,7 @@ Array UnRavelIndexCompute(const Attrs& attrs, Expr MakeUnRavelIndex(Expr data, Expr shape) { static const Op& op = Op::Get("unravel_index"); - return CallNode::make(op, {data, shape}, Attrs(), {}); + return Call(op, {data, shape}, Attrs(), {}); } TVM_REGISTER_GLOBAL("relay.op._make.unravel_index") diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index beb6cd742c9c..3da77e994dee 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -184,7 +184,7 @@ TVM_REGISTER_GLOBAL("relay.op._make.clip") attrs->a_min = a_min; attrs->a_max = a_max; static const Op& op = Op::Get("clip"); - return CallNode::make(op, {a}, Attrs(attrs), {}); + return Call(op, {a}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("clip") @@ -347,7 +347,7 @@ TVM_REGISTER_GLOBAL("relay.op._make.shape_of") auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("shape_of"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("shape_of") @@ -397,7 +397,7 @@ TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size") auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("ndarray_size"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("ndarray_size") diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index eb5012fbfb35..cafe9b6dd0c3 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -68,7 +68,7 @@ Expr MakeMultiBoxPrior(Expr data, attrs->offsets = std::move(offsets); attrs->clip = clip; static const Op& op = Op::Get("vision.multibox_prior"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } @@ -141,7 +141,7 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob, attrs->threshold = std::move(threshold); attrs->variances = std::move(variances); static const Op& op = Op::Get("vision.multibox_transform_loc"); - return CallNode::make(op, {cls_prob, loc_pred, anchor}, Attrs(attrs), {}); + return Call(op, {cls_prob, loc_pred, anchor}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_transform_loc") diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index bec0c1d8d45a..25743f98bc0b 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -57,7 +57,7 @@ Expr MakeGetValidCounts(Expr data, attrs->id_index = id_index; attrs->score_index = score_index; static const Op& op = Op::Get("vision.get_valid_counts"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } @@ -125,7 +125,7 @@ Expr MakeNMS(Expr data, attrs->return_indices = return_indices; attrs->invalid_to_bottom = invalid_to_bottom; static const Op& op = Op::Get("vision.non_max_suppression"); - return CallNode::make(op, {data, valid_count}, Attrs(attrs), {}); + return Call(op, {data, valid_count}, Attrs(attrs), {}); } diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 65efd0495656..6b221a279bac 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -57,7 +57,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spa attrs->sample_ratio = sample_ratio; attrs->layout = layout; static const Op& op = Op::Get("vision.roi_align"); - return CallNode::make(op, {data, rois}, Attrs(attrs), {}); + return Call(op, {data, rois}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align") @@ -107,7 +107,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array pooled_size, double spat attrs->spatial_scale = spatial_scale; attrs->layout = layout; static const Op& op = Op::Get("vision.roi_pool"); - return CallNode::make(op, {data, rois}, Attrs(attrs), {}); + return Call(op, {data, rois}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool") @@ -173,7 +173,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array attrs->rpn_min_size = rpn_min_size; attrs->iou_loss = iou_loss; static const Op& op = Op::Get("vision.proposal"); - return CallNode::make(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {}); + return Call(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal") diff --git a/src/relay/op/vision/yolo.cc b/src/relay/op/vision/yolo.cc index 7d152718f3a0..58596778de1d 100644 --- a/src/relay/op/vision/yolo.cc +++ b/src/relay/op/vision/yolo.cc @@ -65,7 +65,7 @@ Expr MakeYoloReorg(Expr data, auto attrs = make_object(); attrs->stride = stride; static const Op& op = Op::Get("vision.yolo_reorg"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 110ee6fb901a..650dcb962d44 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -113,7 +113,7 @@ Expr MakeQnnConcatenate(Expr data, Expr input_scales, Expr input_zero_points, Ex auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("qnn.concatenate"); - return CallNode::make(op, + return Call(op, {data, input_scales, input_zero_points, output_scale, output_zero_point}, Attrs(attrs), {}); } @@ -184,7 +184,7 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, } idx++; } - return MakeConcatenate(TupleNode::make(requantized_exprs), concatenate_attrs->axis); + return MakeConcatenate(Tuple(requantized_exprs), concatenate_attrs->axis); } RELAY_REGISTER_OP("qnn.concatenate") diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 791b035aa06c..de0aae3195f8 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -673,7 +673,7 @@ Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_ze attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("qnn.conv2d"); - return CallNode::make( + return Call( op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, Attrs(attrs), {}); } diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 91f607279f17..7b9733c36586 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -72,7 +72,7 @@ Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kern attrs->units = std::move(units); attrs->out_dtype = out_dtype; static const Op& op = Op::Get("qnn.dense"); - return CallNode::make( + return Call( op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, Attrs(attrs), {}); } diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 087ceac4f300..69389a7317aa 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -62,7 +62,7 @@ Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point) { // A more detailed explanation can be found here - // https://github.com/google/gemmlowp/blob/master/doc/quantization.md static const Op& op = Op::Get("qnn.dequantize"); - return CallNode::make(op, {data, input_scale, input_zero_point}, Attrs(), {}); + return Call(op, {data, input_scale, input_zero_point}, Attrs(), {}); } Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index 444ca60327a1..73be4de0eafd 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -68,10 +68,10 @@ static inline bool QnnBroadcastRel(const Array& types, int num_inputs, con .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ static const Op& op = Op::Get("qnn." OpName); \ - return CallNode::make(op, {lhs, rhs, \ - lhs_scale, lhs_zero_point, \ - rhs_scale, rhs_zero_point, \ - output_scale, output_zero_point}, Attrs(), {}); \ + return Call(op, {lhs, rhs, \ + lhs_scale, lhs_zero_point, \ + rhs_scale, rhs_zero_point, \ + output_scale, output_zero_point}, Attrs(), {}); \ }); \ RELAY_REGISTER_OP("qnn." OpName) \ .set_num_inputs(8) \ diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index d3f8ee5e5cee..43ba4b6b1ba4 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -77,7 +77,7 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis // A more detailed explanation can be found here - // https://github.com/google/gemmlowp/blob/master/doc/quantization.md static const Op& op = Op::Get("qnn.quantize"); - return CallNode::make(op, {data, output_scale, output_zero_point}, Attrs(attrs), {}); + return Call(op, {data, output_scale, output_zero_point}, Attrs(attrs), {}); } Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale, diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index f3351a248109..4ceb3597fc91 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -299,7 +299,7 @@ Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr out attrs->rounding = std::move(rounding); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("qnn.requantize"); - return CallNode::make(op, {data, input_scale, input_zero_point, output_scale, output_zero_point}, + return Call(op, {data, input_scale, input_zero_point, output_scale, output_zero_point}, Attrs(attrs), {}); } diff --git a/src/relay/quantize/annotate.cc b/src/relay/quantize/annotate.cc index 84d6a0d24257..4492ed5bebca 100644 --- a/src/relay/quantize/annotate.cc +++ b/src/relay/quantize/annotate.cc @@ -45,8 +45,6 @@ class QAnnotateExprNode : public TempExprNode { v->Visit("kind", &kind); } - TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind); - Expr Realize() const final; static constexpr const char* _type_key = "relay.QAnnotateExpr"; @@ -55,6 +53,13 @@ class QAnnotateExprNode : public TempExprNode { class QAnnotateExpr : public TempExpr { public: + /*! + * \brief The constructor + * \param expr The original relay expression. + * \param kind The annotation kind. + */ + TVM_DLL QAnnotateExpr(Expr expr, QAnnotateKind kind); + TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode); }; @@ -63,18 +68,17 @@ Expr QAnnotateExprNode::Realize() const { return expr; } -QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) { +QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) { auto rnode = make_object(); - rnode->expr = expr; + rnode->expr = std::move(expr); rnode->kind = kind; - return QAnnotateExpr(rnode); + data_ = std::move(rnode); } TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = QAnnotateExprNode::make(args[0], - static_cast(args[1].operator int())); - }); +.set_body_typed([](Expr expr, int kind) { + return QAnnotateExpr(expr, static_cast(kind)); +}); Pass QuantizeAnnotate() { @@ -87,7 +91,7 @@ Pass QuantizeAnnotate() { const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); Expr ret = (*f)(n->expr, static_cast(kQInput)); - return static_cast(QAnnotateExprNode::make(ret, kQInput)); + return static_cast(QAnnotateExpr(ret, kQInput)); } return e; }; diff --git a/src/relay/quantize/calibrate.cc b/src/relay/quantize/calibrate.cc index faa5f3097c5d..7b1e909501b5 100644 --- a/src/relay/quantize/calibrate.cc +++ b/src/relay/quantize/calibrate.cc @@ -150,7 +150,7 @@ class StatsCollector : private ExprMutator { auto new_e = this->Mutate(expr); const FunctionNode* func = new_e.as(); CHECK(func) << "Input shoule be Function"; - Expr new_body = TupleNode::make(std::move(profile_data_)); + Expr new_body = Tuple(std::move(profile_data_)); return Function(FreeVars(new_body), new_body, NullValue(), func->type_params, func->attrs); } @@ -173,7 +173,7 @@ class StatsCollector : private ExprMutator { new_attrs->kind = QAnnotateKind::kQIdentity; new_attrs->sign = attrs->sign; new_attrs->rounding = attrs->rounding; - Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {}); + Expr identity_quantize = Call(new_call->op, new_args, Attrs{new_attrs}, {}); // add non-const expressions to profile data if (attrs->kind != QAnnotateKind::kQWeight) { diff --git a/src/relay/quantize/partition.cc b/src/relay/quantize/partition.cc index 121e4316867b..39de0bc49d4c 100644 --- a/src/relay/quantize/partition.cc +++ b/src/relay/quantize/partition.cc @@ -45,8 +45,6 @@ class QPartitionExprNode : public TempExprNode { v->Visit("expr", &expr); } - TVM_DLL static QPartitionExpr make(Expr expr); - Expr Realize() const final; static constexpr const char* _type_key = "relay.QPartitionExpr"; @@ -55,6 +53,12 @@ class QPartitionExprNode : public TempExprNode { class QPartitionExpr : public TempExpr { public: + /*! + * \brief The constructor + * \param expr The original relay expression. + */ + TVM_DLL explicit QPartitionExpr(Expr expr); + TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode); }; @@ -66,16 +70,16 @@ Expr QPartitionExprNode::Realize() const { return StopFusion(ret); } -QPartitionExpr QPartitionExprNode::make(Expr expr) { +QPartitionExpr::QPartitionExpr(Expr expr) { auto rnode = make_object(); - rnode->expr = expr; - return QPartitionExpr(rnode); + rnode->expr = std::move(expr); + data_ = std::move(rnode); } TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = QPartitionExprNode::make(args[0]); - }); +.set_body_typed([](Expr expr) { + return QPartitionExpr(expr); +}); Pass QuantizePartition() { runtime::TypedPackedFunc pass_func = diff --git a/src/relay/quantize/quantize.cc b/src/relay/quantize/quantize.cc index b3a8733c45e1..631d8c0fdf58 100644 --- a/src/relay/quantize/quantize.cc +++ b/src/relay/quantize/quantize.cc @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize") attrs->sign = sign; attrs->rounding = rounding; static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); - return CallNode::make(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); + return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); }); diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index e80b17c65331..8e04a99e2813 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -67,14 +67,14 @@ class QRealizeIntExprNode : public QRealizeExprNode { Expr Realize() const final; - TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype); - static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(QRealizeIntExprNode, QRealizeExprNode); }; class QRealizeIntExpr : public QRealizeExpr { public: + TVM_DLL QRealizeIntExpr(Expr data, Expr dom_scale, DataType dtype); + TVM_DEFINE_OBJECT_REF_METHODS(QRealizeIntExpr, QRealizeExpr, QRealizeIntExprNode); }; @@ -87,18 +87,17 @@ Expr QRealizeIntExprNode::Realize() const { return data; } -QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) { +QRealizeIntExpr::QRealizeIntExpr(Expr data, Expr dom_scale, DataType dtype) { ObjectPtr n = make_object(); n->data = std::move(data); n->dom_scale = std::move(dom_scale); n->dtype = std::move(dtype); - return QRealizeIntExpr(n); + data_ = std::move(n); } inline Expr ForwardOp(const Call& ref_call, const Array& args) { - return CallNode::make(ref_call->op, - args, ref_call->attrs, ref_call->type_args); + return Call(ref_call->op, args, ref_call->attrs, ref_call->type_args); } @@ -150,7 +149,7 @@ Expr QuantizeRealize(const Call& ref_call, if (idom_scale_imm == odom_scale_imm) { // same domain scale, only clip data = Clip(data, clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(data, dom_scale, n->dtype); + return QRealizeIntExpr(data, dom_scale, n->dtype); } float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); @@ -170,14 +169,14 @@ Expr QuantizeRealize(const Call& ref_call, static_cast(shift_nbit))); } data = Clip(data, clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(data, dom_scale, n->dtype); + return QRealizeIntExpr(data, dom_scale, n->dtype); } else { data = Cast(data, DataType::Int(64)); data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm, ref_call->type_as()->shape, cfg->rounding); data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype); - return QRealizeIntExprNode::make(data, dom_scale, n->dtype); + return QRealizeIntExpr(data, dom_scale, n->dtype); } } @@ -186,7 +185,7 @@ Expr QuantizeRealize(const Call& ref_call, Expr data = new_args[0]; Expr scaled_data = Multiply(data, MakeConstantScalar(DataType::Float(32), 1 / dom_scale_imm)); Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(round_data, dom_scale, DataType::Float(32)); + return QRealizeIntExpr(round_data, dom_scale, DataType::Float(32)); } Expr FoldConstantOpt(const Expr& expr) { @@ -225,11 +224,11 @@ Expr Conv2dRealize(const Call& ref_call, DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; - Expr ret = CallNode::make(ref_call->op, + Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); - return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); + return QRealizeIntExpr(ret, dom_scale, out_dtype); } RELAY_REGISTER_OP("nn.conv2d") @@ -259,11 +258,11 @@ Expr DenseRealize(const Call& ref_call, DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; - Expr ret = CallNode::make(ref_call->op, + Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); - return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); + return QRealizeIntExpr(ret, dom_scale, out_dtype); } RELAY_REGISTER_OP("nn.dense") @@ -293,7 +292,7 @@ Expr MulRealize(const Call& ref_call, Expr ret = ForwardOp(ref_call, {ldata, rdata}); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); + return QRealizeIntExpr(ret, dom_scale, dtype); } CHECK(!new_args[0]->IsInstance() && !new_args[1]->IsInstance()); return Expr(nullptr); @@ -377,7 +376,7 @@ Expr AddRealize(const Call& ref_call, Expr dom_scale; Array ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale); Expr ret = ForwardOp(ref_call, ret_args); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); + return QRealizeIntExpr(ret, dom_scale, dtype); } CHECK(!new_args[0]->IsInstance() && !new_args[1]->IsInstance()); @@ -398,9 +397,9 @@ Expr ClipRealize(const Call& ref_call, attrs->a_min = ref_attrs->a_min / dom_scale; attrs->a_max = ref_attrs->a_max / dom_scale; - Expr ret = CallNode::make(ref_call->op, + Expr ret = Call(ref_call->op, {n->data}, Attrs(attrs), ref_call->type_args); - return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + return QRealizeIntExpr(ret, n->dom_scale, n->dtype); } CHECK(!new_args[0]->IsInstance()); return Expr(nullptr); @@ -427,8 +426,8 @@ Expr ConcatenateRealize(const Call& ref_call, DataType dtype; Expr dom_scale; Array ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale); - Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)}); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); + Expr ret = ForwardOp(ref_call, {Tuple(ret_args)}); + return QRealizeIntExpr(ret, dom_scale, dtype); } else { for (auto arg : new_args) { CHECK(!arg->IsInstance()); @@ -448,7 +447,7 @@ Expr IdentityRealize(const Call& ref_call, CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = ForwardOp(ref_call, {n->data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + return QRealizeIntExpr(ret, n->dom_scale, n->dtype); } CHECK(!new_args[0]->IsInstance()); return Expr(nullptr); @@ -472,7 +471,7 @@ Expr CastDtypeInputRealize(const Call& ref_call, if (const auto* n = new_args[0].as()) { Expr data = Cast(n->data, cfg->dtype_input); Expr ret = ForwardOp(ref_call, {data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input); + return QRealizeIntExpr(ret, n->dom_scale, cfg->dtype_input); } CHECK(!new_args[0]->IsInstance()); return Expr(nullptr); @@ -493,7 +492,7 @@ Expr AvgPoolRealize(const Call& ref_call, data = Cast(n->data, cfg->dtype_activation); } Expr ret = ForwardOp(ref_call, {data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation); + return QRealizeIntExpr(ret, n->dom_scale, cfg->dtype_activation); } CHECK(!new_args[0]->IsInstance()); return Expr(nullptr); @@ -512,7 +511,7 @@ Expr CastHintRealize(const Call& ref_call, CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = Cast(n->data, param->dtype); - return QRealizeIntExprNode::make(ret, n->dom_scale, param->dtype); + return QRealizeIntExpr(ret, n->dom_scale, param->dtype); } CHECK(!new_args[0]->IsInstance()); return Expr(nullptr); diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index fe8862523dda..63c1cb96886d 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -93,7 +93,7 @@ class AlterTransformMemorizer : public TransformMemorizer { } } if (!modified) { - new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs); + new_e = Call(ref_call->op, new_args, ref_call->attrs); } const CallNode* new_call = new_e.as(); diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index f5bc6c2617ba..162cc19a509e 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -58,7 +58,7 @@ class AnnotateTargetWrapper : public ExprMutator { Expr begin = (*begin_op)(it, target_); compiler_begins.push_back(begin); } - Expr update_call = CallNode::make(call->op, compiler_begins, call->attrs); + Expr update_call = Call(call->op, compiler_begins, call->attrs); const auto* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); CHECK(end_op); diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc index b65e94805e8c..759a4aea741d 100644 --- a/src/relay/transforms/canonicalize_cast.cc +++ b/src/relay/transforms/canonicalize_cast.cc @@ -83,7 +83,7 @@ class CastCanonicalizer : public ExprMutator { if (unchanged) { return GetRef(call); } - return CallNode::make(call->op, call_args, call->attrs, call->type_args); + return Call(call->op, call_args, call->attrs, call->type_args); } } @@ -112,7 +112,7 @@ class CastCanonicalizer : public ExprMutator { const CallNode* new_call = new_expr.as(); CHECK(new_call); CHECK(new_call->op == cast_op_); - return CallNode::make(new_call->op, new_call->args, new_call->attrs, + return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args); } } diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index b4ac99709f93..0dbce9bf5dd6 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -67,9 +67,9 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { CHECK(attrs_b); const auto* tweight_a = a->args[1]->type_as(); const auto* tweight_b = b->args[1]->type_as(); - const auto shape_a = BijectiveLayoutNode::make( + const auto shape_a = tir::BijectiveLayout( Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape); - const auto shape_b = BijectiveLayoutNode::make( + const auto shape_b = tir::BijectiveLayout( Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape); return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && @@ -108,7 +108,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { channel_pos_ = layout.find('C'); CHECK_NE(channel_pos_, std::string::npos); - return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); + return Call(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); } bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { @@ -159,11 +159,11 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { tuple.push_back(branch[depth]->args[i]); } - auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos); + auto concat = MakeConcatenate(Tuple(tuple), arg_channel_pos); new_args.push_back(std::move(concat)); } - return CallNode::make(call->op, new_args, call->attrs, {}); + return Call(call->op, new_args, call->attrs, {}); } void UpdateGroupOutput(const Expr& data, @@ -203,7 +203,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { } auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); CHECK_NE(index, std::string::npos); - return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), + return std::make_tuple(MakeConcatenate(Tuple(weights), index), tir::make_const(DataType::Int(32), num_filters)); } }; diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index 4252cfdec732..fa63573afd50 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -105,10 +105,10 @@ Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) { arg_from_all_branches.push_back(branch[0]->args[i]); } - new_args.push_back(MakeStack(TupleNode::make(arg_from_all_branches), 0)); + new_args.push_back(MakeStack(Tuple(arg_from_all_branches), 0)); } - return CallNode::make(batch_op, new_args, Attrs(), {}); + return Call(batch_op, new_args, Attrs(), {}); } bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { @@ -153,11 +153,11 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, } } - auto stack = MakeStack(TupleNode::make(tuple), 0); + auto stack = MakeStack(Tuple(tuple), 0); new_args.push_back(std::move(stack)); } - return CallNode::make(call->op, new_args, call->attrs, {}); + return Call(call->op, new_args, call->attrs, {}); } void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, @@ -167,7 +167,7 @@ void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, int index = 0; auto split = MakeSplit(data, Integer(branches.size()), 0); for (const auto& branch : branches) { - auto split_data = TupleGetItemNode::make(split, index++); + auto split_data = TupleGetItem(split, index++); auto squeezed_data = MakeSqueeze(split_data, {0}); subst_map->insert({GetRef(branch[depth]), squeezed_data}); } diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index be44db0980a9..871969dd1f37 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -99,7 +99,7 @@ class ConvertTransformMemorizer : public TransformMemorizer { } } if (!modified) { - new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs); + new_e = Call(ref_call->op, new_args, ref_call->attrs); } const CallNode* new_call = new_e.as(); diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 598289fe3fe9..48b8666856a6 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -44,7 +44,7 @@ Expr DeDup(const Expr& e) { Var Fresh(const Var& v) { CHECK_EQ(rename_.count(v), 0); CHECK_EQ(memo_.count(v), 0) << v.as(); - Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation)); + Var ret = Var(v->name_hint(), VisitType(v->type_annotation)); rename_[v] = ret; return ret; } @@ -62,7 +62,7 @@ Expr DeDup(const Expr& e) { Expr VisitExpr_(const LetNode* op) final { Var v = Fresh(op->var); - return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); + return Let(v, VisitExpr(op->value), VisitExpr(op->body)); } Type VisitType(const Type& t) final { @@ -90,7 +90,7 @@ Expr DeDup(const Expr& e) { } Pattern VisitPattern_(const PatternVarNode* op) final { - return PatternVarNode::make(Fresh(op->var)); + return PatternVar(Fresh(op->var)); } Type VisitType_(const TypeVarNode* op) final { diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index deb26aac7918..f4058b2ea6ad 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -84,7 +84,7 @@ class Eliminator : private ExprMutator { Expr VisitExpr_(const LetNode* op) final { Var v = op->var; if (HasLet(v)) { - return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); + return Let(v, VisitExpr(op->value), VisitExpr(op->body)); } else { return VisitExpr(op->body); } diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index 75afc9e7b63e..b4d61f110832 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -134,7 +134,7 @@ class RewriteAnnotation : public ExprMutator { if (value.same_as(op->value) && body.same_as(op->body)) { return ExprMutator::VisitExpr_(op); } else { - Expr new_let = LetNode::make(op->var, value, body); + Expr new_let = Let(op->var, value, body); UpdateAnnotationMap(op, new_let.operator->()); return this->VisitExpr(new_let); } @@ -149,7 +149,7 @@ class RewriteAnnotation : public ExprMutator { } if (annotated) { - Expr new_tuple = TupleNode::make(fields); + Expr new_tuple = Tuple(fields); UpdateAnnotationMap(op, new_tuple.operator->()); return this->VisitExpr(new_tuple); } else { @@ -161,7 +161,7 @@ class RewriteAnnotation : public ExprMutator { Expr tuple = op->tuple; if (NeedDeviceCopy(tuple.operator->(), op)) { Expr new_expr = - TupleGetItemNode::make(GetDeviceCopyExpr(tuple, op), op->index); + TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index); UpdateAnnotationMap(op, new_expr.operator->()); return this->VisitExpr(new_expr); } else { @@ -178,7 +178,7 @@ class RewriteAnnotation : public ExprMutator { if_node->false_branch.same_as(false_br)) { return ExprMutator::VisitExpr_(if_node); } else { - Expr new_if = IfNode::make(cond, true_br, false_br); + Expr new_if = If(cond, true_br, false_br); UpdateAnnotationMap(if_node, new_if.operator->()); return this->VisitExpr(new_if); } @@ -201,7 +201,7 @@ class RewriteAnnotation : public ExprMutator { } if (annotated) { - Call new_call = CallNode::make(call_node->op, new_args, call_node->attrs, + Call new_call = Call(call_node->op, new_args, call_node->attrs, call_node->type_args); UpdateAnnotationMap(call_node, new_call.operator->()); @@ -284,7 +284,7 @@ class RewriteAnnotation : public ExprMutator { attrs->src_dev_type = src_dev_type; attrs->dst_dev_type = dst_dev_type; static const Op& op = Op::Get("device_copy"); - Call device_copy = CallNode::make(op, {src}, Attrs(attrs), {}); + Call device_copy = Call(op, {src}, Attrs(attrs), {}); annotation_map_.insert({device_copy.operator->(), dst_dev_type}); return device_copy; } @@ -526,7 +526,7 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { } else if (tuple->fields.size() == new_body.size()) { return new_expr; } else { - Tuple tuple_body = TupleNode::make(new_body); + Tuple tuple_body = Tuple(new_body); return Function(params, tuple_body, Type(nullptr), fn->type_params, fn->attrs); } @@ -545,7 +545,7 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { return new_fields.size() == 1 ? new_fields[0] : new_expr; } else { return new_fields.size() == 1 ? new_fields[0] - : TupleNode::make(new_fields); + : Tuple(new_fields); } } else { return new_expr; diff --git a/src/relay/transforms/eta_expand.cc b/src/relay/transforms/eta_expand.cc index 978a3a6aa207..c720bdfa14ee 100644 --- a/src/relay/transforms/eta_expand.cc +++ b/src/relay/transforms/eta_expand.cc @@ -89,7 +89,7 @@ class EtaExpander : public ExprMutator { for (const auto& arg : call->args) { new_args.push_back(VisitExpr(arg)); } - return CallNode::make(new_op, new_args, call->attrs, call->type_args); + return Call(new_op, new_args, call->attrs, call->type_args); } Expr VisitExpr_(const ConstructorNode* cons_node) final { @@ -101,14 +101,14 @@ class EtaExpander : public ExprMutator { tvm::Array params; for (const auto& type : cons->inputs) { Type param_type = type_var_replacer_.VisitType(type); - params.push_back(VarNode::make("eta_expand_param", param_type)); + params.push_back(Var("eta_expand_param", param_type)); } tvm::Array type_params; TypeData adt_def = mod_->LookupTypeDef(cons->belong_to); for (const auto& type_var : adt_def->type_vars) { type_params.push_back(type_var_replacer_.VisitType(type_var)); } - Expr body = CallNode::make(cons, params, Attrs()); + Expr body = Call(cons, params, Attrs()); Type ret_type = TypeCall(cons->belong_to, type_params); return Function( @@ -130,14 +130,14 @@ class EtaExpander : public ExprMutator { tvm::Array params; tvm::Array args; for (size_t i = 0; i < func->params.size(); ++i) { - auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation); + auto var = Var("eta_expand_param", func->params[i]->type_annotation); params.push_back(var); args.push_back(var); } return Function( args, - CallNode::make(gvar, params), + Call(gvar, params), func->ret_type, func->type_params); } else { diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 8fcef2f15c49..a52f42054c3e 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -103,7 +103,7 @@ class ConstantFolder : public ExprMutator { body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(var, value, body); + return Let(var, value, body); } } } @@ -187,14 +187,14 @@ class ConstantFolder : public ExprMutator { CHECK_GT(dim, 0) << "invalid dimension after constant eval"; } - return ConstantNode::make(nd_array); + return Constant(nd_array); } else if (const auto* val = value.as()) { runtime::ADT adt = GetRef(val); Array fields; for (size_t i = 0; i < adt.size(); ++i) { fields.push_back(ObjectToExpr(adt[i])); } - return TupleNode::make(fields); + return Tuple(fields); } else { LOG(FATAL) << "Cannot handle " << value->GetTypeKey(); return Expr(); @@ -267,13 +267,13 @@ class ConstantFolder : public ExprMutator { if (shape->data.Shape().size() == 0 && GetScalarFromConstant(shape) == 0) { auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx); - shape = ConstantNode::make(ndarray); + shape = Constant(ndarray); } // Cast the constant into correct dtype auto cast_attrs = make_object(); cast_attrs->dtype = param->dtype; - Expr ret = CallNode::make(cast_op_, { shape }, Attrs(cast_attrs), {}); + Expr ret = Call(cast_op_, { shape }, Attrs(cast_attrs), {}); return ConstEvaluate(ret); } }; diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 4bfe270cc044..c3114c78bd1d 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -92,22 +92,28 @@ class MessageNode : public RelayNode { */ bool require_positive; - static Message make(const AxesSet& axes, bool require_positive); - static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message"; TVM_DECLARE_FINAL_OBJECT_INFO(MessageNode, RelayNode); }; class Message : public ObjectRef { public: + /*! + * \brief The constructor + * \param axes Axes for scaling + * \param require_positive If folding requires the scales to be positive + * values. + */ + Message(const AxesSet& axes, bool require_positive); + TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode); }; -Message MessageNode::make(const AxesSet& axes, bool require_positive) { +Message::Message(const AxesSet& axes, bool require_positive) { auto n = make_object(); n->axes = axes; n->require_positive = require_positive; - return Message(n); + data_ = std::move(n); } /*! @@ -150,7 +156,7 @@ Message Intersect(const Message& lhs, const Message& rhs) { if (!lhs.defined()) return lhs; if (!rhs.defined()) return rhs; auto axes = Intersect(lhs->axes, rhs->axes); - return MessageNode::make(axes, lhs->require_positive || rhs->require_positive); + return Message(axes, lhs->require_positive || rhs->require_positive); } /*! @@ -315,7 +321,7 @@ class ForwardPrep : private ExprVisitor { // Intermediate operators Array ReluForwardPrep(const Call& call, const Message& out_message) { if (out_message.defined()) { - return {MessageNode::make(out_message->axes, true)}; + return {Message(out_message->axes, true)}; } return {out_message}; } @@ -327,7 +333,7 @@ Expr ReluForwardRewrite(const Call& ref_call, if (input == nullptr) return Expr(nullptr); // return transformed conv2d auto rnode = make_object(); - rnode->value = CallNode::make( + rnode->value = Call( ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); rnode->scale = input->scale; rnode->axes = input->axes; @@ -377,7 +383,7 @@ Expr AddSubForwardRewrite(const Call& ref_call, Expr scale = ExpandBiasToMatchAxis( slhs->scale, tlhs->shape.size(), slhs->axes); Expr rhs = Divide(new_args[1], scale); - rnode->value = CallNode::make(ref_call->op, {slhs->value, rhs}, + rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); rnode->scale = slhs->scale; rnode->axes = slhs->axes; @@ -387,7 +393,7 @@ Expr AddSubForwardRewrite(const Call& ref_call, Expr scale = ExpandBiasToMatchAxis( srhs->scale, trhs->shape.size(), srhs->axes); Expr lhs = Divide(new_args[0], scale); - rnode->value = CallNode::make(ref_call->op, {lhs, srhs->value}, + rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); rnode->scale = srhs->scale; rnode->axes = srhs->axes; @@ -476,7 +482,7 @@ Array Conv2DForwardPrep(const Call& call, const Message& out_message) { data_axes = {c_big_axis}; } if (data_axes.defined()) { - return {MessageNode::make(data_axes, false), none}; + return {Message(data_axes, false), none}; } return {none, none}; } @@ -521,7 +527,7 @@ Expr Conv2DForwardRewrite(const Call& ref_call, weight = Multiply(weight, scale); } // return transformed conv2d - return CallNode::make( + return Call( ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); } @@ -726,7 +732,7 @@ Expr BackwardTransformerNode::Transform( // Intermediate operators Message ReluBackwardPrep(const Call& call, const Array& in_messages) { if (in_messages[0].defined()) { - return MessageNode::make(in_messages[0]->axes, true); + return Message(in_messages[0]->axes, true); } return in_messages[0]; } @@ -740,7 +746,7 @@ Expr ReluBackwardTransform(const Call& call, } Expr input = transformer->Transform( call->args[0], message, scale); - return CallNode::make(call->op, {input}, call->attrs, call->type_args); + return Call(call->op, {input}, call->attrs, call->type_args); } RELAY_REGISTER_OP("nn.relu") @@ -796,7 +802,7 @@ Expr AddSubBackwardTransform(const Call& call, CHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); Expr rhs = transformer->Transform(call->args[1], message, scale); - return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); + return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (lhs_message.defined()) { CHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); @@ -805,7 +811,7 @@ Expr AddSubBackwardTransform(const Call& call, Expr rhs_scale = ExpandBiasToMatchAxis( scale, tlhs->shape.size(), message->axes); rhs = Multiply(rhs, rhs_scale); - return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); + return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (rhs_message.defined()) { CHECK(equal(message->axes, rhs_message->axes)); Expr lhs = transformer->Transform( @@ -814,7 +820,7 @@ Expr AddSubBackwardTransform(const Call& call, Expr lhs_scale = ExpandBiasToMatchAxis( scale, trhs->shape.size(), message->axes); lhs = Multiply(lhs, lhs_scale); - return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args); + return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else { LOG(FATAL) << "outstanding scale"; return Expr(); @@ -890,7 +896,7 @@ Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 && (param->groups == 1 || is_depthwise_conv2d)) { - return MessageNode::make({c_big_axis}, false); + return Message({c_big_axis}, false); } else { return NullValue(); } @@ -930,7 +936,7 @@ Expr Conv2DBackwardTransform(const Call& call, Expr wscale = ExpandBiasToMatchAxis( scale, kernel_layout.ndim(), {big_oc_axis}); weight = Multiply(weight, wscale); - return CallNode::make( + return Call( call->op, {data, weight}, call->attrs, call->type_args); } diff --git a/src/relay/transforms/forward_rewrite.cc b/src/relay/transforms/forward_rewrite.cc index fe0df010b626..1d9d2b62a5c3 100644 --- a/src/relay/transforms/forward_rewrite.cc +++ b/src/relay/transforms/forward_rewrite.cc @@ -125,7 +125,7 @@ class ForwardRewriter : private ExprMutator { if (tuple.same_as(op->tuple)) { return GetRef(op); } else { - return TupleGetItemNode::make(tuple, op->index); + return TupleGetItem(tuple, op->index); } } } @@ -142,7 +142,7 @@ class ForwardRewriter : private ExprMutator { if (all_fields_unchanged) { return GetRef(op); } else { - return TupleNode::make(fields); + return Tuple(fields); } } @@ -185,7 +185,7 @@ class ForwardRewriter : private ExprMutator { } } if (unchanged) return ref_call; - return CallNode::make( + return Call( new_op, call_args, call_node->attrs, call_node->type_args); } }; diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index d9bb62e932ab..6e95441ea162 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -837,7 +837,7 @@ class FuseMutator : private ExprMutator { // create a new parameter. std::ostringstream os; os << "p" << params.size(); - auto var = VarNode::make(os.str(), type); + auto var = Var(os.str(), type); params.push_back(var); arguments.push_back(expr); return var; @@ -878,7 +878,7 @@ class FuseMutator : private ExprMutator { auto* ret_group = gmap_.at(call)->FindRoot(); Array new_args = GetNewArguments(call->args, ret_group); - auto new_call = CallNode::make( + auto new_call = Call( call->op, new_args, call->attrs, call->type_args); if (ret_group->root_ref == call) { @@ -902,13 +902,13 @@ class FuseMutator : private ExprMutator { } // This tuple is an intermediate node in the group Array new_fields = GetNewArguments(tuple->fields, ret_group); - return TupleNode::make(new_fields); + return Tuple(new_fields); } Expr VisitExpr_(const TupleGetItemNode* tuple_get) { auto* ret_group = gmap_.at(tuple_get)->FindRoot(); auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0]; - auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index); + auto new_node = TupleGetItem(new_tuple, tuple_get->index); if (ret_group->root_ref == tuple_get) { if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) { // Isolated. This case occurs when tuple is created by an Opaque op @@ -934,7 +934,7 @@ class FuseMutator : private ExprMutator { const GroupInfo& ginfo = ginfo_[group]; auto func = Function(ginfo.params, body, ret_type, {}); func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call)); - return CallNode::make(func, ginfo.arguments, Attrs()); + return Call(func, ginfo.arguments, Attrs()); } Array GetNewArguments(const tvm::Array& args, diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 9f401ed303b7..a3728e905922 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -154,7 +154,7 @@ struct FirstOrderReverseAD : ExprFunctor { for (const ADValue& adval : args) { call_args.push_back(adval->get().forward); } - auto orig = CallNode::make(op_ref, call_args, attrs, type_args); + auto orig = Call(op_ref, call_args, attrs, type_args); orig->checked_type_ = orig_type; auto ret = std::make_shared(ll, orig); backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { @@ -250,7 +250,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { for (const auto& a : args) { grad_res.push_back(a->get().reverse); } - return TupleNode::make(grad_res); + return Tuple(grad_res); }); return Pair(res.forward, grad); }); @@ -297,7 +297,7 @@ Expr LiftTensor(const std::function& f, fields.push_back(field); types.push_back(field->checked_type_); } - auto ret = TupleNode::make(fields); + auto ret = Tuple(fields); ret->checked_type_ = TupleType(types); return std::move(ret); } else { @@ -316,14 +316,14 @@ void TransferGrads(const Type& forward_type, CHECK(IsAtomic(from)) << from; CHECK(IsAtomic(to)) << to; if (forward_type.as()) { - auto from_ref = TupleGetItemNode::make(from, 1); - auto to_ref = TupleGetItemNode::make(to, 1); - ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref))); + auto from_ref = TupleGetItem(from, 1); + auto to_ref = TupleGetItem(to, 1); + ll->Push(RefWrite(to_ref, RefRead(from_ref))); } else if (auto* tt = forward_type.as()) { for (size_t i = 0; i < tt->fields.size(); ++i) { TransferGrads(tt->fields[i], - ll->Push(TupleGetItemNode::make(from, i)), - ll->Push(TupleGetItemNode::make(to, i)), + ll->Push(TupleGetItem(from, i)), + ll->Push(TupleGetItem(to, i)), ll); } } else { @@ -335,7 +335,7 @@ void TransferGrads(const Type& forward_type, /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */ Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) { auto rev = [&](const Expr& e) { - return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e)))); + return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); }; auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); @@ -357,7 +357,7 @@ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) { /*! \brief ReverseType(t) -> t. Get the gradient. */ Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) { auto grad = [&](const Expr& e) { - return ll->Push(RefReadNode::make(GetField(e, 1))); + return ll->Push(RefRead(GetField(e, 1))); }; auto grad_type = [&](const Type& forward_type) { return forward_type; @@ -367,8 +367,8 @@ Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) { void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { if (t.as()) { - ll->Push(RefWriteNode::make(GetField(arg, 1), - Add(ll->Push(RefReadNode::make(GetField(arg, 1))), + ll->Push(RefWrite(GetField(arg, 1), + Add(ll->Push(RefRead(GetField(arg, 1))), grad))); } else if (auto* tt = t.as()) { for (size_t i = 0; i < tt->fields.size(); ++i) { @@ -384,8 +384,8 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { } Expr BPEmpty() { - Expr unitF = Function({}, TupleNode::make({}), TupleType::Empty(), {}); - return RefCreateNode::make(unitF); + Expr unitF = Function({}, Tuple(tvm::Array({})), TupleType::Empty(), {}); + return RefCreate(unitF); } struct ReverseAD : ExprMutator { @@ -412,7 +412,7 @@ struct ReverseAD : ExprMutator { return LetList::With([&](LetList* ll) { auto x_var = ll->Push(x); auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); - auto bpv = ll->Push(RefReadNode::make(bp)); + auto bpv = ll->Push(RefRead(bp)); Expr nbp = Function( {}, LetList::With([&](LetList* ll) { @@ -422,12 +422,12 @@ struct ReverseAD : ExprMutator { auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); TransferGrads(call->checked_type(), ret, dup_ad, ll); - ll->Push(CallNode::make(RefReadNode::make(dup_bp), {})); - return CallNode::make(bpv, {}); + ll->Push(Call(RefRead(dup_bp), {})); + return Call(bpv, {}); }), TupleType::Empty(), {}); - ll->Push(RefWriteNode::make(bp, nbp)); + ll->Push(RefWrite(bp, nbp)); return ret; }); } @@ -451,12 +451,12 @@ struct ReverseAD : ExprMutator { for (size_t i = 0; i < args.size(); i++) { orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll)); } - Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args); + Expr orig = Call(call->op, orig_args, call->attrs, call->type_args); orig->checked_type_ = call->checked_type(); Var orig_var = ll->Push(orig); orig_var->checked_type_ = call->checked_type(); auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); - auto bpv = ll->Push(RefReadNode::make(bp)); + auto bpv = ll->Push(RefRead(bp)); Expr nbp = Function( {}, LetList::With([&](LetList* ll) { @@ -465,11 +465,11 @@ struct ReverseAD : ExprMutator { for (size_t i = 0; i < args.size(); ++i) { UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); } - return CallNode::make(bpv, {}); + return Call(bpv, {}); }), TupleType::Empty(), {}); - ll->Push(RefWriteNode::make(bp, nbp)); + ll->Push(RefWrite(bp, nbp)); return ret; }); } @@ -478,11 +478,11 @@ struct ReverseAD : ExprMutator { Expr VisitExpr_(const ConstantNode* op) final { Expr e = GetRef(op); - return Pair(e, RefCreateNode::make(ZerosLike(e))); + return Pair(e, RefCreate(ZerosLike(e))); } Expr VisitExpr_(const IfNode* op) final { - return IfNode::make(TupleGetItemNode::make(VisitExpr(op->cond), 0), + return If(TupleGetItem(VisitExpr(op->cond), 0), VisitExpr(op->true_branch), VisitExpr(op->false_branch)); } @@ -545,13 +545,13 @@ Expr Gradient(const Expr& re, const IRModule& mod) { Expr rev = ReverseAD(bp, std::make_shared())(e); std::vector args; for (const auto& p : f->params) { - args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p))))); + args.push_back(ll->Push(Pair(p, RefCreate(ZerosLike(p))))); } - auto c = ll->Push(CallNode::make(rev, args)); + auto c = ll->Push(Call(rev, args)); std::function init_grad; init_grad = [&](const Expr& e, const Type& t) { if (t.as()) { - ll->Push(RefWriteNode::make(GetField(e, 1), OnesLike(GetField(e, 0)))); + ll->Push(RefWrite(GetField(e, 1), OnesLike(GetField(e, 0)))); } else if (auto tt = t.as()) { CHECK_GT(tt->fields.size(), 0); init_grad(ll->Push(GetField(e, 0)), tt->fields[0]); @@ -561,10 +561,10 @@ Expr Gradient(const Expr& re, const IRModule& mod) { } }; init_grad(c, f->body->checked_type()); - ll->Push(CallNode::make(RefReadNode::make(bp), {})); + ll->Push(Call(RefRead(bp), {})); std::vector ret; for (const auto& a : args) { - ret.push_back(RefReadNode::make(GetField(a, 1))); + ret.push_back(RefRead(GetField(a, 1))); } std::function get_final_result; get_final_result = [&](const Expr& e, const Type& t) -> Expr { @@ -575,13 +575,13 @@ Expr Gradient(const Expr& re, const IRModule& mod) { for (size_t i = 0; i < tt->fields.size(); ++i) { fields.push_back(get_final_result(ll->Push(GetField(e, i)), tt->fields[i])); } - return TupleNode::make(fields); + return Tuple(fields); } else { LOG(FATAL) << "unhandled type " << t; throw; } }; - return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret)); + return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret)); }); return Function(f->params, body, GradRetType(GetRef(f)), {}); } diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index 9e118ba8f87c..ef3c51f86105 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -151,7 +151,7 @@ class Inliner : ExprMutator { return Bind(func->body, bind_map); } } else if (const auto* call_node = callee.as()) { - return CallNode::make(func, args, call_node->attrs, call_node->type_args); + return Call(func, args, call_node->attrs, call_node->type_args); } else { return std::move(func); } diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index 1fb09478095b..f195c3060e2f 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -78,7 +78,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Expr expr, Type ty) { - return Push(VarNode::make("x", ty), expr); + return Push(Var("x", ty), expr); } /*! @@ -103,7 +103,7 @@ class LetList { CHECK(!used_); Expr ret = body; for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) { - ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret); + ret = Let(std::get<0>(*rit), std::get<1>(*rit), ret); } used_ = true; return ret; @@ -120,10 +120,10 @@ class LetList { * // Automatically call Get with LetList::With * return LetList::With([&](LetList* ll) { * // Turn a call to plus into a variable to avoid duplication of code - * Var b = ll->Push(CallNode::make(plus, {a, a})); - * Var c = ll->Push(CallNode::make(plus, {b, b})); - * Var d = ll->Push(CallNode::make(plus, {c, c})); - * return CallNode::make(plus, {d, d}); + * Var b = ll->Push(Call(plus, {a, a})); + * Var c = ll->Push(Call(plus, {b, b})); + * Var d = ll->Push(Callplus, {c, c})); + * return Call(plus, {d, d}); * }); * } * \endcode @@ -136,7 +136,7 @@ class LetList { return ll.Get(f(&ll)); } - static Expr Let(const Expr& e, const std::function& f) { + static Expr LetBind(const Expr& e, const std::function& f) { return With([&](LetList* ll) { return f(ll->Push(e)); }); diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 1157789df7a3..6506015e4a44 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -45,7 +45,7 @@ class MergeCompositeWrapper : public ExprMutator { if (var_map->find(pattern->name_hint()) == var_map->end()) { // if we haven't encountered this var yet, make a new free var and associate // it with the value at 'root' - auto free_var = VarNode::make(pattern->name_hint(), Type()); + auto free_var = Var(pattern->name_hint(), Type()); var_map->Set(pattern->name_hint(), Array({free_var, root})); return std::move(free_var); } else { @@ -132,7 +132,7 @@ class MergeCompositeWrapper : public ExprMutator { new_args.push_back(new_arg); i++; } - return CallNode::make(root->op, new_args, root->attrs); + return Call(root->op, new_args, root->attrs); } Expr VisitExpr_(const CallNode* cn) { @@ -149,7 +149,7 @@ class MergeCompositeWrapper : public ExprMutator { auto new_e = this->Mutate(arg); new_args.push_back(new_e); } - return CallNode::make(call->op, new_args, call->attrs); + return Call(call->op, new_args, call->attrs); } } @@ -175,7 +175,7 @@ class MergeCompositeWrapper : public ExprMutator { for (const auto& free_var : free_vars) { args.push_back(args_map[free_var->name_hint()][1]); } - auto new_call = CallNode::make(f, args); + auto new_call = Call(f, args); return std::move(new_call); } return std::move(call); diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 2e048f002283..cd1f40c28767 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -603,7 +603,7 @@ static const Op& with_funcid_op = Op::Get("annotation.with_funcid"); Expr MkWithFuncId(const Expr& expr, FuncId fid) { auto attrs = make_object(); attrs->fid = fid; - return CallNode::make(with_funcid_op, {expr}, Attrs(attrs), {}); + return Call(with_funcid_op, {expr}, Attrs(attrs), {}); } Expr StripWithFuncId(const Expr& e); @@ -658,7 +658,7 @@ class PartialEvaluator : public ExprFunctor value.push_back(ps); expr.push_back(ps->dynamic); } - return HasStatic(MkSTuple(value), ll->Push(TupleNode::make(expr))); + return HasStatic(MkSTuple(value), ll->Push(Tuple(expr))); } PStatic VisitExpr_(const TupleGetItemNode* op, LetList* ll) final { @@ -666,7 +666,7 @@ class PartialEvaluator : public ExprFunctor if (ps->pstatic.defined()) { return Downcast(ps->pstatic)->fields[op->index]; } else { - return NoStatic(ll->Push(TupleGetItemNode::make(ps->dynamic, op->index))); + return NoStatic(ll->Push(TupleGetItem(ps->dynamic, op->index))); } } @@ -724,7 +724,7 @@ class PartialEvaluator : public ExprFunctor }); }); store_.Invalidate(); - return NoStatic(ll->Push(IfNode::make(c->dynamic, t, f))); + return NoStatic(ll->Push(If(c->dynamic, t, f))); } } @@ -732,7 +732,7 @@ class PartialEvaluator : public ExprFunctor PStatic ps = VisitExpr(op->value, ll); Static r = MkSRef(); store_.Insert(r.as(), ps); - return HasStatic(r, ll->Push(RefCreateNode::make(ps->dynamic))); + return HasStatic(r, ll->Push(RefCreate(ps->dynamic))); } PStatic VisitExpr_(const RefWriteNode* op, LetList* ll) final { @@ -743,7 +743,7 @@ class PartialEvaluator : public ExprFunctor } else { store_.Invalidate(); } - return HasStatic(MkSTuple({}), ll->Push(RefWriteNode::make(r->dynamic, v->dynamic))); + return HasStatic(MkSTuple({}), ll->Push(RefWrite(r->dynamic, v->dynamic))); } PStatic VisitExpr_(const RefReadNode* op, LetList* ll) final { @@ -754,7 +754,7 @@ class PartialEvaluator : public ExprFunctor return ret; } } - return NoStatic(ll->Push(RefReadNode::make(r->dynamic))); + return NoStatic(ll->Push(RefRead(r->dynamic))); } PStatic VisitExpr_(const CallNode* op, LetList* ll) final { @@ -774,7 +774,7 @@ class PartialEvaluator : public ExprFunctor return Downcast(f->pstatic)->func(f, x, op->attrs, op->type_args, ll); } else { store_.Invalidate(); - return NoStatic(ll->Push(CallNode::make(f->dynamic, x_dyn, op->attrs, op->type_args))); + return NoStatic(ll->Push(Call(f->dynamic, x_dyn, op->attrs, op->type_args))); } } @@ -872,7 +872,7 @@ class PartialEvaluator : public ExprFunctor for (const auto& v : pv) { dyn.push_back(v->dynamic); } - return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args))); + return NoStatic(ll->Push(Call(var, dyn, attrs, type_args))); } }); }; @@ -898,7 +898,7 @@ class PartialEvaluator : public ExprFunctor PStatic VisitFunc(const Function& func, LetList* ll, - const Var& name = VarNode::make("x", Type())) { + const Var& name = Var("x", Type())) { Func f = VisitFuncStatic(func, name); Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func)))); // TODO(@M.K.): we seems to reduce landin knot into letrec. @@ -919,13 +919,13 @@ class PartialEvaluator : public ExprFunctor if (!st->pstatic.defined()) { throw ReflectError(); } else if (const STensorNode* op = st->pstatic.as()) { - return ConstantNode::make(op->data); + return Constant(op->data); } else if (const STupleNode* op = st->pstatic.as()) { tvm::Array fields; for (const PStatic& field : op->fields) { fields.push_back(Reflect(field)); } - return TupleNode::make(fields); + return Tuple(fields); } else { LOG(FATAL) << "Unknown case: " << st->dynamic; throw; @@ -935,7 +935,7 @@ class PartialEvaluator : public ExprFunctor PStatic Reify(const ObjectRef& v, LetList* ll) const { if (v->IsInstance()) { auto nd_array = Downcast(v); - return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array))); + return HasStatic(MkSTensor(nd_array), ll->Push(Constant(nd_array))); } else if (const runtime::ADTObj* op = v.as()) { std::vector fields; tvm::Array fields_dyn; @@ -945,7 +945,7 @@ class PartialEvaluator : public ExprFunctor fields.push_back(ps); fields_dyn.push_back(ps->dynamic); } - return HasStatic(MkSTuple(fields), ll->Push(TupleNode::make(fields_dyn))); + return HasStatic(MkSTuple(fields), ll->Push(Tuple(fields_dyn))); } else { LOG(FATAL) << "Unknown case"; throw; @@ -977,7 +977,7 @@ class PartialEvaluator : public ExprFunctor ns_args.push_back(ps->dynamic); } auto ns = [&]() { - return NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args))); + return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args))); }; if (StatefulOp(expr)) { return ns(); @@ -987,7 +987,7 @@ class PartialEvaluator : public ExprFunctor for (const PStatic& ps : pv) { args.push_back(Reflect(ps)); } - return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll); + return ConstEvaluate(Call(expr, args, attrs, type_args), ll); } catch (const ReflectError&) { return ns(); @@ -1010,7 +1010,7 @@ class PartialEvaluator : public ExprFunctor for (const PStatic& ps : pv) { dyn.push_back(ps->dynamic); } - return HasStatic(MkSConstructor(c, pv), ll->Push(CallNode::make(c, dyn))); + return HasStatic(MkSConstructor(c, pv), ll->Push(Call(c, dyn))); }; return HasStatic(MkSFunc(f), GetRef(op)); } @@ -1036,10 +1036,10 @@ class PartialEvaluator : public ExprFunctor return VisitExpr(c->rhs, ll)->dynamic; }); }); - clauses.push_back(ClauseNode::make(c->lhs, expr)); + clauses.push_back(Clause(c->lhs, expr)); } store_.Invalidate(); - return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete))); + return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete))); }(); default: LOG(FATAL) << "Unknown MatchStatus"; diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 17f5cfa0778b..3e4a1820b731 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -169,7 +169,7 @@ class Partitioner : public ExprMutator { auto compiler_attrs = call->attrs.as(); // The type of the created variable is the same as the compiler_begin // node. - auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), + auto var = Var(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), call->checked_type_); // Find the corresponding subgraph and add the argument. @@ -246,7 +246,7 @@ class Partitioner : public ExprMutator { module_->Add(glob_func, subgraph_func); // The return type of callnode is the same as the type of the // compiler_end node. - auto ret = CallNode::make(glob_func, args); + auto ret = Call(glob_func, args); ret->checked_type_ = call->checked_type_; return std::move(ret); } @@ -264,7 +264,7 @@ class Partitioner : public ExprMutator { for (auto field : op->fields) { fields.push_back(VisitExpr(field)); } - return TupleNode::make(fields); + return Tuple(fields); } } @@ -275,7 +275,7 @@ class Partitioner : public ExprMutator { } else { AddToSubgraph(subgraph, g->tuple); auto t = VisitExpr(g->tuple); - return TupleGetItemNode::make(t, g->index); + return TupleGetItem(t, g->index); } } @@ -309,7 +309,7 @@ class Partitioner : public ExprMutator { auto value = VisitExpr(op->value); auto body = VisitExpr(op->body); - return LetNode::make(var, value, body); + return Let(var, value, body); } } @@ -324,7 +324,7 @@ class Partitioner : public ExprMutator { auto guard = VisitExpr(op->cond); auto true_b = VisitExpr(op->true_branch); auto false_b = VisitExpr(op->false_branch); - return IfNode::make(guard, true_b, false_b); + return If(guard, true_b, false_b); } } @@ -335,7 +335,7 @@ class Partitioner : public ExprMutator { } else { AddToSubgraph(subgraph, op->value); Expr value = VisitExpr(op->value); - return RefCreateNode::make(value); + return RefCreate(value); } } @@ -346,7 +346,7 @@ class Partitioner : public ExprMutator { } else { AddToSubgraph(subgraph, op->ref); Expr ref = VisitExpr(op->ref); - return RefReadNode::make(ref); + return RefRead(ref); } } @@ -358,7 +358,7 @@ class Partitioner : public ExprMutator { AddToSubgraph(subgraph, op->ref); Expr ref = VisitExpr(op->ref); Expr value = VisitExpr(op->value); - return RefWriteNode::make(ref, value); + return RefWrite(ref, value); } } diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index e59d5c9bb38b..e86fcdcc23aa 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -129,7 +129,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, } if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) { static const Op& squeeze_op = Op::Get("squeeze"); - *rhs_value = CallNode::make(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {}); + *rhs_value = Call(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {}); } return true; } @@ -155,7 +155,7 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, auto attrs = make_object(); attrs->axis = i; attrs->num_newaxis = static_cast(num_pad_axis); - bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {}); + bias = Call(expand_dims, {bias}, Attrs(attrs), {}); } } else { int64_t diff = axes[i]->value - axes[i - 1]->value; @@ -164,7 +164,7 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, auto attrs = make_object(); attrs->axis = i; attrs->num_newaxis = static_cast(diff); - bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {}); + bias = Call(expand_dims, {bias}, Attrs(attrs), {}); } } } @@ -182,7 +182,7 @@ inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param, const Layout& kernel_layout) { static const Layout kOIHW("OIHW"); - const auto bilayout = BijectiveLayoutNode::make(kernel_layout, kOIHW); + const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW); auto wshape = bilayout.ForwardShape(call->args[1]->type_as()->shape); return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1); @@ -257,7 +257,7 @@ inline Constant MakeConstantScalar(DataType dtype, T value) { *static_cast(arr->data) = value; } }) - return ConstantNode::make(arr); + return Constant(arr); } /*! @@ -285,7 +285,7 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector s } } }) - return ConstantNode::make(arr); + return Constant(arr); } /*! @@ -304,31 +304,31 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { } inline Expr GetField(Expr t, size_t i) { - return TupleGetItemNode::make(t, i); + return TupleGetItem(t, i); } inline Expr Pair(Expr l, Expr r) { - return TupleNode::make({l, r}); + return Tuple({l, r}); } inline Expr Exp(Expr e) { static const Op& op = Op::Get("exp"); - return CallNode::make(op, {e}); + return Call(op, {e}); } inline Expr FastExp(Expr e) { static const Op& op = Op::Get("fast_exp"); - return CallNode::make(op, {e}); + return Call(op, {e}); } inline Expr FastTanh(Expr e) { static const Op& op = Op::Get("fast_tanh"); - return CallNode::make(op, {e}); + return Call(op, {e}); } inline Expr Log(Expr e) { static const Op& op = Op::Get("log"); - return CallNode::make(op, {e}); + return Call(op, {e}); } /*! * \brief Get an immediate scalar from a Constant expr. @@ -348,30 +348,30 @@ inline Expr Cast(Expr x, DataType dtype) { static const Op& op = Op::Get("cast"); auto attrs = make_object(); attrs->dtype = dtype; - return CallNode::make(op, {x}, Attrs(attrs), {}); + return Call(op, {x}, Attrs(attrs), {}); } inline Expr Negative(Expr x) { static const Op& op = Op::Get("negative"); - return CallNode::make(op, {x}, Attrs(), {}); + return Call(op, {x}, Attrs(), {}); } inline Expr Sqrt(Expr x) { static const Op& op = Op::Get("sqrt"); - return CallNode::make(op, {x}, Attrs(), {}); + return Call(op, {x}, Attrs(), {}); } inline Expr Relu(Expr x) { static const Op& op = Op::Get("nn.relu"); - return CallNode::make(op, {x}, Attrs(), {}); + return Call(op, {x}, Attrs(), {}); } inline Expr Round(Expr x) { static const Op& op = Op::Get("round"); - return CallNode::make(op, {x}, Attrs(), {}); + return Call(op, {x}, Attrs(), {}); } @@ -380,41 +380,41 @@ inline Expr Clip(Expr x, double a_min, double a_max) { auto attrs = make_object(); attrs->a_min = a_min; attrs->a_max = a_max; - return CallNode::make(op, {x}, Attrs(attrs), {}); + return Call(op, {x}, Attrs(attrs), {}); } inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(), {}); } inline Expr Subtract(Expr lhs, Expr rhs) { static const Op& op = Op::Get("subtract"); - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(), {}); } inline Expr Multiply(Expr lhs, Expr rhs) { static const Op& op = Op::Get("multiply"); - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(), {}); } inline Expr Divide(Expr lhs, Expr rhs) { static const Op& op = Op::Get("divide"); - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(), {}); } inline Expr Maximum(Expr lhs, Expr rhs) { static const Op& op = Op::Get("maximum"); - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(), {}); } inline Expr ZerosLike(Expr e) { static const Op& op = Op::Get("zeros_like"); - return CallNode::make(op, {e}); + return Call(op, {e}); } inline Expr Zeros(Array shape, DataType dtype) { @@ -422,46 +422,46 @@ inline Expr Zeros(Array shape, DataType dtype) { attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("zeros"); - return CallNode::make(op, {}, Attrs(attrs), {}); + return Call(op, {}, Attrs(attrs), {}); } inline Expr OnesLike(Expr e) { static const Op& op = Op::Get("ones_like"); - return CallNode::make(op, {e}); + return Call(op, {e}); } inline Expr CollapseSumLike(Expr e) { static const Op& op = Op::Get("collapse_sum_like"); - return CallNode::make(op, {e}); + return Call(op, {e}); } inline Expr Power(Expr lhs, Expr rhs) { static const Op& op = Op::Get("power"); - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(), {}); } inline Expr RightShift(Expr x, Expr nbit) { static const Op& op = Op::Get("right_shift"); - return CallNode::make(op, {x, nbit}, Attrs(), {}); + return Call(op, {x, nbit}, Attrs(), {}); } inline Expr LeftShift(Expr x, Expr nbit) { static const Op& op = Op::Get("left_shift"); - return CallNode::make(op, {x, nbit}, Attrs(), {}); + return Call(op, {x, nbit}, Attrs(), {}); } inline Expr ReshapeLike(Expr lhs, Expr rhs) { static const Op& op = Op::Get("reshape_like"); - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(), {}); } inline Expr Copy(Expr data) { static const Op& op = Op::Get("copy"); - return CallNode::make(op, {data}, Attrs(), {}); + return Call(op, {data}, Attrs(), {}); } @@ -471,7 +471,7 @@ inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { attrs->keepdims = keepdims; attrs->exclude = exclude; static const Op& op = Op::Get("mean"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude) { @@ -480,18 +480,18 @@ inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, b attrs->keepdims = keepdims; attrs->exclude = exclude; static const Op& op = Op::Get("variance"); - return CallNode::make(op, {data, mean}, Attrs(attrs), {}); + return Call(op, {data, mean}, Attrs(attrs), {}); } static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { static const Op& op = Op::Get("where"); - return CallNode::make(op, {condition, x, y}); + return Call(op, {condition, x, y}); } static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { static const Op& op = Op::Get("greater_equal"); - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(), {}); } static inline Expr Full(Expr fill_value, @@ -501,7 +501,7 @@ static inline Expr Full(Expr fill_value, attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("full"); - return CallNode::make(op, {fill_value}, Attrs(attrs), {}); + return Call(op, {fill_value}, Attrs(attrs), {}); } static inline Expr Conv2D(Expr data, Expr weight, Array strides, @@ -520,7 +520,7 @@ static inline Expr Conv2D(Expr data, Expr weight, Array strides, attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.conv2d"); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); + return Call(op, {data, weight}, Attrs(attrs), {}); } static inline Expr Dense(Expr data, @@ -531,7 +531,7 @@ static inline Expr Dense(Expr data, attrs->units = units; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("nn.dense"); - return CallNode::make(op, {data, weight}, Attrs(attrs), {}); + return Call(op, {data, weight}, Attrs(attrs), {}); } static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclude) { @@ -540,7 +540,7 @@ static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclu attrs->keepdims = keepdims; attrs->exclude = exclude; static const Op& op = Op::Get("sum"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } static inline Expr Reshape(Expr data, Array newshape) { @@ -548,7 +548,7 @@ static inline Expr Reshape(Expr data, Array newshape) { attrs->newshape = std::move(newshape); attrs->reverse = false; static const Op& op = Op::Get("reshape"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides, @@ -562,7 +562,7 @@ static inline Expr AvgPool2D(Expr data, Array pool_size, Arrayceil_mode = ceil_mode; attrs->count_include_pad = count_include_pad; static const Op& op = Op::Get("nn.avg_pool2d"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } static inline Expr Pad(Expr data, Array> pad_width, double pad_value, @@ -572,14 +572,14 @@ static inline Expr Pad(Expr data, Array> pad_width, double pad_ attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); static const Op& op = Op::Get("nn.pad"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } static inline Expr Tile(Expr data, Array reps) { auto attrs = make_object(); attrs->reps = reps; static const Op& op = Op::Get("tile"); - return CallNode::make(op, {data}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } Expr MakeBroadCastTo(Expr data, Array shape); diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index e4722e2c3748..6e35dfbcb158 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -155,7 +155,7 @@ class Fill : ExprFunctor { Expr Compound(const Expr& orig, const Expr& now, const Var& v) { Var var = v.defined() ? v : - VarNode::make(std::string("x"), Type()); + Var(std::string("x"), Type()); return GetScope(orig)->ll->Push(var, now); } @@ -165,7 +165,7 @@ class Fill : ExprFunctor { for (const auto& a : c->args) { args.push_back(VisitExpr(a)); } - return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), v); + return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); } Expr VisitExpr_(const TupleNode* t, const Var& v) final { @@ -174,32 +174,32 @@ class Fill : ExprFunctor { for (const auto& a : t->fields) { fields.push_back(VisitExpr(a)); } - return Compound(e, TupleNode::make(fields), v); + return Compound(e, Tuple(fields), v); } Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { Expr e = GetRef(t); - return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v); + return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v); } Expr VisitExpr_(const RefCreateNode* r, const Var& v) final { Expr e = GetRef(r); - return Compound(e, RefCreateNode::make(VisitExpr(r->value)), v); + return Compound(e, RefCreate(VisitExpr(r->value)), v); } Expr VisitExpr_(const RefReadNode* r, const Var& v) final { Expr e = GetRef(r); - return Compound(e, RefReadNode::make(VisitExpr(r->ref)), v); + return Compound(e, RefRead(VisitExpr(r->ref)), v); } Expr VisitExpr_(const RefWriteNode* r, const Var& v) final { Expr e = GetRef(r); - return Compound(e, RefWriteNode::make(VisitExpr(r->ref), VisitExpr(r->value)), v); + return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v); } Expr VisitExpr_(const IfNode* i, const Var& v) final { Expr e = GetRef(i); - Expr ret = IfNode::make(VisitExpr(i->cond), + Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); return Compound(e, ret, v); @@ -257,11 +257,11 @@ class Fill : ExprFunctor { Expr data = VisitExpr(m->data); std::vector clauses; for (const Clause& c : m->clauses) { - clauses.push_back(ClauseNode::make( + clauses.push_back(Clause( c->lhs, GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); } - return Compound(e, MatchNode::make(data, clauses, m->complete), v); + return Compound(e, Match(data, clauses, m->complete), v); } }; diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 49ca8d2ef326..1039a1b6272d 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -137,7 +137,7 @@ Function ToCPS(const Function& f, Expr VisitExpr_(const LetNode* op, const MCont& k) final { return VisitExpr(op->value, [&](const Expr& v) { - return LetNode::make(remap(op->var), v, VisitExpr(op->body, k)); + return Let(remap(op->var), v, VisitExpr(op->body, k)); }); } @@ -155,7 +155,7 @@ Function ToCPS(const Function& f, } Pattern VisitPattern_(const PatternVarNode* op) final { - return PatternVarNode::make(remap(op->var)); + return PatternVar(remap(op->var)); } Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { @@ -177,18 +177,18 @@ Function ToCPS(const Function& f, } Expr VisitExpr_(const RefCreateNode* op, const MCont& k) final { - return VisitExpr(op->value, [&](const Expr& v) { return k(RefCreateNode::make(v)); }); + return VisitExpr(op->value, [&](const Expr& v) { return k(RefCreate(v)); }); } Expr reify(const MCont& k) { - Var arg = VarNode::make("arg", Type()); + Var arg = Var("arg", Type()); return Function({arg}, k(arg), Type(), {}, {}); } Expr reify(const MCont& k, const std::function& cont) { - return LetList::Let(reify(k), + return LetList::LetBind(reify(k), [&](const Var& f) { - return cont([&](const Expr& e) { return CallNode::make(f, {e}); }); + return cont([&](const Expr& e) { return Call(f, {e}); }); }); } @@ -196,7 +196,7 @@ Function ToCPS(const Function& f, return reify(k, [&](const MCont& kf) { return VisitExpr(op->cond, [&](const Expr& v) { - return IfNode::make(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf)); + return If(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf)); }); }); } @@ -206,9 +206,9 @@ Function ToCPS(const Function& f, return VisitExpr(op->data, [&](const Expr& v) { tvm::Array clauses; for (const auto& c : op->clauses) { - clauses.push_back(ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs, kf))); + clauses.push_back(Clause(VisitPattern(c->lhs), VisitExpr(c->rhs, kf))); } - return MatchNode::make(v, clauses, op->complete); + return Match(v, clauses, op->complete); }); }); } @@ -216,7 +216,7 @@ Function ToCPS(const Function& f, Expr VisitExpr_(const RefReadNode* op, const MCont& k) final { return VisitExpr(op->ref, [&](const Expr& r) { - return LetList::Let(RefReadNode::make(r), k); + return LetList::LetBind(RefRead(r), k); }); } @@ -225,7 +225,7 @@ Function ToCPS(const Function& f, [&](const Expr& r) { return VisitExpr(op->value, [&](const Expr& v) { - return LetList::Let(RefWriteNode::make(r, v), k); + return LetList::LetBind(RefWrite(r, v), k); }); }); } @@ -235,7 +235,7 @@ Function ToCPS(const Function& f, std::function next; next = [&]() { return (fields.size() == op->fields.size()) ? - k(TupleNode::make(fields)) : + k(Tuple(fields)) : VisitExpr(op->fields[fields.size()], [&](const Expr& v) { fields.push_back(v); return next(); @@ -246,7 +246,7 @@ Function ToCPS(const Function& f, Expr VisitExpr_(const TupleGetItemNode* op, const MCont& k) final { return VisitExpr(op->tuple, [&](const Expr& v) { - return k(TupleGetItemNode::make(v, op->index)); + return k(TupleGetItem(v, op->index)); }); } @@ -256,7 +256,7 @@ Function ToCPS(const Function& f, std::function next; next = [&]() { if (args.size() == op->args.size()) { - return LetList::Let(CallNode::make(op->op, args, op->attrs, op->type_args), k); + return LetList::LetBind(Call(op->op, args, op->attrs, op->type_args), k); } else { return VisitExpr(op->args[args.size()], [&](const Expr& v) { args.push_back(v); @@ -272,7 +272,7 @@ Function ToCPS(const Function& f, next = [&]() { if (args.size() == op->args.size()) { args.push_back(reify(k)); - return Expr(CallNode::make(f, args, op->attrs, op->type_args)); + return Expr(Call(f, args, op->attrs, op->type_args)); } else { return VisitExpr(op->args[args.size()], [&](const Expr& v) { args.push_back(v); @@ -287,7 +287,7 @@ Function ToCPS(const Function& f, } } } mut(remap, answer, m, vm, cm); - Var k = VarNode::make("k", Arrow(CPSType(function_type->ret_type, answer), answer)); + Var k = Var("k", Arrow(CPSType(function_type->ret_type, answer), answer)); tvm::Array new_params; for (const Var& v : f->params) { new_params.push_back(remap(v)); @@ -295,7 +295,7 @@ Function ToCPS(const Function& f, new_params.push_back(k); return Function(new_params, mut.VisitExpr(f->body, - [&](const Expr& e) { return CallNode::make(k, {e}); }), + [&](const Expr& e) { return Call(k, {e}); }), answer, f->type_params, f->attrs); @@ -311,7 +311,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { void VisitExpr_(const VarNode* vn) final { Var v = GetRef(vn); if (vm->count(v) == 0) { - auto ret = VarNode::make(v->name_hint(), CPSType(v->checked_type(), answer)); + auto ret = Var(v->name_hint(), CPSType(v->checked_type(), answer)); vm->insert({v, ret}); } } @@ -340,7 +340,7 @@ Function UnCPS(const Function& f) { CHECK_GT(f->params.size(), 0); std::vector new_params; for (const auto& p : f->params) { - new_params.push_back(VarNode::make(p->name_hint(), p->checked_type())); + new_params.push_back(Var(p->name_hint(), p->checked_type())); } auto cont_type = Downcast(new_params.back()->type_annotation); new_params.pop_back(); @@ -354,7 +354,7 @@ Function UnCPS(const Function& f) { new_type_params.pop_back(); // TODO(@M.K.): make alphaequal work on free term // CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type))); - auto x = VarNode::make("x", new_ret_type); + auto x = Var("x", new_ret_type); auto cont = Function({x}, x, new_ret_type, {}, {}); tvm::Array args; for (const auto& p : new_params) { @@ -367,7 +367,7 @@ Function UnCPS(const Function& f) { } type_args.push_back(new_ret_type); return Function(new_params, - CallNode::make(f, args, {}, type_args), + Call(f, args, {}, type_args), new_ret_type, new_type_params, f->attrs); diff --git a/src/relay/transforms/to_graph_normal_form.cc b/src/relay/transforms/to_graph_normal_form.cc index b6ff2490aba4..8bf41a4610c0 100644 --- a/src/relay/transforms/to_graph_normal_form.cc +++ b/src/relay/transforms/to_graph_normal_form.cc @@ -63,7 +63,7 @@ class GNF : public ExprMutator { } static Expr WrapRec(const Var& var, const Expr& val) { - return UseVar(var, val) ? LetNode::make(var, val, var) : val; + return UseVar(var, val) ? Let(var, val, var) : val; } Expr VisitExpr_(const LetNode* ln) override { diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index cee3aaf01731..b6e75ae4f585 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -138,7 +138,7 @@ class TransformMemorizer : public ObjectRef { // 2) Insert layout transform on the transformed src. CHECK(new_src_layout.defined() && dst_layout.defined()) << "Cannot insert layout transform because there are undefined layouts"; - CHECK(BijectiveLayoutNode::make(new_src_layout, dst_layout).defined()) + CHECK(tir::BijectiveLayout(new_src_layout, dst_layout).defined()) << "Cannot insert layout transform because there are inconvertible layouts: " << new_src_layout << " v.s. " << dst_layout; return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name()); @@ -258,7 +258,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj Expr tmp = push_back_one_arg(x); fields.push_back(tmp); } - normal_new_args.push_back(TupleNode::make(fields)); + normal_new_args.push_back(Tuple(fields)); } else { Expr tmp = push_back_one_arg(new_arg); normal_new_args.push_back(tmp); @@ -325,7 +325,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt])); pt++; } - transformed_args.push_back(TupleNode::make(transformed_tuple_arg)); + transformed_args.push_back(Tuple(transformed_tuple_arg)); } else { transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt])); pt++; @@ -336,21 +336,21 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj // state[node] = (old_out, new_out) // (handle tuple output) if (ref_call->checked_type()->IsInstance()) { - Expr tuple_output = CallNode::make(new_call->op, transformed_args, new_call->attrs); + Expr tuple_output = Call(new_call->op, transformed_args, new_call->attrs); Array fields; for (size_t i = 0; i < new_out.size(); ++i) { auto rnode = make_object>(); - rnode->value = TupleGetItemNode::make(tuple_output, i); + rnode->value = TupleGetItem(tuple_output, i); rnode->old_layout = old_out[i]; rnode->new_layout = new_out[i]; rnode->memorizer = memorizer; fields.push_back(Expr(rnode)); } - return TupleNode::make(fields); + return Tuple(fields); } else { auto rnode = make_object>(); CHECK_EQ(new_out.size(), 1); - rnode->value = CallNode::make(new_call->op, transformed_args, new_call->attrs); + rnode->value = Call(new_call->op, transformed_args, new_call->attrs); rnode->old_layout = old_out[0]; rnode->new_layout = new_out[0]; rnode->memorizer = memorizer; diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 9cc07a8a3b81..85842a0b9dcf 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -350,20 +350,18 @@ Array BijectiveLayout::BackwardShape(const Array& shape) con self->src_layout->axes, self->backward_rule); } -BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout, - const Layout& dst_layout) { +BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { auto n = make_object(); - n->src_layout = src_layout; - n->dst_layout = dst_layout; + n->src_layout = std::move(src_layout); + n->dst_layout = std::move(dst_layout); - if (!GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) { - // not convertible - return BijectiveLayout(); + // To be consistent with previous behavior, a nullptr layout is created + // when argument is invalid. + if (GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) { + CHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout)); + data_ = std::move(n); } - CHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout)); - - return BijectiveLayout(n); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -398,7 +396,9 @@ TVM_REGISTER_GLOBAL("tir.LayoutGetItem") }); TVM_REGISTER_GLOBAL("tir.BijectiveLayout") -.set_body_typed(BijectiveLayoutNode::make); +.set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { + return BijectiveLayout(src_layout, dst_layout); +}); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex") .set_body_method(&BijectiveLayout::ForwardIndex); diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index cae71889d7ad..fa94271ad413 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -76,12 +76,12 @@ TVM_REGISTER_GLOBAL("relay.backend.lower_call") TEST(Relay, BuildModule) { auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32)); - auto a = relay::VarNode::make("a", tensor_type); - auto b = relay::VarNode::make("b", tensor_type); + auto a = relay::Var("a", tensor_type); + auto b = relay::Var("b", tensor_type); auto add_op = relay::Op::Get("add"); - auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {}); - auto c = relay::VarNode::make("c", tensor_type); - auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {}); + auto x = relay::Call(add_op, {a, b}, tvm::Attrs(), {}); + auto c = relay::Var("c", tensor_type); + auto y = relay::Call(add_op, {x, c}, tvm::Attrs(), {}); auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {}); auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index bc5e65e59b74..f951a8f386a6 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -27,11 +27,11 @@ TEST(Relay, SelfReference) { using namespace tvm; auto tensor_type = relay::TensorType({}, DataType::Bool()); - auto x = relay::VarNode::make("x", relay::Type()); + auto x = relay::Var("x", relay::Type()); auto f = relay::Function(tvm::Array{ x }, x, relay::Type(), {}); CHECK(f->IsInstance()); - auto y = relay::VarNode::make("y", tensor_type); - auto call = relay::CallNode::make(f, Array{ y }); + auto y = relay::Var("y", tensor_type); + auto call = relay::Call(f, Array{ y }); auto fx = relay::Function(tvm::Array{ y }, call, relay::Type(), {}); auto mod = IRModule::FromExpr(fx); mod = relay::transform::InferType()(mod); diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index d8a0bde5fa6d..756468c9b110 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -41,17 +41,17 @@ TEST(Relay, Sequential) { tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); // Create a function for optimization. - auto c = relay::ConstantNode::make(c_data); - auto a = relay::VarNode::make("a", tensor_type); - auto x = relay::VarNode::make("x", tensor_type); + auto c = relay::Constant(c_data); + auto a = relay::Var("a", tensor_type); + auto x = relay::Var("x", tensor_type); auto add_op = relay::Op::Get("add"); - auto y = relay::CallNode::make(add_op, {c, c}); - y = relay::CallNode::make(add_op, {x, y}); - auto z = relay::CallNode::make(add_op, {y, c}); - auto z1 = relay::CallNode::make(add_op, {y, c}); - auto z2 = relay::CallNode::make(add_op, {z, z1}); + auto y = relay::Call(add_op, {c, c}); + y = relay::Call(add_op, {x, y}); + auto z = relay::Call(add_op, {y, c}); + auto z1 = relay::Call(add_op, {y, c}); + auto z2 = relay::Call(add_op, {z, z1}); // Let expression and varaible a should be dead-code eliminated. - auto z3 = relay::LetNode::make(a, c, z2); + auto z3 = relay::Let(a, c, z2); relay::Function func = relay::Function(relay::FreeVars(z3), z3, relay::Type(), {}); @@ -89,12 +89,12 @@ TEST(Relay, Sequential) { CHECK(f.defined()); // Expected function - auto c1 = relay::ConstantNode::make(c_data); - auto x1 = relay::VarNode::make("x", tensor_type); - auto y1 = relay::CallNode::make(add_op, {c1, c1}); - y1 = relay::CallNode::make(add_op, {x1, y1}); - auto zz = relay::CallNode::make(add_op, {y1, c1}); - zz = relay::CallNode::make(add_op, {zz, zz}); + auto c1 = relay::Constant(c_data); + auto x1 = relay::Var("x", tensor_type); + auto y1 = relay::Call(add_op, {c1, c1}); + y1 = relay::Call(add_op, {x1, y1}); + auto zz = relay::Call(add_op, {y1, c1}); + zz = relay::Call(add_op, {zz, zz}); relay::Function expected_func = relay::Function(relay::FreeVars(zz), zz, relay::Type(), {}); diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 40bdcc673cb0..431ace5bc11e 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1258,7 +1258,7 @@ inline Tensor layout_transform(const Tensor& src, CHECK(src_layout_struct.defined() && dst_layout_struct.defined()) << "cannot convert from/to undefined layout"; - auto layout_converter = BijectiveLayoutNode::make(src_layout_struct, dst_layout_struct); + auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct); CHECK(layout_converter.defined()) << "cannot convert from " << src_layout << " to " << dst_layout;