From a8c369218e87979020f732b9b2ad373fce4895f2 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 31 Dec 2019 09:35:03 -0800 Subject: [PATCH] [REFACTOR][OBJECT] Consoldiate NodePtr/Ref/Hash/Equal to Object (#4603) * [REFACTOR][OBJECT] Consoldiate NodePtr/Ref/Hash/Equal and macros to Object. Historically, we have classes like NodePtr/Ref/HashEqual. After unified object protocol, these names are just alias of the object counterpart. Moreover, there are helper macros defined over the places for defining these object. This PR consoldiate the terminologies into the corresponding ones in the Object system so we have a clean and consistent API moving forward. * Update include/tvm/attrs.h Co-Authored-By: Wei Chen * fix compilation Co-authored-by: Wei Chen --- include/tvm/api_registry.h | 12 +- include/tvm/arithmetic.h | 26 ++-- include/tvm/attrs.h | 39 +++--- include/tvm/buffer.h | 8 +- include/tvm/build_module.h | 24 ++-- include/tvm/data_layout.h | 16 +-- include/tvm/expr.h | 51 ++++---- include/tvm/ir.h | 88 ++++++------- include/tvm/ir_functor_ext.h | 4 +- include/tvm/ir_pass.h | 2 +- include/tvm/ir_visitor.h | 4 +- include/tvm/lowered_func.h | 4 +- include/tvm/node/container.h | 48 +++---- include/tvm/node/node.h | 100 --------------- include/tvm/operation.h | 40 +++--- include/tvm/packed_func_ext.h | 8 +- include/tvm/relay/adt.h | 64 +++++++--- include/tvm/relay/attrs/transform.h | 2 +- include/tvm/relay/base.h | 72 +++-------- include/tvm/relay/error.h | 8 +- include/tvm/relay/expr.h | 108 +++++++++++----- include/tvm/relay/expr_functor.h | 6 +- include/tvm/relay/interpreter.h | 48 ++++--- include/tvm/relay/module.h | 6 +- include/tvm/relay/op.h | 4 +- include/tvm/relay/op_attr_types.h | 2 +- include/tvm/relay/pattern_functor.h | 4 +- include/tvm/relay/transform.h | 25 ++-- include/tvm/relay/type.h | 93 +++++++++----- include/tvm/runtime/object.h | 52 ++++++-- include/tvm/schedule.h | 42 +++---- include/tvm/target_info.h | 9 +- include/tvm/tensor.h | 12 +- include/tvm/tensor_intrin.h | 16 +-- nnvm/include/nnvm/graph.h | 6 +- nnvm/include/nnvm/node.h | 22 ++-- nnvm/include/nnvm/op_attr_types.h | 8 +- nnvm/include/nnvm/symbolic.h | 6 +- nnvm/src/c_api/c_api_symbolic.cc | 6 +- nnvm/src/core/graph.cc | 8 +- nnvm/src/core/node.cc | 8 +- nnvm/src/core/symbolic.cc | 46 +++---- nnvm/src/pass/correct_layout.cc | 20 +-- nnvm/src/pass/gradient.cc | 20 +-- nnvm/src/pass/infer_shape_type.cc | 6 +- nnvm/src/pass/order_mutation.cc | 16 +-- nnvm/src/pass/place_device.cc | 12 +- nnvm/src/pass/saveload_json.cc | 10 +- src/api/api_base.cc | 2 +- src/api/api_lang.cc | 10 +- src/api/api_pass.cc | 9 +- src/arithmetic/bound_deducer.cc | 6 +- src/arithmetic/canonical_simplify.cc | 26 ++-- src/arithmetic/const_int_bound.cc | 4 +- src/arithmetic/detect_linear_equation.cc | 4 +- src/arithmetic/int_set.cc | 4 +- src/arithmetic/int_set.h | 6 +- src/arithmetic/modular_set.cc | 4 +- src/arithmetic/pattern_match.h | 16 +-- src/codegen/build_module.cc | 14 +-- src/contrib/hybrid/codegen_hybrid.cc | 4 +- src/contrib/hybrid/codegen_hybrid.h | 4 +- src/lang/api_registry.cc | 2 +- src/lang/attrs.cc | 8 +- src/lang/buffer.cc | 4 +- src/lang/data_layout.cc | 6 +- src/lang/expr.cc | 12 +- src/lang/ir.cc | 64 +++++----- src/lang/tensor.cc | 8 +- src/node/serialization.cc | 2 +- src/op/compute_op.cc | 22 ++-- src/op/extern_op.cc | 10 +- src/op/hybrid_op.cc | 19 +-- src/op/placeholder_op.cc | 2 +- src/op/scan_op.cc | 10 +- src/op/tensor_compute_op.cc | 10 +- src/op/tensorize.cc | 6 +- src/pass/combine_context_call.cc | 6 +- src/pass/coproc_sync.cc | 16 +-- src/pass/hoist_if_then_else.cc | 44 +++---- src/pass/infer_fragment.cc | 2 +- src/pass/inject_virtual_thread.cc | 2 +- src/pass/ir_deep_compare.cc | 2 +- src/pass/ir_util.cc | 18 +-- src/pass/ir_visitor.cc | 10 +- src/pass/lift_attr_scope.cc | 14 +-- src/pass/loop_partition.cc | 32 ++--- src/pass/lower_custom_datatypes.cc | 2 +- src/pass/lower_intrin.cc | 2 +- src/pass/lower_thread_allreduce.cc | 2 +- src/pass/lower_tvm_builtin.cc | 2 +- src/pass/lower_warp_memory.cc | 2 +- src/pass/make_api.cc | 6 +- src/pass/remap_thread_axis.cc | 2 +- src/pass/simple_passes.cc | 8 +- src/pass/skip_assert.cc | 2 +- src/pass/split_host_device.cc | 6 +- src/pass/ssa.cc | 6 +- src/pass/storage_access.cc | 2 +- src/pass/storage_access.h | 2 +- src/pass/storage_flatten.cc | 4 +- src/pass/storage_rewrite.cc | 22 ++-- src/pass/storage_sync.cc | 10 +- src/pass/tensor_core.cc | 26 ++-- src/pass/verify_memory.cc | 2 +- src/relay/backend/build_module.cc | 4 +- src/relay/backend/compile_engine.cc | 34 ++--- src/relay/backend/compile_engine.h | 34 ++--- .../backend/contrib/codegen_c/codegen.cc | 6 +- .../backend/contrib/codegen_c/codegen_c.h | 2 +- src/relay/backend/contrib/dnnl/codegen.cc | 4 +- src/relay/backend/graph_runtime_codegen.cc | 16 +-- src/relay/backend/interpreter.cc | 23 ++-- src/relay/backend/param_dict.cc | 2 +- src/relay/backend/param_dict.h | 10 +- src/relay/backend/vm/compiler.cc | 34 ++--- src/relay/backend/vm/compiler.h | 4 +- src/relay/backend/vm/inline_primitives.cc | 2 +- src/relay/backend/vm/lambda_lift.cc | 4 +- src/relay/backend/vm/removed_unused_funcs.cc | 4 +- src/relay/ir/adt.cc | 16 +-- src/relay/ir/alpha_equal.cc | 26 ++-- src/relay/ir/base.cc | 8 +- src/relay/ir/error.cc | 6 +- src/relay/ir/expr.cc | 46 +++---- src/relay/ir/expr_functor.cc | 6 +- src/relay/ir/hash.cc | 22 ++-- src/relay/ir/module.cc | 6 +- src/relay/ir/op.cc | 12 +- src/relay/ir/pretty_printer.cc | 48 +++---- src/relay/ir/type.cc | 18 +-- src/relay/ir/type_functor.h | 2 +- src/relay/op/algorithm/argsort.cc | 2 +- src/relay/op/algorithm/topk.cc | 2 +- src/relay/op/annotation/annotation.cc | 4 +- src/relay/op/debug.cc | 2 +- src/relay/op/device_copy.cc | 6 +- src/relay/op/image/resize.cc | 2 +- src/relay/op/memory/memory.cc | 6 +- src/relay/op/nn/bitserial.cc | 6 +- src/relay/op/nn/convolution.cc | 24 ++-- src/relay/op/nn/nn.cc | 30 ++--- src/relay/op/nn/pad.cc | 4 +- src/relay/op/nn/pooling.cc | 16 +-- src/relay/op/nn/sparse.cc | 4 +- src/relay/op/nn/upsampling.cc | 4 +- src/relay/op/op_common.h | 2 +- src/relay/op/tensor/reduce.cc | 4 +- src/relay/op/tensor/transform.cc | 52 ++++---- src/relay/op/tensor/unary.cc | 6 +- src/relay/op/vision/multibox_op.cc | 4 +- src/relay/op/vision/nms.cc | 8 +- src/relay/op/vision/rcnn_op.cc | 10 +- src/relay/op/vision/yolo.cc | 2 +- src/relay/pass/alter_op_layout.cc | 4 +- src/relay/pass/canonicalize_cast.cc | 7 +- src/relay/pass/combine_parallel_conv2d.cc | 6 +- src/relay/pass/combine_parallel_op.h | 16 +-- src/relay/pass/convert_layout.cc | 4 +- src/relay/pass/de_duplicate.cc | 4 +- src/relay/pass/dead_code.cc | 4 +- src/relay/pass/dependency_graph.cc | 2 +- src/relay/pass/dependency_graph.h | 2 +- src/relay/pass/device_annotation.cc | 2 +- src/relay/pass/eliminate_common_subexpr.cc | 6 +- src/relay/pass/eta_expand.cc | 2 +- src/relay/pass/expr_subst.cc | 9 +- src/relay/pass/expr_subst.h | 7 +- src/relay/pass/feature.cc | 2 +- src/relay/pass/fold_constant.cc | 4 +- src/relay/pass/fold_scale_axis.cc | 45 +++---- src/relay/pass/forward_rewrite.cc | 14 +-- src/relay/pass/fuse_ops.cc | 18 +-- src/relay/pass/gradient.cc | 4 +- src/relay/pass/partial_eval.cc | 117 +++++++++++------- src/relay/pass/pass_manager.cc | 32 +++-- src/relay/pass/pass_util.h | 46 +++---- src/relay/pass/pattern_util.h | 36 +++--- src/relay/pass/quantize/annotate.cc | 9 +- src/relay/pass/quantize/calibrate.cc | 2 +- src/relay/pass/quantize/partition.cc | 9 +- src/relay/pass/quantize/quantize.cc | 4 +- src/relay/pass/quantize/quantize.h | 12 +- src/relay/pass/quantize/realize.cc | 44 ++++--- src/relay/pass/simplify_inference.cc | 2 +- src/relay/pass/to_a_normal_form.cc | 2 +- src/relay/pass/to_cps.cc | 4 +- src/relay/pass/to_graph_normal_form.cc | 6 +- src/relay/pass/transform_layout.h | 33 ++--- src/relay/pass/type_infer.cc | 28 ++--- src/relay/pass/type_solver.cc | 28 ++--- src/relay/pass/type_solver.h | 10 +- src/relay/pass/util.cc | 14 +-- src/relay/pass/well_formed.cc | 8 +- src/relay/qnn/op/concatenate.cc | 2 +- src/relay/qnn/op/convolution.cc | 2 +- src/relay/qnn/op/dense.cc | 2 +- src/relay/qnn/op/dequantize.cc | 2 +- src/relay/qnn/op/op_common.h | 6 +- src/relay/qnn/op/quantize.cc | 2 +- src/relay/qnn/op/requantize.cc | 2 +- src/relay/qnn/util.h | 2 +- src/runtime/vm/memory_manager.h | 2 +- src/schedule/auto_inline_elem_wise.cc | 6 +- src/schedule/bound.cc | 2 +- src/schedule/graph.cc | 30 ++--- src/schedule/schedule_dataflow_rewrite.cc | 12 +- src/schedule/schedule_lang.cc | 44 +++---- src/schedule/schedule_ops.cc | 8 +- tests/cpp/container_test.cc | 4 +- tests/cpp/expr_test.cc | 4 +- tests/cpp/ir_visitor_test.cc | 2 +- topi/include/topi/detail/extern.h | 2 +- topi/include/topi/nn/softmax.h | 2 +- topi/src/topi.cc | 2 +- 215 files changed, 1623 insertions(+), 1517 deletions(-) diff --git a/include/tvm/api_registry.h b/include/tvm/api_registry.h index c41c3087f4ac..292e4948d211 100644 --- a/include/tvm/api_registry.h +++ b/include/tvm/api_registry.h @@ -49,7 +49,7 @@ namespace tvm { * \brief Node container of EnvFunc * \sa EnvFunc */ -class EnvFuncNode : public Node { +class EnvFuncNode : public Object { public: /*! \brief Unique name of the global function */ std::string name; @@ -63,7 +63,7 @@ class EnvFuncNode : public Node { } static constexpr const char* _type_key = "EnvFunc"; - TVM_DECLARE_NODE_TYPE_INFO(EnvFuncNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object); }; /*! @@ -73,10 +73,10 @@ class EnvFuncNode : public Node { * An EnvFunc is saved by its name in the global registry * under the assumption that the same function is registered during load. */ -class EnvFunc : public NodeRef { +class EnvFunc : public ObjectRef { public: EnvFunc() {} - explicit EnvFunc(NodePtr n) : NodeRef(n) {} + explicit EnvFunc(ObjectPtr n) : ObjectRef(n) {} /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { return static_cast(get()); @@ -119,12 +119,12 @@ class TypedEnvFunc; * \sa EnvFunc */ template -class TypedEnvFunc : public NodeRef { +class TypedEnvFunc : public ObjectRef { public: /*! \brief short hand for this function type */ using TSelf = TypedEnvFunc; TypedEnvFunc() {} - explicit TypedEnvFunc(ObjectPtr n) : NodeRef(n) {} + explicit TypedEnvFunc(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Assign global function to a TypedEnvFunc * \param other Another global function. diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index bda6ac647f55..e5f75673a9cb 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -55,7 +55,7 @@ class Analyzer; * * set = [min_value, max_value] */ -class ConstIntBoundNode : public Node { +class ConstIntBoundNode : public Object { public: int64_t min_value; int64_t max_value; @@ -74,14 +74,14 @@ class ConstIntBoundNode : public Node { static const constexpr int64_t kNegInf = -kPosInf; static constexpr const char* _type_key = "arith.ConstIntBound"; - TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object); }; /*! * \brief reference class to ConstIntBoundNode * \sa ConstIntBoundNode */ -class ConstIntBound : public NodeRef { +class ConstIntBound : public ObjectRef { public: /*! * \brief constructor by fields. @@ -92,7 +92,7 @@ class ConstIntBound : public NodeRef { static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; - TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode); + TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode); }; /*! @@ -155,7 +155,7 @@ class ConstIntBoundAnalyzer { * This is useful to decide if the index is dividable by certain value. * For example, if index = 0 + 4 x, then we know it can be divided by 4. */ -class ModularSetNode : public Node { +class ModularSetNode : public Object { public: /*! \brief linear co-efficient */ int64_t coeff; @@ -168,18 +168,18 @@ class ModularSetNode : public Node { } static constexpr const char* _type_key = "arith.ModularSet"; - TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object); }; /*! * \brief reference of ModularSetNode * \sa ModularSetNode */ -class ModularSet : public NodeRef { +class ModularSet : public ObjectRef { public: TVM_DLL ModularSet(int64_t coeff, int64_t base); - TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode); + TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode); }; /*! @@ -349,20 +349,20 @@ enum SignType { /*! * \brief Base class of all IntSet containers. */ -struct IntSetNode : public Node { +struct IntSetNode : public Object { static constexpr const char* _type_key = "IntSet"; - TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object); }; /*! * \brief Integer set class, represent a set of integers in one dimension. */ -class IntSet : public NodeRef { +class IntSet : public ObjectRef { public: /*! \brief constructor */ IntSet() {} // constructor from not container. - explicit IntSet(ObjectPtr n) : NodeRef(n) {} + explicit IntSet(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -598,7 +598,7 @@ IntSet EvalSet(Range r, const std::unordered_map& dom_map); /*! \brief Map from Expr to IntSet */ -using ExprIntSetMap = std::unordered_map; +using ExprIntSetMap = std::unordered_map; /*! * \brief Find the integer set of every sub-expression, given the * domain of each iteration variables. diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 8810c4e4a0df..0178eabe02eb 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -65,7 +65,7 @@ namespace tvm { */ #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ static constexpr const char* _type_key = TypeKey; \ - TVM_DECLARE_NODE_TYPE_INFO(ClassName, ::tvm::BaseAttrsNode) \ + TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ template \ void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) @@ -83,9 +83,9 @@ namespace tvm { * \tparam TNodeRef the type to be created. * \return A instance that will represent None. */ -template -inline TNodeRef NullValue() { - return TNodeRef(NodePtr(nullptr)); +template +inline TObjectRef NullValue() { + return TObjectRef(ObjectPtr(nullptr)); } template<> @@ -106,7 +106,7 @@ struct AttrError : public dmlc::Error { /*! * \brief Information about attribute fields in string representations. */ -class AttrFieldInfoNode : public Node { +class AttrFieldInfoNode : public Object { public: /*! \brief name of the field */ std::string name; @@ -121,11 +121,14 @@ class AttrFieldInfoNode : public Node { v->Visit("description", &description); } static constexpr const char* _type_key = "AttrFieldInfo"; - TVM_DECLARE_NODE_TYPE_INFO(AttrFieldInfoNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); }; /*! \brief AttrFieldInfo */ -TVM_DEFINE_NODE_REF(AttrFieldInfo, AttrFieldInfoNode); +class AttrFieldInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode); +}; class AttrsHashHandler; class AttrsEqualHandler; @@ -217,7 +220,7 @@ class AttrsHash { * subclass AttrsNode instead. * \sa AttrsNode */ -class BaseAttrsNode : public Node { +class BaseAttrsNode : public Object { public: using TVMArgs = runtime::TVMArgs; using TVMRetValue = runtime::TVMRetValue; @@ -271,16 +274,16 @@ class BaseAttrsNode : public Node { TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0; static constexpr const char* _type_key = "Attrs"; - TVM_DECLARE_BASE_NODE_INFO(BaseAttrsNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); }; /*! \brief Base attribute container for all attributes */ -class Attrs : public NodeRef { +class Attrs : public ObjectRef { public: // normal constructor Attrs() {} // construct from shared ptr. - explicit Attrs(NodePtr n) : NodeRef(n) {} + explicit Attrs(ObjectPtr n) : ObjectRef(n) {} /*! \return The attribute node */ const BaseAttrsNode* operator->() const { @@ -305,13 +308,13 @@ class Attrs : public NodeRef { class DictAttrsNode : public BaseAttrsNode { public: /*! \brief internal attrs map */ - Map dict; + Map dict; /*! * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. * \return The dict attributes. */ - TVM_DLL static Attrs make(Map dict); + TVM_DLL static Attrs make(Map dict); // implementations void VisitAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final; @@ -321,7 +324,7 @@ class DictAttrsNode : public BaseAttrsNode { size_t ContentHash(AttrsHash hasher) const final; // type info static constexpr const char* _type_key = "DictAttrs"; - TVM_DECLARE_NODE_TYPE_INFO(DictAttrsNode, BaseAttrsNode); + TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); }; @@ -639,7 +642,7 @@ class AttrDocEntry { public: using TSelf = AttrDocEntry; - explicit AttrDocEntry(NodePtr info) + explicit AttrDocEntry(ObjectPtr info) : info_(info) { } TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { @@ -663,15 +666,15 @@ class AttrDocEntry { } private: - NodePtr info_; + ObjectPtr info_; }; class AttrDocVisitor { public: template AttrDocEntry operator()(const char* key, T* v) { - NodePtr info - = make_node(); + ObjectPtr info + = make_object(); info->name = key; info->type_info = TypeName::value; fields_.push_back(AttrFieldInfo(info)); diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index fac18a9b1753..44c791863153 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -48,10 +48,10 @@ enum BufferType : int { * It is a composition of primitive symbolic types, * used to specify the memory layout of the Tensor used in program input. */ -class Buffer : public NodeRef { +class Buffer : public ObjectRef { public: Buffer() {} - explicit Buffer(ObjectPtr n) : NodeRef(n) {} + explicit Buffer(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Return a new buffer that is equivalent with current one * but always add stride field. @@ -101,7 +101,7 @@ class Buffer : public NodeRef { }; /*! \brief Node to represent a buffer */ -class BufferNode : public Node { +class BufferNode : public Object { public: // Data fields. /*! @@ -169,7 +169,7 @@ class BufferNode : public Node { BufferType buffer_type); static constexpr const char* _type_key = "Buffer"; - TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); }; inline const BufferNode* Buffer::operator->() const { diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index fba929cda1be..5078621e4bda 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -39,7 +39,7 @@ namespace tvm { * \brief Container for target device information. * Use target::llvm, target::cuda etc functions instead of constructing directly. */ -class TargetNode : public Node { +class TargetNode : public Object { public: /*! \brief The name of the target device */ std::string target_name; @@ -82,7 +82,7 @@ class TargetNode : public Node { TVM_DLL std::unordered_set libs() const; static constexpr const char* _type_key = "Target"; - TVM_DECLARE_NODE_TYPE_INFO(TargetNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); private: /*! \brief Internal string repr. */ @@ -90,10 +90,10 @@ class TargetNode : public Node { }; /*! \brief reference cpass to the target. */ -class Target : public NodeRef { +class Target : public ObjectRef { public: Target() {} - explicit Target(ObjectPtr n) : NodeRef(n) {} + explicit Target(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Create a Target given a string * \param target_str the string to parse @@ -178,7 +178,7 @@ TVM_DLL Target ext_dev(const std::vector& options = /*! * \brief Container for build configuration options */ -class BuildConfigNode : public Node { +class BuildConfigNode : public Object { public: /*! * \brief The data alignment to use when constructing buffers. If this is set to @@ -254,16 +254,16 @@ class BuildConfigNode : public Node { } static constexpr const char* _type_key = "BuildConfig"; - TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(BuildConfigNode, Object); }; /*! * \brief Build configuration for compilations. */ -class BuildConfig : public ::tvm::NodeRef { +class BuildConfig : public ::tvm::ObjectRef { public: BuildConfig() {} - explicit BuildConfig(ObjectPtr n) : NodeRef(n) {} + explicit BuildConfig(ObjectPtr n) : ObjectRef(n) {} const BuildConfigNode* operator->() const { return static_cast(get()); } @@ -375,10 +375,10 @@ class GenericFuncNode; /*! * \brief Generic function that can be specialized on a per-target basis. */ -class GenericFunc : public NodeRef { +class GenericFunc : public ObjectRef { public: GenericFunc() {} - explicit GenericFunc(ObjectPtr n) : NodeRef(n) {} + explicit GenericFunc(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Set the default function implementaiton. @@ -471,7 +471,7 @@ inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const { /*! * \brief Represents a generic function that can be specialized on a per-target basis. */ -class GenericFuncNode : public Node { +class GenericFuncNode : public Object { public: /*! \brief name of the function */ std::string name_; @@ -483,7 +483,7 @@ class GenericFuncNode : public Node { void VisitAttrs(AttrVisitor* v) {} static constexpr const char* _type_key = "GenericFunc"; - TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object); }; inline GenericFuncNode* GenericFunc::operator->() { diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index 5e2cc08660db..8c7247ff860b 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -92,7 +92,7 @@ class LayoutAxis { class Layout; // Internal node container Buffer -class LayoutNode : public Node { +class LayoutNode : public Object { public: /*! \brief string representation of layout, "" for scalar. */ std::string name; @@ -112,7 +112,7 @@ class LayoutNode : public Node { TVM_DLL static Layout make(const std::string& layout); static constexpr const char* _type_key = "Layout"; - TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object); }; /*! @@ -125,9 +125,9 @@ class LayoutNode : public Node { * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). * Layout for scalar is defined, while both its name and axes have size 0. */ -class Layout : public NodeRef { +class Layout : public ObjectRef { public: - explicit Layout(ObjectPtr n) : NodeRef(n) {} + explicit Layout(ObjectPtr n) : ObjectRef(n) {} /*! \brief default constructor */ Layout() = default; @@ -311,7 +311,7 @@ class Layout : public NodeRef { class BijectiveLayout; // Internal node container BijectiveLayout -class BijectiveLayoutNode : public Node { +class BijectiveLayoutNode : public Object { public: /*! \brief Describes how source axes can be mapped to the destination axes, * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n @@ -333,7 +333,7 @@ class BijectiveLayoutNode : public Node { } static constexpr const char* _type_key = "BijectiveLayout"; - TVM_DECLARE_NODE_TYPE_INFO(BijectiveLayoutNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object); TVM_DLL static BijectiveLayout make(const Layout& src_layout, const Layout& dst_layout); @@ -344,10 +344,10 @@ class BijectiveLayoutNode : public Node { * provides API to transform N-dimention tensor from the source indices (i0, i1, …, im) * to the destination indices (j0, j1, … jm). */ -class BijectiveLayout : public NodeRef { +class BijectiveLayout : public ObjectRef { public: BijectiveLayout() = default; - explicit BijectiveLayout(NodePtr n) : NodeRef(n) {} + explicit BijectiveLayout(ObjectPtr n) : ObjectRef(n) {} // Given the source shape, infer the destination shape. TVM_DLL Array ForwardShape(const Array& shape) const; diff --git a/include/tvm/expr.h b/include/tvm/expr.h index f27cb9879fb7..0605cc512690 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -38,20 +38,20 @@ namespace tvm { /*! \brief Base node of all expressions. */ -class ExprNode : public Node { +class ExprNode : public Object { public: /*! \brief The data type of the expression. */ DataType dtype; static constexpr const char* _type_key = "Expr"; - TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, Object); }; /*! \brief Container of all expressions. */ -class Expr : public NodeRef { +class Expr : public ObjectRef { public: Expr() {} - explicit Expr(ObjectPtr ptr) : NodeRef(ptr) {} + explicit Expr(ObjectPtr ptr) : ObjectRef(ptr) {} /*! * \brief construct from integer. * \param value The value to be constructed. @@ -78,16 +78,16 @@ class Expr : public NodeRef { }; /*! \brief Base node of all statements. */ -class StmtNode : public Node { +class StmtNode : public Object { public: static constexpr const char* _type_key = "Stmt"; - TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); }; /*! \brief Container of all statements */ -class Stmt : public NodeRef { +class Stmt : public ObjectRef { public: - TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode); + TVM_DEFINE_OBJECT_REF_METHODS(Stmt, ObjectRef, StmtNode); }; class Var; @@ -118,7 +118,7 @@ class Variable : public ExprNode { } static constexpr const char* _type_key = "Variable"; - TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode); }; /*! \brief a named variable in TVM */ @@ -156,8 +156,8 @@ class Var : public Expr { // Backward compatibility, will be removed later. using VarExpr = Var; using BaseExprNode = ExprNode; -using ExprHash = NodeHash; -using ExprEqual = NodeEqual; +using ExprHash = ObjectHash; +using ExprEqual = ObjectEqual; class Integer; /*! \brief ExprNode: constant integer. */ @@ -174,7 +174,7 @@ class IntImm : public ExprNode { TVM_DLL static Integer make(DataType t, int64_t value); static constexpr const char* _type_key = "IntImm"; - TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IntImm, ExprNode); }; /*! @@ -222,7 +222,7 @@ class Integer : public Expr { }; /*! \brief range over one dimension */ -class RangeNode : public Node { +class RangeNode : public Object { public: /*! \brief beginning of the node */ Expr min; @@ -238,11 +238,11 @@ class RangeNode : public Node { } static constexpr const char* _type_key = "Range"; - TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); }; /*! \brief Range constainer */ -class Range : public NodeRef { +class Range : public ObjectRef { public: /*! * \brief constructor by begin and end @@ -261,7 +261,7 @@ class Range : public NodeRef { */ static Range make_by_min_extent(Expr min, Expr extent); // declare range. - TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode); + TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode); }; /*! \brief container class of iteration variable. */ @@ -343,12 +343,12 @@ enum IterVarType : int { * \brief Iteration Variable, * represents an iteration over an integer interval. */ -class IterVar : public NodeRef { +class IterVar : public ObjectRef { public: // construct a new iter var without a domain IterVar() {} // construct from shared ptr. - explicit IterVar(ObjectPtr n) : NodeRef(n) {} + explicit IterVar(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -384,14 +384,14 @@ using Domain = Array; * \brief Dump the node to stderr, used for debug purposes. * \param node The input node */ -TVM_DLL void Dump(const NodeRef& node); +TVM_DLL void Dump(const ObjectRef& node); // definition of Node. /*! * \brief An iteration variable representing an iteration * over a one dimensional interval. */ -class IterVarNode : public Node { +class IterVarNode : public Object { public: /*! * \brief the domain of iteration, if known, can be None @@ -420,7 +420,7 @@ class IterVarNode : public Node { std::string thread_tag = ""); static constexpr const char* _type_key = "IterVar"; - TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); }; // inline implementations @@ -490,17 +490,22 @@ class IRPrinter { using FType = NodeFunctor; TVM_DLL static FType& vtable(); }; +} // namespace tvm -// default print function for all nodes +namespace tvm { +namespace runtime { +// default print function for all objects +// provide in the runtime namespace as this is where objectref originally comes from. inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*) IRPrinter(os).Print(n); return os; } +} // namespace runtime } // namespace tvm namespace std { template <> -struct hash<::tvm::IterVar> : public ::tvm::NodeHash { +struct hash<::tvm::IterVar> : public ::tvm::ObjectHash { }; } #endif // TVM_EXPR_H_ diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 33aa72b50805..c55a4695de4d 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -53,7 +53,7 @@ class UIntImm : public ExprNode { TVM_DLL static Expr make(DataType t, uint64_t value); static constexpr const char* _type_key = "UIntImm"; - TVM_DECLARE_NODE_TYPE_INFO(UIntImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(UIntImm, ExprNode); }; /*! \brief Floating point constants. */ @@ -70,7 +70,7 @@ class FloatImm : public ExprNode { TVM_DLL static Expr make(DataType t, double value); static constexpr const char* _type_key = "FloatImm"; - TVM_DECLARE_NODE_TYPE_INFO(FloatImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FloatImm, ExprNode); }; /*! \brief String constants, only used in asserts. */ @@ -87,7 +87,7 @@ class StringImm : public ExprNode { TVM_DLL Expr static make(std::string value); static constexpr const char* _type_key = "StringImm"; - TVM_DECLARE_NODE_TYPE_INFO(StringImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(StringImm, ExprNode); }; /*! @@ -107,7 +107,7 @@ class Cast : public ExprNode { TVM_DLL static Expr make(DataType t, Expr v); static constexpr const char* _type_key = "Cast"; - TVM_DECLARE_NODE_TYPE_INFO(Cast, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Cast, ExprNode); }; /*! @@ -132,14 +132,14 @@ class BinaryOpNode : public ExprNode { CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n"; CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = a.dtype(); node->a = std::move(a); node->b = std::move(b); return Expr(node); } - TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(T, ExprNode); }; /*! \brief a + b */ @@ -224,14 +224,14 @@ class CmpOpNode : public ExprNode { CHECK(a.defined()) << "ValueError: a is undefined\n"; CHECK(b.defined()) << "ValueError: b is undefined\n"; CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); return Expr(node); } - TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(T, ExprNode); }; /*! \brief a == b */ @@ -287,7 +287,7 @@ class And : public ExprNode { TVM_DLL static Expr make(Expr a, Expr b); static constexpr const char* _type_key = "And"; - TVM_DECLARE_NODE_TYPE_INFO(And, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(And, ExprNode); }; /*! \brief a || b */ @@ -307,7 +307,7 @@ class Or : public ExprNode { TVM_DLL static Expr make(Expr a, Expr b); static constexpr const char* _type_key = "Or"; - TVM_DECLARE_NODE_TYPE_INFO(Or, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Or, ExprNode); }; /*! \brief !a */ @@ -324,7 +324,7 @@ class Not : public ExprNode { TVM_DLL static Expr make(Expr a); static constexpr const char* _type_key = "Not"; - TVM_DECLARE_NODE_TYPE_INFO(Not, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Not, ExprNode); }; /*! @@ -353,7 +353,7 @@ class Select : public ExprNode { TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value); static constexpr const char* _type_key = "Select"; - TVM_DECLARE_NODE_TYPE_INFO(Select, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Select, ExprNode); }; /*! @@ -390,7 +390,7 @@ class Load : public ExprNode { TVM_DLL static Expr make(DataType dtype, Var buffer_var, Expr index, Expr predicate); static constexpr const char* _type_key = "Load"; - TVM_DECLARE_NODE_TYPE_INFO(Load, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Load, ExprNode); }; /*! @@ -421,7 +421,7 @@ class Ramp : public ExprNode { TVM_DLL static Expr make(Expr base, Expr stride, int lanes); static constexpr const char* _type_key = "Ramp"; - TVM_DECLARE_NODE_TYPE_INFO(Ramp, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Ramp, ExprNode); }; /*! \brief Create a vector where all the elements are value. */ @@ -441,7 +441,7 @@ class Broadcast : public ExprNode { TVM_DLL static Expr make(Expr value, int lanes); static constexpr const char* _type_key = "Broadcast"; - TVM_DECLARE_NODE_TYPE_INFO(Broadcast, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Broadcast, ExprNode); }; /*! @@ -466,7 +466,7 @@ class Let : public ExprNode { TVM_DLL static Expr make(Var var, Expr value, Expr body); static constexpr const char* _type_key = "Let"; - TVM_DECLARE_NODE_TYPE_INFO(Let, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Let, ExprNode); }; // Call node, represent a function call or a multi-dimensional array load. @@ -477,7 +477,7 @@ class Let : public ExprNode { // We should move most information into function itself and remove name. /*! \brief Base node of internal functions. */ -class FunctionBaseNode : public Node { +class FunctionBaseNode : public Object { public: /*! \return the name of the function */ virtual const std::string& func_name() const = 0; @@ -486,9 +486,9 @@ class FunctionBaseNode : public Node { }; /*! \brief reference to a function */ -class FunctionRef : public NodeRef { +class FunctionRef : public ObjectRef { public: - TVM_DEFINE_NODE_REF_METHODS(FunctionRef, NodeRef, FunctionBaseNode); + TVM_DEFINE_OBJECT_REF_METHODS(FunctionRef, ObjectRef, FunctionBaseNode); }; /*! @@ -560,7 +560,7 @@ class Call : public ExprNode { bool is_vectorizable() const; static constexpr const char* _type_key = "Call"; - TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Call, ExprNode); // Build-in intrinsics static constexpr const char* reinterpret = "reinterpret"; @@ -602,16 +602,16 @@ class Shuffle : public ExprNode { TVM_DLL static Expr make_extract_element(Expr vector, int index); static constexpr const char* _type_key = "Shuffle"; - TVM_DECLARE_NODE_TYPE_INFO(Shuffle, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Shuffle, ExprNode); }; // Reduce operator class CommReducerNode; -class CommReducer : public NodeRef { +class CommReducer : public ObjectRef { public: CommReducer() {} - explicit CommReducer(NodePtr n) : NodeRef(n) {} + explicit CommReducer(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -630,7 +630,7 @@ class CommReducer : public NodeRef { * \brief A commutative reducer node to represent a commutative * binary operator with identity element */ -class CommReducerNode : public Node { +class CommReducerNode : public Object { public: /*! \brief The left argument of reducer */ Array lhs; @@ -660,7 +660,7 @@ class CommReducerNode : public Node { } static constexpr const char* _type_key = "CommReducer"; - TVM_DECLARE_NODE_TYPE_INFO(CommReducerNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); }; inline const CommReducerNode* CommReducer::get() const { @@ -704,7 +704,7 @@ class Reduce : public ExprNode { } static constexpr const char* _type_key = "Reduce"; - TVM_DECLARE_NODE_TYPE_INFO(Reduce, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Reduce, ExprNode); }; /*! \brief Any shape. */ @@ -719,7 +719,7 @@ class Any : public ExprNode { TVM_DLL static Expr make(); static constexpr const char* _type_key = "Any"; - TVM_DECLARE_NODE_TYPE_INFO(Any, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Any, ExprNode); }; // Statements @@ -744,7 +744,7 @@ class LetStmt : public StmtNode { TVM_DLL static Stmt make(Var var, Expr value, Stmt body); static constexpr const char* _type_key = "LetStmt"; - TVM_DECLARE_NODE_TYPE_INFO(LetStmt, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LetStmt, StmtNode); }; /*! @@ -760,7 +760,7 @@ class LetStmt : public StmtNode { class AttrStmt : public StmtNode { public: /*! \brief this is attribute about certain node */ - NodeRef node; + ObjectRef node; /*! \brief the type key of the attribute */ std::string attr_key; /*! \brief The attribute value, value is well defined at current scope. */ @@ -775,13 +775,13 @@ class AttrStmt : public StmtNode { v->Visit("body", &body); } - TVM_DLL static Stmt make(NodeRef node, + TVM_DLL static Stmt make(ObjectRef node, std::string type_key, Expr value, Stmt body); static constexpr const char* _type_key = "AttrStmt"; - TVM_DECLARE_NODE_TYPE_INFO(AttrStmt, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmt, StmtNode); }; /*! @@ -808,7 +808,7 @@ class AssertStmt : public StmtNode { TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body); static constexpr const char* _type_key = "AssertStmt"; - TVM_DECLARE_NODE_TYPE_INFO(AssertStmt, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmt, StmtNode); }; // TODO(tvm-team): consider consolidate with AttrStmt. @@ -831,7 +831,7 @@ class ProducerConsumer : public StmtNode { TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body); static constexpr const char* _type_key = "ProducerConsumer"; - TVM_DECLARE_NODE_TYPE_INFO(ProducerConsumer, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumer, StmtNode); }; /*! @@ -876,7 +876,7 @@ class Store : public StmtNode { Expr predicate); static constexpr const char* _type_key = "Store"; - TVM_DECLARE_NODE_TYPE_INFO(Store, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Store, StmtNode); }; /*! @@ -906,7 +906,7 @@ class Provide : public StmtNode { Array args); static constexpr const char* _type_key = "Provide"; - TVM_DECLARE_NODE_TYPE_INFO(Provide, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Provide, StmtNode); }; /*! @@ -963,7 +963,7 @@ class Allocate : public StmtNode { const Array& extents); static constexpr const char* _type_key = "Allocate"; - TVM_DECLARE_NODE_TYPE_INFO(Allocate, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Allocate, StmtNode); }; /*! \brief Free the resources in the buffer before the scope ends. */ @@ -979,7 +979,7 @@ class Free : public StmtNode { TVM_DLL static Stmt make(Var buffer_var); static constexpr const char* _type_key = "Free"; - TVM_DECLARE_NODE_TYPE_INFO(Free, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Free, StmtNode); }; /*! @@ -1018,7 +1018,7 @@ class Realize : public StmtNode { Stmt body); static constexpr const char* _type_key = "Realize"; - TVM_DECLARE_NODE_TYPE_INFO(Realize, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Realize, StmtNode); }; /*! @@ -1040,7 +1040,7 @@ class Block : public StmtNode { TVM_DLL static Stmt make(const std::vector &stmts); static constexpr const char* _type_key = "Block"; - TVM_DECLARE_NODE_TYPE_INFO(Block, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Block, StmtNode); }; /*! @@ -1064,7 +1064,7 @@ class IfThenElse : public StmtNode { TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt()); static constexpr const char* _type_key = "IfThenElse"; - TVM_DECLARE_NODE_TYPE_INFO(IfThenElse, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElse, StmtNode); }; /*! @@ -1085,7 +1085,7 @@ class Evaluate : public StmtNode { TVM_DLL static Stmt make(Expr v); static constexpr const char* _type_key = "Evaluate"; - TVM_DECLARE_NODE_TYPE_INFO(Evaluate, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Evaluate, StmtNode); }; /*! \brief Additional annotation of for loop. */ @@ -1152,7 +1152,7 @@ class For : public StmtNode { } static constexpr const char* _type_key = "For"; - TVM_DECLARE_NODE_TYPE_INFO(For, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(For, StmtNode); }; /*! @@ -1182,7 +1182,7 @@ class Prefetch : public StmtNode { Region bounds); static constexpr const char* _type_key = "Prefetch"; - TVM_DECLARE_NODE_TYPE_INFO(Prefetch, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(Prefetch, StmtNode); }; /*! @@ -1636,7 +1636,7 @@ namespace std { template <> struct hash<::tvm::ir::TensorKey> { std::size_t operator()(const ::tvm::ir::TensorKey& k) const { - size_t lhs = ::tvm::NodeHash()(k.f); + size_t lhs = ::tvm::ObjectHash()(k.f); size_t rhs = static_cast(k.value_index); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); return lhs; diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 04ce7934ff2f..9b2632f87b3c 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -164,7 +164,7 @@ class ExprFunctor { virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExprDefault_(const Node* op, Args ...) { + virtual R VisitExprDefault_(const Object* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -255,7 +255,7 @@ class StmtFunctor { virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmtDefault_(const Node* op, Args ...) { + virtual R VisitStmtDefault_(const Object* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index b0b13df729cc..6e1fed5a8542 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -418,7 +418,7 @@ Stmt HoistIfThenElse(Stmt stmt); */ LoweredFunc MakeAPI(Stmt body, std::string name, - Array api_args, + Array api_args, int num_unpacked_args, bool is_restricted); diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index b85cf233a42f..cffcdcbdf5b8 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -87,7 +87,7 @@ class TVM_DLL IRVisitor { /*! * \brief recursively visit an IR node */ - virtual void Visit(const NodeRef& node) { + virtual void Visit(const ObjectRef& node) { static const FVisit& f = vtable(); if (node.defined()) f(node, this); } @@ -152,7 +152,7 @@ class TVM_DLL IRVisitor { * \param node The ir to be visited. * \param fvisit The visitor function to be applied. */ -TVM_DLL void PostOrderVisit(const NodeRef& node, std::function fvisit); +TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function fvisit); } // namespace ir } // namespace tvm diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index 6709f545cb39..3de6bfdbb087 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -131,7 +131,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode { } static constexpr const char* _type_key = "LoweredFunc"; - TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(LoweredFuncNode, Object); }; // Implementations of inline functions @@ -143,7 +143,7 @@ inline const LoweredFuncNode* LoweredFunc::operator->() const { namespace std { template <> -struct hash<::tvm::LoweredFunc> : public tvm::NodeHash { +struct hash<::tvm::LoweredFunc> : public tvm::ObjectHash { }; } diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 1a276ae695fc..d20fb288039c 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -35,7 +35,7 @@ namespace tvm { /*! \brief array node content in array */ -class ArrayNode : public Node { +class ArrayNode : public Object { public: /*! \brief the data content */ std::vector data; @@ -44,11 +44,11 @@ class ArrayNode : public Node { } static constexpr const char* _type_key = "Array"; - TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); }; /*! \brief map node content */ -class MapNode : public Node { +class MapNode : public Object { public: void VisitAttrs(AttrVisitor* visitor) { } @@ -63,12 +63,12 @@ class MapNode : public Node { ContainerType data; static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); }; /*! \brief specialized map node with string as key */ -class StrMapNode : public Node { +class StrMapNode : public Object { public: /*! \brief The corresponding conatiner type */ using ContainerType = std::unordered_map; @@ -80,7 +80,7 @@ class StrMapNode : public Node { ContainerType data; static constexpr const char* _type_key = "StrMap"; - TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Object); }; /*! @@ -138,13 +138,13 @@ class IterAdapter { */ template::value>::type > -class Array : public NodeRef { +class Array : public ObjectRef { public: /*! * \brief default constructor */ Array() { - data_ = make_node(); + data_ = make_object(); } /*! * \brief move constructor @@ -164,7 +164,7 @@ class Array : public NodeRef { * \brief constructor from pointer * \param n the container pointer */ - explicit Array(ObjectPtr n) : NodeRef(n) {} + explicit Array(ObjectPtr n) : ObjectRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -195,7 +195,7 @@ class Array : public NodeRef { * \param val The init value */ explicit Array(size_t n, const T& val) { - auto tmp_node = make_node(); + auto tmp_node = make_object(); for (size_t i = 0; i < n; ++i) { tmp_node->data.push_back(val); } @@ -227,7 +227,7 @@ class Array : public NodeRef { */ template void assign(IterType begin, IterType end) { - auto n = make_node(); + auto n = make_object(); for (IterType it = begin; it != end; ++it) { n->data.push_back(T(*it)); } @@ -257,7 +257,7 @@ class Array : public NodeRef { */ inline ArrayNode* CopyOnWrite() { if (data_.get() == nullptr || !data_.unique()) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); } @@ -333,13 +333,13 @@ template::value || std::is_base_of::value >::type, typename = typename std::enable_if::value>::type> -class Map : public NodeRef { +class Map : public ObjectRef { public: /*! * \brief default constructor */ Map() { - data_ = make_node(); + data_ = make_object(); } /*! * \brief move constructor @@ -352,13 +352,13 @@ class Map : public NodeRef { * \brief copy constructor * \param other source */ - Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) + Map(const Map &other) : ObjectRef(other.data_) { // NOLINT(*) } /*! * \brief constructor from pointer * \param n the container pointer */ - explicit Map(ObjectPtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : ObjectRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -410,7 +410,7 @@ class Map : public NodeRef { */ template void assign(IterType begin, IterType end) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); for (IterType i = begin; i != end; ++i) { n->data.emplace(std::make_pair(i->first, i->second)); } @@ -454,7 +454,7 @@ class Map : public NodeRef { */ inline MapNode* CopyOnWrite() { if (data_.get() == nullptr || !data_.unique()) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); } @@ -507,18 +507,18 @@ class Map : public NodeRef { // specialize of string map template -class Map : public NodeRef { +class Map : public ObjectRef { public: // for code reuse Map() { - data_ = make_node(); + data_ = make_object(); } Map(Map && other) { // NOLINT(*) data_ = std::move(other.data_); } - Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) + Map(const Map &other) : ObjectRef(other.data_) { // NOLINT(*) } - explicit Map(ObjectPtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : ObjectRef(n) {} template Map(IterType begin, IterType end) { assign(begin, end); @@ -541,7 +541,7 @@ class Map : public NodeRef { } template void assign(IterType begin, IterType end) { - auto n = make_node(); + auto n = make_object(); for (IterType i = begin; i != end; ++i) { n->data.emplace(std::make_pair(i->first, i->second)); } @@ -565,7 +565,7 @@ class Map : public NodeRef { } inline StrMapNode* CopyOnWrite() { if (data_.get() == nullptr || !data_.unique()) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); } diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 4014c3700596..bb5da415c463 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -56,105 +56,5 @@ using runtime::ObjectHash; using runtime::ObjectEqual; using runtime::make_object; -using NodeHash = ObjectHash; -using NodeEqual = ObjectEqual; -using Node = Object; - -/*! - * \brief Base class of all references to AST/IR nodes. - */ -class NodeRef : public ObjectRef { - public: - NodeRef() {} - explicit NodeRef(ObjectPtr n) : ObjectRef(n) {} -}; - -/*! - * \brief Allocate a node object. - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The NodePtr to the allocated object. - * \note This function is an alias of make_object. - */ -template -inline NodePtr make_node(Args&&... args) { - return runtime::make_object(std::forward(args)...); -} - -/*! - * \brief helper macro to declare type information in a base node. - */ -#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ - TVM_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) - -/*! - * \brief helper macro to declare type information in a terminal node - */ -#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ - TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent); - - -/*! - * \brief Macro to define common node ref methods. - * \param TypeName The name of the NodeRef. - * \param BaseTypeName The Base type. - * \param NodeName The node container type. - */ -#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ - TypeName() {} \ - explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ - : BaseTypeName(n) {} \ - const NodeName* operator->() const { \ - return static_cast(data_.get()); \ - } \ - operator bool() const { return this->defined(); } \ - using ContainerType = NodeName; - -/*! - * \brief Macro to define CopyOnWrite function in a NodeRef. - * \param NodeName The Type of the Node. - * - * CopyOnWrite will generate a unique copy of the internal node. - * The node will be copied if it is referenced by multiple places. - * The function returns the raw pointer to the node to allow modification - * of the content. - * - * \code - * - * MyCOWNodeRef ref, ref2; - * ref2 = ref; - * ref.CopyOnWrite()->value = new_value; - * assert(ref2->value == old_value); - * assert(ref->value == new_value); - * - * \endcode - */ -#define TVM_DEFINE_NODE_REF_COW(NodeName) \ - NodeName* CopyOnWrite() { \ - CHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - NodePtr n = make_node(*(operator->())); \ - ObjectPtr(std::move(n)).swap(data_); \ - } \ - return static_cast(data_.get()); \ - } - -/*! \brief Macro to make it easy to define node ref type given node */ -#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ - class TypeName : public ::tvm::NodeRef { \ - public: \ - TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \ - }; \ - -/*! - * \brief Macro to make it easy to define node ref type that - * has a CopyOnWrite member function. - */ -#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ - class TypeName : public BaseType { \ - public: \ - TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \ - TVM_DEFINE_NODE_REF_COW(NodeName); \ - }; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 34f584b63261..681d06897355 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -60,7 +60,7 @@ class OperationNode : public ir::FunctionBaseNode { /*! \brief optional tag of the operation */ std::string tag; /*! \brief additional attributes of the operation*/ - Map attrs; + Map attrs; /*! \return name of the operation */ const std::string& func_name() const final { return name; @@ -149,7 +149,7 @@ class OperationNode : public ir::FunctionBaseNode { static constexpr const char* _type_key = "Operation"; - TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object); }; /*! @@ -200,7 +200,7 @@ class PlaceholderOpNode : public OperationNode { DataType dtype); static constexpr const char* _type_key = "PlaceholderOp"; - TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; /*! @@ -228,7 +228,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; - TVM_DECLARE_BASE_NODE_INFO(BaseComputeOpNode, OperationNode); + TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode); }; @@ -269,12 +269,12 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { } static Operation make(std::string name, std::string tag, - Map attrs, + Map attrs, Array axis, Array body); static constexpr const char* _type_key = "ComputeOp"; - TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, BaseComputeOpNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); }; /*! @@ -334,7 +334,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; - TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode); }; /*! @@ -407,7 +407,7 @@ class ScanOpNode : public OperationNode { } static Operation make(std::string name, std::string tag, - Map attrs, + Map attrs, IterVar axis, Array init, Array update, @@ -415,7 +415,7 @@ class ScanOpNode : public OperationNode { Array input); static constexpr const char* _type_key = "ScanOp"; - TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); }; /*! @@ -472,14 +472,14 @@ class ExternOpNode : public OperationNode { } TVM_DLL static Operation make(std::string name, std::string tag, - Map attrs, + Map attrs, Array inputs, Array input_placeholders, Array output_placeholders, Stmt body); static constexpr const char* _type_key = "ExternOp"; - TVM_DECLARE_NODE_TYPE_INFO(ExternOpNode, OperationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); }; /*! @@ -540,13 +540,13 @@ class HybridOpNode : public OperationNode { } TVM_DLL static Operation make(std::string name, std::string tag, - Map attrs, + Map attrs, Array inputs, Array outputs, Stmt body); static constexpr const char* _type_key = "HybridOp"; - TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode); }; /*! \brief The compute function to specify the input source of a Tensor */ @@ -578,7 +578,7 @@ TVM_DLL Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor", std::string tag = "", - Map attrs = {}); + Map attrs = {}); /*! * \brief Construct a new tensor by computing over shape, @@ -593,7 +593,7 @@ TVM_DLL Array compute(Array shape, FBatchCompute fcompute, std::string name = "tensor", std::string tag = "", - Map attrs = {}); + Map attrs = {}); /*! * \brief Construct new tensors by scan. @@ -613,14 +613,14 @@ TVM_DLL Array scan(Array init, Array inputs = Array(), std::string name = "scan", std::string tag = "", - Map attrs = {}); + Map attrs = {}); // same as compute, specialized for different fcompute function inline Tensor compute(Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { + Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0]); }; return compute(shape, fc, name, tag, attrs); } @@ -628,7 +628,7 @@ inline Tensor compute(Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { + Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0], i[1]); }; return compute(shape, fc, name, tag, attrs); } @@ -636,7 +636,7 @@ inline Tensor compute(Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { + Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2]); }; return compute(shape, fc, name, tag, attrs); } @@ -644,7 +644,7 @@ inline Tensor compute(Array shape, std::function f, std::string name = "tensor", std::string tag = "", - Map attrs = {}) { + Map attrs = {}) { FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2], i[3]); }; return compute(shape, fc, name, tag, attrs); } diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index c9f7a580621f..b301a18ea313 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -115,15 +115,15 @@ inline TVMPODValue_::operator tvm::Expr() const { Object* ptr = static_cast(value_.v_handle); if (ptr->IsInstance()) { - return IterVar(ObjectPtr(ptr))->var; + return IterVar(ObjectPtr(ptr))->var; } if (ptr->IsInstance()) { - return Tensor(ObjectPtr(ptr))(); + return Tensor(ObjectPtr(ptr))(); } CHECK(ObjectTypeChecker::Check(ptr)) << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); - return Expr(ObjectPtr(ptr)); + return Expr(ObjectPtr(ptr)); } inline TVMPODValue_::operator tvm::Integer() const { @@ -138,7 +138,7 @@ inline TVMPODValue_::operator tvm::Integer() const { CHECK(ObjectTypeChecker::Check(ptr)) << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); - return Integer(ObjectPtr(ptr)); + return Integer(ObjectPtr(ptr)); } } // namespace runtime } // namespace tvm diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index a74353239a00..dac39e014cc7 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -38,7 +38,7 @@ namespace relay { class PatternNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Pattern"; - TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object); }; /*! @@ -49,10 +49,10 @@ class PatternNode : public RelayNode { * * ADT pattern matching thus takes a list of values and binds to the first that accepts the value. */ -class Pattern : public NodeRef { +class Pattern : public ObjectRef { public: Pattern() {} - explicit Pattern(ObjectPtr p) : NodeRef(p) {} + explicit Pattern(ObjectPtr p) : ObjectRef(p) {} using ContainerType = PatternNode; }; @@ -71,10 +71,13 @@ class PatternWildcardNode : public PatternNode { } static constexpr const char* _type_key = "relay.PatternWildcard"; - TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode); }; -RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern); +class PatternWildcard : public Pattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PatternWildcard, Pattern, PatternWildcardNode); +}; /*! \brief A var pattern. Accept all input and bind to a var. */ class PatternVar; @@ -94,10 +97,13 @@ class PatternVarNode : public PatternNode { } static constexpr const char* _type_key = "relay.PatternVar"; - TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode); }; -RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern); +class PatternVar : public Pattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode); +}; /*! * \brief ADT constructor. @@ -132,10 +138,13 @@ class ConstructorNode : public ExprNode { } static constexpr const char* _type_key = "relay.Constructor"; - TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr); +class Constructor : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Constructor, Expr, ConstructorNode); +}; /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ class PatternConstructor; @@ -158,10 +167,13 @@ class PatternConstructorNode : public PatternNode { } static constexpr const char* _type_key = "relay.PatternConstructor"; - TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode); }; -RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); +class PatternConstructor : public Pattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode); +}; /*! \brief A tuple pattern. Matches a tuple, binds recursively. */ class PatternTuple; @@ -181,10 +193,13 @@ class PatternTupleNode : public PatternNode { } static constexpr const char* _type_key = "relay.PatternTuple"; - TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode); }; -RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern); +class PatternTuple : public Pattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode); +}; /*! * \brief Stores all data for an Algebraic Data Type (ADT). @@ -225,15 +240,18 @@ class TypeDataNode : public TypeNode { tvm::Array constructors); static constexpr const char* _type_key = "relay.TypeData"; - TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type); +class TypeData : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode); +}; /*! \brief A clause in a match expression. */ class Clause; /*! \brief Clause container node. */ -class ClauseNode : public Node { +class ClauseNode : public Object { public: /*! \brief The pattern the clause matches. */ Pattern lhs; @@ -248,10 +266,13 @@ class ClauseNode : public Node { TVM_DLL static Clause make(Pattern lhs, Expr rhs); static constexpr const char* _type_key = "relay.Clause"; - TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object); }; -RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef); +class Clause : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode); +}; /*! \brief ADT pattern matching exression. */ class Match; @@ -280,10 +301,13 @@ class MatchNode : public ExprNode { TVM_DLL static Match make(Expr data, tvm::Array pattern, bool complete = true); static constexpr const char* _type_key = "relay.Match"; - TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr); +class Match : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Match, Expr, MatchNode); +}; } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index ccdc871e8a78..1c7fc1c45480 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -196,7 +196,7 @@ struct SqueezeAttrs : public tvm::AttrsNode { }; // struct SqueezeAttrs struct SplitAttrs : public tvm::AttrsNode { - NodeRef indices_or_sections; + ObjectRef indices_or_sections; int axis; TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 32f9c32f468a..d64d05f119bb 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -53,53 +53,11 @@ namespace relay { (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ } -/*! - * \brief We always used NodeRef for referencing nodes. - * - * By default, NodeRef is a std::shared_ptr of node - */ -using NodeRef = tvm::NodeRef; - -/*! - * \brief Content data type. - */ -using DataType = ::tvm::DataType; - /*! * \brief Symbolic expression for tensor shape. */ using IndexExpr = ::tvm::Expr; -/*! - * \brief Hash function for nodes. - * e.g. std::unordered_map - */ -using NodeHash = ::tvm::NodeHash; -/*! - * \brief Equality check function for nodes. - */ -using NodeEqual = ::tvm::NodeEqual; - -/*! - * \brief Macro to make it easy to define node ref type given node - * \param TypeName The name of the reference type. - * \param NodeName The internal container name. - * \param NodeRefBase The base type. - */ -#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ - class TypeName : public NodeRefBase { \ - public: \ - TypeName() {} \ - explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ - : NodeRefBase(n) { \ - } \ - const NodeName* operator->() const { \ - return static_cast(get()); \ - } \ - operator bool() { return this->defined(); } \ - using ContainerType = NodeName; \ - }; - /*! * \brief The source name in the Span * \sa SourceNameNode, Span @@ -108,7 +66,7 @@ class SourceName; /*! * \brief The name of a source fragment. */ -class SourceNameNode : public Node { +class SourceNameNode : public Object { public: /*! \brief The source name. */ std::string name; @@ -116,20 +74,20 @@ class SourceNameNode : public Node { void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } static constexpr const char* _type_key = "relay.SourceName"; - TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); }; /*! * \brief The source name of a file span. * \sa SourceNameNode, Span */ -class SourceName : public NodeRef { +class SourceName : public ObjectRef { public: /*! \brief default constructor */ SourceName() {} /*! \brief constructor from node pointer */ - explicit SourceName(NodePtr n) : NodeRef(n) {} + explicit SourceName(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -157,7 +115,7 @@ class Span; /*! * \brief Stores locations in frontend source that generated a node. */ -class SpanNode : public Node { +class SpanNode : public Object { public: /*! \brief The source name */ SourceName source; @@ -175,22 +133,25 @@ class SpanNode : public Node { TVM_DLL static Span make(SourceName source, int lineno, int col_offset); static constexpr const char* _type_key = "relay.Span"; - TVM_DECLARE_NODE_TYPE_INFO(SpanNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); }; -RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef); +class Span : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); +}; /*! * \brief This is the base node container of all relay structures. */ -class RelayNode : public Node { +class RelayNode : public Object { public: /*! \brief The location of the program in a SourceFragment can be null, * check with span.defined() */ mutable Span span; static constexpr const char* _type_key = "relay.Node"; - TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(RelayNode, Object); }; /*! @@ -201,7 +162,7 @@ class RelayNode : public Node { * * \note Do not create Id directly, they are created in Var. */ -class IdNode : public Node { +class IdNode : public Object { public: /*! * \brief The name of the variable, @@ -215,10 +176,13 @@ class IdNode : public Node { } static constexpr const char* _type_key = "relay.Id"; - TVM_DECLARE_NODE_TYPE_INFO(IdNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object); }; -RELAY_DEFINE_NODE_REF(Id, IdNode, NodeRef); +class Id : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); +}; struct Module; diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index ef3387b1893b..4cd999fb4480 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -118,7 +118,7 @@ class ErrorReporter { * \param node The expression or type to report the error at. * \param err The error message to report. */ - inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { + inline void ReportAt(const GlobalVar& global, const ObjectRef& node, std::stringstream& err) { std::string err_msg = err.str(); this->ReportAt(global, node, Error(err_msg)); } @@ -134,7 +134,7 @@ class ErrorReporter { * \param node The expression or type to report the error at. * \param err The error to report. */ - void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err); + void ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err); /*! \brief Render all reported errors and exit the program. * @@ -154,8 +154,8 @@ class ErrorReporter { private: std::vector errors_; - std::unordered_map, NodeHash, NodeEqual> node_to_error_; - std::unordered_map node_to_gv_; + std::unordered_map, ObjectHash, ObjectEqual> node_to_error_; + std::unordered_map node_to_gv_; }; } // namespace relay diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 01a73d5396cc..47c83696c3e5 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -67,10 +67,13 @@ class ExprNode : public RelayNode { inline const TTypeNode* type_as() const; static constexpr const char* _type_key = "relay.Expr"; - TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode); + TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, RelayNode); }; -RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef); +class Expr : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Expr, ObjectRef, ExprNode); +}; /*! * \brief Constant tensor, backed by an NDArray on the cpu(0) device. @@ -104,10 +107,13 @@ class ConstantNode : public ExprNode { TVM_DLL static Constant make(runtime::NDArray data); static constexpr const char* _type_key = "relay.Constant"; - TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr); +class Constant : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Constant, Expr, ConstantNode); +}; /*! \brief Tuple of multiple Exprs */ class Tuple; @@ -126,10 +132,13 @@ class TupleNode : public ExprNode { TVM_DLL static Tuple make(tvm::Array fields); static constexpr const char* _type_key = "relay.Tuple"; - TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); +class Tuple : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); +}; /*! * \brief Local variables used in the let expression. @@ -179,10 +188,13 @@ class VarNode : public ExprNode { Type type_annotation); static constexpr const char* _type_key = "relay.Var"; - TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Var, VarNode, Expr); +class Var : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); +}; /*! * \brief Global variable that leaves in the top-level module. @@ -206,10 +218,13 @@ class GlobalVarNode : public ExprNode { TVM_DLL static GlobalVar make(std::string name_hint); static constexpr const char* _type_key = "relay.GlobalVar"; - TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr); +class GlobalVar : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, Expr, GlobalVarNode); +}; /*! * \brief Function (subgraph in computational graph) @@ -297,14 +312,19 @@ class FunctionNode : public ExprNode { tvm::Map GetParams() const; static constexpr const char* _type_key = "relay.Function"; - TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); +class Function : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); +}; -TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key); -TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data); +TVM_DLL ObjectRef FunctionGetAttr(const Function& func, const std::string& key); +TVM_DLL Function FunctionSetAttr(const Function& func, + const std::string& key, + const ObjectRef& data); /*! * \brief Call corresponds to operator invocation. @@ -363,10 +383,13 @@ class CallNode : public ExprNode { Array type_args = Array()); static constexpr const char* _type_key = "relay.Call"; - TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Call, CallNode, Expr); +class Call : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode); +}; /*! * \brief Let binding that binds a local var and optionally a type annotation. @@ -401,10 +424,13 @@ class LetNode : public ExprNode { TVM_DLL static Let make(Var var, Expr value, Expr body); static constexpr const char* _type_key = "relay.Let"; - TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); +class Let : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Let, Expr, LetNode); +}; /*! * \brief Condition expression @@ -439,10 +465,13 @@ class IfNode : public ExprNode { TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch); static constexpr const char* _type_key = "relay.If"; - TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(If, IfNode, Expr); +class If : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); +}; /*! \brief Get index-th field out of a tuple. */ class TupleGetItem; @@ -463,10 +492,13 @@ class TupleGetItemNode : public ExprNode { TVM_DLL static TupleGetItem make(Expr tuple, int index); static constexpr const char* _type_key = "relay.TupleGetItem"; - TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); +class TupleGetItem : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode); +}; /*! \brief Create a new Reference out of initial value. */ class RefCreate; @@ -484,10 +516,13 @@ class RefCreateNode : public ExprNode { TVM_DLL static RefCreate make(Expr value); static constexpr const char* _type_key = "relay.RefCreate"; - TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr); +class RefCreate : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, Expr, RefCreateNode); +}; /*! \brief Get value out of Reference. */ class RefRead; @@ -505,10 +540,13 @@ class RefReadNode : public ExprNode { TVM_DLL static RefRead make(Expr ref); static constexpr const char* _type_key = "relay.RefRead"; - TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr); +class RefRead : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RefRead, Expr, RefReadNode); +}; /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ class RefWrite; class RefWriteNode : public ExprNode { @@ -528,10 +566,13 @@ class RefWriteNode : public ExprNode { TVM_DLL static RefWrite make(Expr ref, Expr value); static constexpr const char* _type_key = "relay.RefWrite"; - TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefWriteNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr); +class RefWrite : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, Expr, RefWriteNode); +}; /*! * \brief Base class of the temporary expression. @@ -554,10 +595,13 @@ class TempExprNode : public ExprNode { virtual Expr Realize() const = 0; static constexpr const char* _type_key = "relay.TempExpr"; - TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode); + TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); +class TempExpr : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, Expr, TempExprNode); +}; // implementataions inline const Type& ExprNode::checked_type() const { @@ -583,7 +627,7 @@ inline const TTypeNode* ExprNode::type_as() const { } /*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */ -std::string PrettyPrint(const NodeRef& node); +std::string PrettyPrint(const ObjectRef& node); /*! * \brief Render the node as a string in the Relay text format. @@ -593,7 +637,7 @@ std::string PrettyPrint(const NodeRef& node); * additional comment block to an expr. * \return The text representation. */ -std::string AsText(const NodeRef& node, +std::string AsText(const ObjectRef& node, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 722f73f03826..f1d7152f48c0 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -116,7 +116,7 @@ class ExprFunctor { virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExprDefault_(const Node* op, Args...) { + virtual R VisitExprDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } @@ -177,7 +177,7 @@ class ExprVisitor protected: // Internal visiting counter - std::unordered_map visit_counter_; + std::unordered_map visit_counter_; }; /*! @@ -227,7 +227,7 @@ class ExprMutator protected: /*! \brief Internal map used for memoization. */ - std::unordered_map memo_; + std::unordered_map memo_; }; /*! diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index d5d783d4804a..8ef7f6e4ed89 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -72,13 +72,13 @@ CreateInterpreter(Module mod, DLContext context, Target target); class ValueNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Value"; - TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); + TVM_DECLARE_BASE_OBJECT_INFO(ValueNode, RelayNode); }; -class Value : public NodeRef { +class Value : public ObjectRef { public: Value() {} - explicit Value(ObjectPtr n) : NodeRef(n) {} + explicit Value(ObjectPtr n) : ObjectRef(n) {} const ValueNode* operator->() const { return static_cast(get()); } @@ -114,10 +114,13 @@ class ClosureNode : public ValueNode { TVM_DLL static Closure make(tvm::Map env, Function func); static constexpr const char* _type_key = "relay.Closure"; - TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ClosureNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); +class Closure : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Closure, Value, ClosureNode); +}; /*! \brief A Relay Recursive Closure. A closure that has a name. */ class RecClosure; @@ -140,10 +143,13 @@ class RecClosureNode : public ValueNode { TVM_DLL static RecClosure make(Closure clos, Var bind); static constexpr const char* _type_key = "relay.RecClosure"; - TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value); +class RecClosure : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, Value, RecClosureNode); +}; /*! \brief A tuple value. */ class TupleValue; @@ -159,10 +165,13 @@ struct TupleValueNode : ValueNode { TVM_DLL static TupleValue make(tvm::Array value); static constexpr const char* _type_key = "relay.TupleValue"; - TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TupleValueNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value); +class TupleValue : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TupleValue, Value, TupleValueNode); +}; /*! \brief A tensor value. */ class TensorValue; @@ -179,10 +188,13 @@ struct TensorValueNode : ValueNode { TVM_DLL static TensorValue make(runtime::NDArray data); static constexpr const char* _type_key = "relay.TensorValue"; - TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorValueNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); +class TensorValue : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TensorValue, Value, TensorValueNode); +}; /*! \brief A reference value. */ class RefValue; @@ -199,10 +211,13 @@ struct RefValueNode : ValueNode { TVM_DLL static RefValue make(Value val); static constexpr const char* _type_key = "relay.RefValue"; - TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefValueNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); +class RefValue : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RefValue, Value, RefValueNode); +}; /*! \brief An ADT constructor value. */ class ConstructorValue; @@ -226,10 +241,13 @@ struct ConstructorValueNode : ValueNode { Constructor construtor = {}); static constexpr const char* _type_key = "relay.ConstructorValue"; - TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueNode, ValueNode); }; -RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value); +class ConstructorValue : public Value { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, Value, ConstructorValueNode); +}; } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 0d3f46cd3cc0..262c82df5c5d 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -258,7 +258,7 @@ class ModuleNode : public RelayNode { const tvm::Map& type_definitions = {}); static constexpr const char* _type_key = "relay.Module"; - TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object); private: /*! \brief Helper function for registering a typedef's constructors */ @@ -285,9 +285,9 @@ class ModuleNode : public RelayNode { std::unordered_set import_set_; }; -struct Module : public NodeRef { +struct Module : public ObjectRef { Module() {} - explicit Module(ObjectPtr<::tvm::Object> p) : NodeRef(p) {} + explicit Module(ObjectPtr<::tvm::Object> p) : ObjectRef(p) {} ModuleNode* operator->() const { return static_cast(get_mutable()); diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 90f2937c929b..b4495191dd24 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -106,7 +106,7 @@ class OpNode : public relay::ExprNode { } static constexpr const char* _type_key = "relay.Op"; - TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, ExprNode); private: // friend class @@ -431,7 +431,7 @@ inline OpRegistry& OpRegistry::describe( inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std::string& type, const std::string& description) { - auto n = make_node(); + auto n = make_object(); n->name = name; n->type_info = type; n->description = description; diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 54ea707905e5..9cfa755ef813 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -180,7 +180,7 @@ using FTVMLegalize = runtime::TypedPackedFunc< using FForwardRewrite = runtime::TypedPackedFunc< Expr(const Call& ref_call, const Array& new_args, - const NodeRef& ctx)>; + const ObjectRef& ctx)>; /*! * \brief Gradient for a specific op. diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index d84d43af82a7..71a024f37a19 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -102,7 +102,7 @@ class PatternFunctor { Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPatternDefault_(const Node* op, Args...) { + virtual R VisitPatternDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } @@ -162,7 +162,7 @@ class PatternMutator /*! \brief Used to visit the vars inside of patterns. */ virtual Constructor VisitConstructor(const Constructor& c); private: - std::unordered_map var_map_; + std::unordered_map var_map_; }; } // namespace relay diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 52be6a0f3781..2d1e45f8ee0f 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -109,7 +109,7 @@ class PassContextNode : public RelayNode { } static constexpr const char* _type_key = "relay.PassContext"; - TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, RelayNode); }; /*! @@ -125,10 +125,10 @@ class PassContextNode : public RelayNode { * * \endcode */ -class PassContext : public NodeRef { +class PassContext : public ObjectRef { public: PassContext() {} - explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {} + explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} /*! * \brief const accessor. * \return const access pointer. @@ -207,10 +207,13 @@ class PassInfoNode : public RelayNode { tvm::Array required); static constexpr const char* _type_key = "relay.PassInfo"; - TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, RelayNode); }; -TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode) +class PassInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); +}; class Pass; @@ -251,10 +254,10 @@ class PassNode : public RelayNode { void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.Pass"; - TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); + TVM_DECLARE_BASE_OBJECT_INFO(PassNode, RelayNode); }; -class Pass : public NodeRef { +class Pass : public ObjectRef { public: /*! * \brief Transform mod using the default PassContext in the current scope. @@ -283,7 +286,7 @@ class Pass : public NodeRef { return node->operator()(mod, pass_ctx); } - TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode); + TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); }; class SequentialNode; @@ -309,7 +312,7 @@ class Sequential : public Pass { TVM_DLL Sequential(tvm::Array passes, std::string name = "sequential"); Sequential() = default; - explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {} + explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {} const SequentialNode* operator->() const; using ContainerType = Sequential; @@ -638,7 +641,7 @@ TVM_DLL Function InferType(const Function& f, */ TVM_DLL Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_attr_name, - std::function fcontext = nullptr, + std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); /*! @@ -655,7 +658,7 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr, */ TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, - std::function fcontext = nullptr, + std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); /*! diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index e0c056c1216b..08fe957d8a78 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -41,7 +41,7 @@ using Any = tvm::ir::Any; class TypeNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Type"; - TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); }; /*! @@ -55,10 +55,10 @@ class TypeNode : public RelayNode { * There are also advanced types to support generic(polymorphic types), * which can be ignored when first reading the code base. */ -class Type : public NodeRef { +class Type : public ObjectRef { public: Type() {} - explicit Type(ObjectPtr p) : NodeRef(p) {} + explicit Type(ObjectPtr p) : ObjectRef(p) {} using ContainerType = TypeNode; }; @@ -70,10 +70,13 @@ class Type : public NodeRef { class BaseTensorTypeNode : public TypeNode { public: static constexpr const char* _type_key = "relay.BaseTensorType"; - TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode); + TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type); +class BaseTensorType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode); +}; /*! * \brief This is the most commonly used type in relay. @@ -113,10 +116,13 @@ class TensorTypeNode : public BaseTensorTypeNode { TVM_DLL static TensorType Scalar(DataType dtype); static constexpr const char* _type_key = "relay.TensorType"; - TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode); }; -RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); +class TensorType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode); +}; /*! \brief Possible kinds of Type. */ enum Kind : int { @@ -168,10 +174,13 @@ class TypeVarNode : public TypeNode { TVM_DLL static TypeVar make(std::string name, Kind kind); static constexpr const char* _type_key = "relay.TypeVar"; - TVM_DECLARE_NODE_TYPE_INFO(TypeVarNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type); +class TypeVar : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); +}; /*! * \brief A global type variable that is used for defining new types or type aliases. @@ -197,10 +206,13 @@ class GlobalTypeVarNode : public TypeNode { TVM_DLL static GlobalTypeVar make(std::string name, Kind kind); static constexpr const char* _type_key = "relay.GlobalTypeVar"; - TVM_DECLARE_NODE_TYPE_INFO(GlobalTypeVarNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(GlobalTypeVar, GlobalTypeVarNode, Type); +class GlobalTypeVar : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode); +}; /*! * \brief Type application. @@ -225,10 +237,13 @@ class TypeCallNode : public TypeNode { TVM_DLL static TypeCall make(Type func, tvm::Array args); static constexpr const char* _type_key = "relay.TypeCall"; - TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type); +class TypeCall : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeCall, Type, TypeCallNode); +}; /*! * \brief IncompleteType. @@ -253,10 +268,13 @@ class IncompleteTypeNode : public TypeNode { TVM_DLL static IncompleteType make(Kind kind); static constexpr const char* _type_key = "relay.IncompleteType"; - TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type); +class IncompleteType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); +}; /*! * \brief Potential Constraints in the type. @@ -267,10 +285,13 @@ class TypeConstraint; class TypeConstraintNode : public TypeNode { public: static constexpr const char* _type_key = "relay.TypeConstraint"; - TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, TypeNode); + TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, Type); +class TypeConstraint : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode); +}; class FuncType; /*! @@ -311,10 +332,13 @@ class FuncTypeNode : public TypeNode { tvm::Array type_constraints); static constexpr const char* _type_key = "relay.FuncType"; - TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type); +class FuncType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); +}; /*! * \brief The type of tuple values. @@ -338,10 +362,13 @@ class TupleTypeNode : public TypeNode { TVM_DLL static TupleType make(tvm::Array fields); static constexpr const char* _type_key = "relay.TupleType"; - TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type); +class TupleType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TupleType, Type, TupleTypeNode); +}; /*! * \brief The type of reference values. @@ -365,10 +392,13 @@ class RefTypeNode : public TypeNode { TVM_DLL static RefType make(Type value); static constexpr const char* _type_key = "relay.RefType"; - TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RefTypeNode, TypeNode); }; -RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type); +class RefType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode); +}; class TypeReporter; @@ -376,7 +406,7 @@ class TypeReporter; * \brief reporter that reports back to the * type resolution information. */ -class TypeReporterNode : public Node { +class TypeReporterNode : public Object { public: /*! * \brief Create a type equality constraint. @@ -408,7 +438,7 @@ class TypeReporterNode : public Node { * \brief Set the location at which to report unification errors. * \param ref The program node to report the error. */ - TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0; + TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0; /*! * \brief Retrieve the current global module. @@ -420,17 +450,17 @@ class TypeReporterNode : public Node { void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.TypeReporter"; - TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); }; /*! * \brief Container class of TypeReporter. * \sa TypeReporterNode */ -class TypeReporter : public NodeRef { +class TypeReporter : public ObjectRef { public: TypeReporter() {} - explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : NodeRef(n) { + explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { } TypeReporterNode* operator->() const { return const_cast( @@ -502,10 +532,13 @@ class TypeRelationNode : public TypeConstraintNode { Attrs attrs); static constexpr const char* _type_key = "relay.TypeRelation"; - TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode); + TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); }; -RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint); +class TypeRelation : public TypeConstraint { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); +}; // The following fields contains advanced typing // Only keep the class name and reserved for future usage. diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 96215daf4a7a..7d1494707af8 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -700,7 +700,12 @@ struct ObjectEqual { TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ TypeName::_GetOrAllocRuntimeTypeIndex() - +/* + * \brief Define object reference methods. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ TypeName() {} \ explicit TypeName( \ @@ -712,17 +717,54 @@ struct ObjectEqual { operator bool() const { return data_ != nullptr; } \ using ContainerType = ObjectName; -#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \ +/* + * \brief Define object reference methods of whose content is mutable. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + * \note We recommend making objects immutable when possible. + * This macro is only reserved for objects that stores runtime states. + */ +#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ TypeName() {} \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ - ObjectName* operator->() { \ + ObjectName* operator->() const { \ return static_cast(data_.get()); \ } \ operator bool() const { return data_ != nullptr; } \ using ContainerType = ObjectName; +/*! + * \brief Define CopyOnWrite function in an ObjectRef. + * \param ObjectName The Type of the Node. + * + * CopyOnWrite will generate a unique copy of the internal node. + * The node will be copied if it is referenced by multiple places. + * The function returns the raw pointer to the node to allow modification + * of the content. + * + * \code + * + * MyCOWObjectRef ref, ref2; + * ref2 = ref; + * ref.CopyOnWrite()->value = new_value; + * assert(ref2->value == old_value); + * assert(ref->value == new_value); + * + * \endcode + */ +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ + ObjectName* CopyOnWrite() { \ + CHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = make_object(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ + } + // Implementations details below // Object reference counting. #if TVM_OBJECT_ATOMIC_REF_COUNTER @@ -832,10 +874,6 @@ inline SubRef Downcast(BaseRef ref) { } } // namespace runtime - -template -using NodePtr = runtime::ObjectPtr; - } // namespace tvm #endif // TVM_RUNTIME_OBJECT_H_ diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 3f4ee38a7695..01caf5a02c91 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -53,10 +53,10 @@ enum AttachType : int { }; /*! \brief Stage, contains scheduling for a stage of computation. */ -class Stage : public NodeRef { +class Stage : public ObjectRef { public: Stage() {} - explicit Stage(ObjectPtr n) : NodeRef(n) {} + explicit Stage(ObjectPtr n) : ObjectRef(n) {} /*! * \brief create a new schedule for op. * \param op The operator in the schedule @@ -277,10 +277,10 @@ class Stage : public NodeRef { * For operations and all the operations they depend on. * The schedule per Operation is named as stage. */ -class Schedule : public NodeRef { +class Schedule : public ObjectRef { public: Schedule() {} - explicit Schedule(ObjectPtr n) : NodeRef(n) {} + explicit Schedule(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Get a copy of current schedule. * \return The copied schedule. @@ -400,10 +400,10 @@ class Schedule : public NodeRef { * \brief The schedule relation between IterVars * can be Split, Fuse. */ -class IterVarRelation : public NodeRef { +class IterVarRelation : public ObjectRef { public: IterVarRelation() {} - explicit IterVarRelation(ObjectPtr n) : NodeRef(n) {} + explicit IterVarRelation(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -414,10 +414,10 @@ class IterVarRelation : public NodeRef { /*! * \brief Additional scheduable attributes about IterVar. */ -class IterVarAttr : public NodeRef { +class IterVarAttr : public ObjectRef { public: IterVarAttr() {} - explicit IterVarAttr(ObjectPtr n) : NodeRef(n) {} + explicit IterVarAttr(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -440,7 +440,7 @@ class IterVarAttr : public NodeRef { * * The group stage node can be attached to IterVars as in normal stage. */ -class StageNode : public Node { +class StageNode : public Object { public: /*! * \brief The operation of stage, can be different from original op. @@ -515,11 +515,11 @@ class StageNode : public Node { } static constexpr const char* _type_key = "Stage"; - TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); }; /*! \brief node container for schedule */ -class ScheduleNode : public Node { +class ScheduleNode : public Object { public: /*! \brief The output operations in original data flow graph */ Array outputs; @@ -538,7 +538,7 @@ class ScheduleNode : public Node { * \brief Internal stage map to map internal ops to stages. * This is created on demand and can be invalidated. */ - std::unordered_map op2stage_cache_; + std::unordered_map op2stage_cache_; void VisitAttrs(AttrVisitor* v) { v->Visit("outputs", &outputs); @@ -576,7 +576,7 @@ class ScheduleNode : public Node { TVM_DLL static Schedule make(Array ops); static constexpr const char* _type_key = "Schedule"; - TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object); }; /*! @@ -589,7 +589,7 @@ inline Schedule create_schedule(Array ops) { } /*! \brief node container for IterVar attr */ -class IterVarAttrNode : public Node { +class IterVarAttrNode : public Object { public: /*! \brief The iteration type. */ IterVarType iter_type{kDataPar}; @@ -630,14 +630,14 @@ class IterVarAttrNode : public Node { } static constexpr const char* _type_key = "IterVarAttr"; - TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object); }; /*! \brief base node of iteration var */ -class IterVarRelationNode : public Node { +class IterVarRelationNode : public Object { public: static constexpr const char* _type_key = "IterVarRelation"; - TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node); + TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object); }; /*! @@ -672,7 +672,7 @@ class SplitNode : public IterVarRelationNode { Expr nparts); static constexpr const char* _type_key = "Split"; - TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode); }; /*! @@ -697,7 +697,7 @@ class FuseNode : public IterVarRelationNode { IterVar outer, IterVar inner, IterVar fused); static constexpr const char* _type_key = "Fuse"; - TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode); }; /*! @@ -720,7 +720,7 @@ class RebaseNode : public IterVarRelationNode { static IterVarRelation make(IterVar parent, IterVar rebased); static constexpr const char* _type_key = "Rebase"; - TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode); }; @@ -739,7 +739,7 @@ class SingletonNode : public IterVarRelationNode { static IterVarRelation make(IterVar iter); static constexpr const char* _type_key = "Singleton"; - TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode); }; diff --git a/include/tvm/target_info.h b/include/tvm/target_info.h index 86cb0e275609..25fb7243eaf2 100644 --- a/include/tvm/target_info.h +++ b/include/tvm/target_info.h @@ -34,7 +34,7 @@ namespace tvm { * \brief Memory information of special memory region. * Use MemoryInfo as its container type */ -struct MemoryInfoNode : public Node { +struct MemoryInfoNode : public Object { /*! \brief The addressable unit */ int unit_bits; /*! \brief Maximum number of bits supported in the memory */ @@ -55,11 +55,14 @@ struct MemoryInfoNode : public Node { } static constexpr const char* _type_key = "MemoryInfo"; - TVM_DECLARE_NODE_TYPE_INFO(MemoryInfoNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(MemoryInfoNode, Object); }; /*! \brief Defines memory info */ -TVM_DEFINE_NODE_REF(MemoryInfo, MemoryInfoNode); +class MemoryInfo : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(MemoryInfo, ObjectRef, MemoryInfoNode); +}; /*! * \brief get memory info given scope diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index f44498a0aa7a..d6e93f567e50 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -46,11 +46,11 @@ class OperationNode; * \brief Tensor structure representing a possible input, * or intermediate computation result. */ -class Tensor : public NodeRef { +class Tensor : public ObjectRef { public: /*! \brief default constructor, used internally */ Tensor() {} - explicit Tensor(ObjectPtr n) : NodeRef(n) {} + explicit Tensor(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -158,7 +158,7 @@ class Operation : public ir::FunctionRef { }; /*! \brief Node to represent a tensor */ -class TensorNode : public Node { +class TensorNode : public Object { public: /*! \brief The shape of the tensor */ Array shape; @@ -183,7 +183,7 @@ class TensorNode : public Node { int value_index); static constexpr const char* _type_key = "Tensor"; - TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object); }; @@ -250,13 +250,13 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) namespace std { template <> -struct hash<::tvm::Operation> : public ::tvm::NodeHash { +struct hash<::tvm::Operation> : public ::tvm::ObjectHash { }; template <> struct hash<::tvm::Tensor> { std::size_t operator()(const ::tvm::Tensor& k) const { - ::tvm::NodeHash hasher; + ::tvm::ObjectHash hasher; if (k.defined() && k->op.defined()) { return hasher(k->op); } else{ diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index 0d4795ad5440..f973909ae398 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -34,10 +34,10 @@ namespace tvm { class TensorIntrinNode; /*! \brief Tensor intrinsic node. */ -class TensorIntrin : public NodeRef { +class TensorIntrin : public ObjectRef { public: TensorIntrin() {} - explicit TensorIntrin(NodePtr n) : NodeRef(n) {} + explicit TensorIntrin(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -49,7 +49,7 @@ class TensorIntrin : public NodeRef { }; /*! \brief Node to represent a Tensor intrinsic operator */ -class TensorIntrinNode : public Node { +class TensorIntrinNode : public Object { public: /*! \brief The name of the intrinsic */ std::string name; @@ -108,7 +108,7 @@ class TensorIntrinNode : public Node { Stmt reduce_update); static constexpr const char* _type_key = "TensorIntrin"; - TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); }; inline const TensorIntrinNode* TensorIntrin::operator->() const { @@ -119,10 +119,10 @@ inline const TensorIntrinNode* TensorIntrin::operator->() const { class TensorIntrinCallNode; /*! \brief Tensor intrinsic calling node. */ -class TensorIntrinCall : public NodeRef { +class TensorIntrinCall : public ObjectRef { public: TensorIntrinCall() {} - explicit TensorIntrinCall(NodePtr n) : NodeRef(n) {} + explicit TensorIntrinCall(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -133,7 +133,7 @@ class TensorIntrinCall : public NodeRef { using ContainerType = TensorIntrinCallNode; }; -class TensorIntrinCallNode : public Node { +class TensorIntrinCallNode : public Object { public: /*! \brief the tensor intrinsic */ TensorIntrin intrin; @@ -166,7 +166,7 @@ class TensorIntrinCallNode : public Node { Array scalar_inputs); static constexpr const char* _type_key = "TensorIntrinCall"; - TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object); }; inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index 6bda2f57c4bf..1911a0337ac2 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -306,7 +306,7 @@ void PostOrderDFSVisit(const std::vector& heads, template inline void DFSVisit(const std::vector& heads, FVisit fvisit) { - typedef const NodePtr* GNode; + typedef const ObjectPtr* GNode; std::vector head_nodes(heads.size()); std::transform(heads.begin(), heads.end(), head_nodes.begin(), [](const NodeEntry& e)->GNode { diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 1dc9d8337587..2155481373fd 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -40,22 +40,22 @@ class Node; class Symbol; /*! - * \brief we always used NodePtr for a reference pointer + * \brief we always used ObjectPtr for a reference pointer * to the node, so this alias can be changed in case. * - * By default, NodePtr is a std::shared_ptr of node + * By default, ObjectPtr is a std::shared_ptr of node */ -using NodePtr = std::shared_ptr; +using ObjectPtr = std::shared_ptr; /*! \brief an entry that represents output data from a node */ struct NodeEntry { - NodeEntry(NodePtr node, uint32_t index, uint32_t version): + NodeEntry(ObjectPtr node, uint32_t index, uint32_t version): node(std::move(node)), index(index), version(version) {} - explicit NodeEntry(NodePtr node): + explicit NodeEntry(ObjectPtr node): node(std::move(node)), index(), version() @@ -72,7 +72,7 @@ struct NodeEntry { {} /*! \brief the source node of this data */ - NodePtr node; + ObjectPtr node; /*! \brief index of output from the source. */ uint32_t index; /*! @@ -167,7 +167,7 @@ class NNVM_DLL Node { * \brief Optional control flow dependencies * Gives operation must be performed before this operation. */ - std::vector control_deps; + std::vector control_deps; /*! \brief additional fields for this node */ any info; /*! \brief destructor of node */ @@ -189,7 +189,7 @@ class NNVM_DLL Node { * \return a created empty node. */ template - static NodePtr Create(Args&&... args) { + static ObjectPtr Create(Args&&... args) { return std::make_shared(std::forward(args)...); } }; @@ -208,7 +208,7 @@ inline NodeEntry MakeNode( std::vector inputs, std::unordered_map attrs = std::unordered_map()) { - NodePtr p = Node::Create(); + ObjectPtr p = Node::Create(); p->attrs.op = nnvm::Op::Get(op_name); p->attrs.name = std::move(node_name); p->attrs.dict = attrs; diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index 8c330c0c44cb..c2af989ba3e0 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -192,7 +192,7 @@ using FIgnoreInputs = std::function< * \note Register under "FGradient" */ using FGradient = std::function( - const NodePtr& nodeptr, + const ObjectPtr& nodeptr, const std::vector& out_grads)>; /*! @@ -204,7 +204,7 @@ using FGradient = std::function( */ using FSetInputVarAttrOnCompose = std::function; /*! diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h index dda79d468173..d3555ec726b2 100644 --- a/nnvm/include/nnvm/symbolic.h +++ b/nnvm/include/nnvm/symbolic.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -97,7 +97,7 @@ class NNVM_DLL Symbol { * \return The arguments list of this symbol, they can be either named or unnamed (empty string). * \sa ListInputOption */ - std::vector ListInputs(ListInputOption option) const; + std::vector ListInputs(ListInputOption option) const; /*! * \brief List the input names. * diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index ae819480eff8..7ca56035acae 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -259,7 +259,7 @@ int NNSymbolListInputVariables(SymbolHandle symbol, Symbol *s = static_cast(symbol); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); - std::vector vs = s->ListInputs(Symbol::ListInputOption(option)); + std::vector vs = s->ListInputs(Symbol::ListInputOption(option)); ret->ret_handles.resize(0); ret->ret_handles.reserve(vs.size()); for (size_t i = 0; i < vs.size(); ++i) { diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 829924ea7d5c..8930e49ecc58 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -50,7 +50,7 @@ static void SubgraphSanityCheck(const std::vector> &subg next_level.clear(); for (const std::vector *graph_ptr : curr_level) { const std::vector &graph = *graph_ptr; - DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) { + DFSVisit(graph, [&next_level, &node2level, level](const ObjectPtr& n) { nnvm::Node *node = n.get(); // if the node is visited, but on a different level, then check failed // if check failed here or before, we stop doing anything, but raise an error @@ -74,7 +74,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { std::vector> subgraphs; DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] - (const NodePtr& n) { + (const ObjectPtr& n) { const auto& is_ghost = Op::GetAttr("TIsGhost"); if (!n->is_variable() && is_ghost.get(n->op(), false)) return; CHECK_LT(nodes_.size(), std::numeric_limits::max()); diff --git a/nnvm/src/core/node.cc b/nnvm/src/core/node.cc index 59e35243d8f8..32d5e7f913b3 100644 --- a/nnvm/src/core/node.cc +++ b/nnvm/src/core/node.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,7 +30,7 @@ Node::~Node() { // explicit deletion via DFS // this is used to avoid stackoverflow caused by chain of deletions std::vector stack{this}; - std::vector to_delete; + std::vector to_delete; while (!stack.empty()) { Node* n = stack.back(); stack.pop_back(); @@ -42,7 +42,7 @@ Node::~Node() { e.node.reset(); } } - for (NodePtr& sp : n->control_deps) { + for (ObjectPtr& sp : n->control_deps) { if (sp.unique()) { stack.push_back(sp.get()); to_delete.emplace_back(std::move(sp)); diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 884dae7372f8..86dc7e63c403 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -36,8 +36,8 @@ struct VariableParam { uint32_t version{0}; }; -NodePtr CreateVariableNode(const std::string& name) { - NodePtr n = Node::Create(); +ObjectPtr CreateVariableNode(const std::string& name) { + ObjectPtr n = Node::Create(); n->attrs.op = nullptr; n->attrs.name = name; n->attrs.parsed = VariableParam(); @@ -114,10 +114,10 @@ inline bool IsAtomic(const std::vector& outputs) { // public functions Symbol Symbol::Copy() const { - std::unordered_map old_new; + std::unordered_map old_new; // use DFSVisit to copy all the nodes - DFSVisit(this->outputs, [&old_new](const NodePtr& node) { - NodePtr np = Node::Create(); + DFSVisit(this->outputs, [&old_new](const ObjectPtr& node) { + ObjectPtr np = Node::Create(); np->attrs = node->attrs; old_new[node.get()] = std::move(np); }); @@ -127,7 +127,7 @@ Symbol Symbol::Copy() const { Node *ptr = e.node.get(); kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); } - for (const NodePtr& p : kv.first->control_deps) { + for (const ObjectPtr& p : kv.first->control_deps) { kv.second->control_deps.emplace_back(old_new[p.get()]); } } @@ -155,7 +155,7 @@ void Symbol::Print(std::ostream &os) const { os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name << '(' << outputs[i].index << ")\n"; } - DFSVisit(this->outputs, [&os](const NodePtr& node) { + DFSVisit(this->outputs, [&os](const ObjectPtr& node) { if (node->is_variable()) { os << "Variable:" << node->attrs.name << '\n'; } else { @@ -204,21 +204,21 @@ Symbol Symbol::operator[] (size_t index) const { } } -std::vector Symbol::ListInputs(ListInputOption option) const { - std::vector ret; +std::vector Symbol::ListInputs(ListInputOption option) const { + std::vector ret; if (option == kAll) { ret.reserve(this->outputs.size()); - DFSVisit(this->outputs, [&ret](const NodePtr &node) { + DFSVisit(this->outputs, [&ret](const ObjectPtr &node) { if (node->is_variable()) { ret.push_back(node); } }); } else { std::unordered_set mutable_set; - std::vector vlist; + std::vector vlist; vlist.reserve(this->outputs.size()); static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); - DFSVisit(this->outputs, [&mutable_set, &vlist](const NodePtr &node) { + DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr &node) { if (node->is_variable()) { vlist.push_back(node); } else if (fmutate_inputs.count(node->op())) { @@ -228,7 +228,7 @@ std::vector Symbol::ListInputs(ListInputOption option) const { } }); ret.reserve(vlist.size()); - for (const NodePtr& node : vlist) { + for (const ObjectPtr& node : vlist) { if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) || (option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) { ret.emplace_back(node); @@ -239,7 +239,7 @@ std::vector Symbol::ListInputs(ListInputOption option) const { } std::vector Symbol::ListInputNames(ListInputOption option) const { - std::vector inputs = ListInputs(option); + std::vector inputs = ListInputs(option); std::vector ret(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { ret[i] = inputs[i]->attrs.name; @@ -416,7 +416,7 @@ void Symbol::Compose(const array_view& args, std::unordered_map replace_map; // replace map stores the existing replacement plan for arguments node auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map] - (const NodePtr &node) { + (const ObjectPtr &node) { if (node->is_variable()) { if (arg_counter < args.size()) { replace_map[node.get()] = &(args[arg_counter]->outputs[0]); @@ -437,7 +437,7 @@ void Symbol::Compose(const array_view& args, std::vector update_nodes; std::vector > replace_plan; auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes] - (const NodePtr &node) { + (const ObjectPtr &node) { // visit all the childs, find possible replacement bool repl = false; for (size_t i = 0; i < node->inputs.size(); ++i) { @@ -499,7 +499,7 @@ void Symbol::AddControlDeps(const Symbol& src) { Symbol Symbol::GetInternals() const { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol ret; - DFSVisit(this->outputs, [&ret](const NodePtr& node) { + DFSVisit(this->outputs, [&ret](const ObjectPtr& node) { Node* n = node.get(); if (n->is_variable()) { // grab version from variable. @@ -582,7 +582,7 @@ bool Symbol::GetAttr(const std::string& key, std::string* out) const { std::unordered_map Symbol::ListAttrs(ListAttrOption option) const { if (option == kRecursive) { std::unordered_map ret; - DFSVisit(this->outputs, [&ret](const NodePtr& n) { + DFSVisit(this->outputs, [&ret](const ObjectPtr& n) { for (const auto& it : n->attrs.dict) { ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; } @@ -596,7 +596,7 @@ std::unordered_map Symbol::ListAttrs(ListAttrOption op std::vector > Symbol::ListAttrsRecursive() const { std::vector > ret; - DFSVisit(this->outputs, [&ret](const NodePtr& n) { + DFSVisit(this->outputs, [&ret](const ObjectPtr& n) { for (const auto& it : n->attrs.dict) { ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); } @@ -608,7 +608,7 @@ Symbol Symbol::CreateFunctor(const Op* op, std::unordered_map attrs) { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol s; - NodePtr n = Node::Create(); + ObjectPtr n = Node::Create(); n->attrs.op = op; n->attrs.dict = std::move(attrs); if (n->op()->attr_parser != nullptr) { @@ -628,7 +628,7 @@ Symbol Symbol::CreateFunctor(const Op* op, Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol s; - NodePtr n = Node::Create(); + ObjectPtr n = Node::Create(); n->attrs = attrs; uint32_t nout = n->num_outputs(); diff --git a/nnvm/src/pass/correct_layout.cc b/nnvm/src/pass/correct_layout.cc index 3058c6f7b976..bdb7dbab6aba 100644 --- a/nnvm/src/pass/correct_layout.cc +++ b/nnvm/src/pass/correct_layout.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,11 +30,11 @@ namespace nnvm { namespace pass { -nnvm::NodePtr CreateLayoutTransformNode(const Layout& src, +nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src, const Layout& dst) { static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__"); static int count = 0; - nnvm::NodePtr n = nnvm::Node::Create(); + nnvm::ObjectPtr n = nnvm::Node::Create(); n->attrs.op = trans_op; n->attrs.name = src.name() + "_to_" + dst.name() + std::to_string(count++); n->attrs.dict["src_layout"] = src.name(); @@ -54,14 +54,14 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { nnvm::Op::GetAttr("FCorrectLayout"); const IndexedGraph& idx = src.indexed_graph(); - std::vector mirror_vec(idx.num_nodes(), nullptr); + std::vector mirror_vec(idx.num_nodes(), nullptr); - // (new) NodePtr -> output_layouts + // (new) ObjectPtr -> output_layouts LayoutAttrDict new_layouts; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; - nnvm::NodePtr new_node = nnvm::Node::Create(); + nnvm::ObjectPtr new_node = nnvm::Node::Create(); *new_node = *(inode.source); if (new_node->is_variable()) { // Variable node. No operator. Only one output entry. @@ -85,7 +85,7 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { std::vector request_ilayouts(num_inputs, Layout::Undef()); for (size_t i = 0; i < num_inputs; ++i) { const IndexedGraph::NodeEntry& input_entry = inode.inputs[i]; - const NodePtr& new_input_node = mirror_vec[input_entry.node_id]; + const ObjectPtr& new_input_node = mirror_vec[input_entry.node_id]; CHECK(new_input_node != nullptr); // fill inputs by previous node (DFS order) inferred layouts. @@ -122,14 +122,14 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { for (uint32_t i = 0; i < inode.inputs.size(); ++i) { const auto& e = inode.inputs[i]; - const nnvm::NodePtr& in = mirror_vec[e.node_id]; + const nnvm::ObjectPtr& in = mirror_vec[e.node_id]; new_node->inputs[i] = nnvm::NodeEntry{in, e.index, e.version}; // insert layout_transform if necessary const Layout& produce = produce_ilayouts[i]; const Layout& request = request_ilayouts[i]; if (produce != request && produce.defined()) { - nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request); + nnvm::ObjectPtr tnode = CreateLayoutTransformNode(produce, request); tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name(); tnode->inputs.emplace_back(new_node->inputs[i]); nnvm::NodeEntry tnode_output(std::move(tnode), 0, 0); diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc index 3e925222504c..9c30a785cac2 100644 --- a/nnvm/src/pass/gradient.cc +++ b/nnvm/src/pass/gradient.cc @@ -37,13 +37,13 @@ NodeEntry DefaultAggregateGradient(std::vector&& v) { if (v.size() == 1) { return std::move(v[0]); } else if (v.size() == 0) { - NodePtr zero_node = Node::Create(); + ObjectPtr zero_node = Node::Create(); zero_node->attrs.op = Op::Get("zeros"); zero_node->attrs.name = "zero_grad"; zero_node->attrs.op->attr_parser(&(zero_node->attrs)); return NodeEntry{zero_node, 0, 0}; } else { - NodePtr sum_node = Node::Create(); + ObjectPtr sum_node = Node::Create(); sum_node->attrs.op = Op::Get("elemwise_sum"); sum_node->inputs = std::move(v); sum_node->attrs.name = "grad_sum"; @@ -119,10 +119,10 @@ Graph Gradient(Graph src) { nullptr; // topo sort - std::vector topo_order; + std::vector topo_order; std::unordered_map > output_grads; - DFSVisit(ys, [&](const NodePtr& node) { + DFSVisit(ys, [&](const ObjectPtr& node) { if (output_grads.count(node.get()) == 0) { output_grads[node.get()].resize(node->num_outputs()); } @@ -143,11 +143,11 @@ Graph Gradient(Graph src) { } // construct mirror as memory reduction strategy if needed - std::unordered_map mirror_map; + std::unordered_map mirror_map; if (mirror_fun != nullptr) { - for (const NodePtr& node_ptr : topo_order) { + for (const ObjectPtr& node_ptr : topo_order) { if (mirror_fun(*node_ptr)) { - NodePtr new_node = Node::Create(); + ObjectPtr new_node = Node::Create(); *new_node = *node_ptr; new_node->attrs.name += "_mirror"; for (auto& e : new_node->inputs) { @@ -169,7 +169,7 @@ Graph Gradient(Graph src) { std::vector out_agg_grads; for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { - const NodePtr& ptr = *rit; + const ObjectPtr& ptr = *rit; if (ptr->is_variable()) continue; out_agg_grads.clear(); auto& out_grad_vec = output_grads.at(ptr.get()); @@ -182,7 +182,7 @@ Graph Gradient(Graph src) { out_agg_grads.push_back(e.sum); } if ((*rit)->inputs.size() != 0) { - NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); + ObjectPtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); std::vector input_grads; // Check for FGradient if (grad_fun_map.contains(ptr->op())) { @@ -244,7 +244,7 @@ Graph Gradient(Graph src) { if (kv == unique_grads.end()) { unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter)); } else { - NodePtr copy_node = Node::Create(); + ObjectPtr copy_node = Node::Create(); std::ostringstream os; os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy"; kv->second.first++; diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index a5797736209f..876dce1c113d 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -112,7 +112,7 @@ Graph InferAttr(Graph &&ret, CHECK_GE(inode.control_deps.size(), 1U) << "BackwardOp need to have control_deps to its forward op"; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; - NodePtr fwd_ptr = inode.source->control_deps[0]; + ObjectPtr fwd_ptr = inode.source->control_deps[0]; CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; // use gradient function to find out the correspondence. std::vector ograd(fwd_ptr->num_outputs()); diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index 6f43da282ee4..b2fa2ca33e07 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -45,7 +45,7 @@ inline bool IsMutate(const std::vector& mutate_inputs, uint32_t i) { Graph OrderMutation(const Graph& src) { std::unordered_map > version_hist; - DFSVisit(src.outputs, [&version_hist](const NodePtr& n) { + DFSVisit(src.outputs, [&version_hist](const ObjectPtr& n) { for (const NodeEntry& e : n->inputs) { if (e.node->is_variable()) { if (e.version != 0 && version_hist.count(e.node.get()) == 0) { @@ -57,8 +57,8 @@ Graph OrderMutation(const Graph& src) { // no mutation happens, everything if fine. if (version_hist.size() == 0) return src; // start preparing for remapping the nodes. - std::unordered_map old_new; - auto prepare = [&version_hist, &old_new] (const NodePtr& n) { + std::unordered_map old_new; + auto prepare = [&version_hist, &old_new] (const ObjectPtr& n) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); std::vector mutate_inputs; if (!n->is_variable() && fmutate_inputs.count(n->op())) { @@ -80,11 +80,11 @@ Graph OrderMutation(const Graph& src) { if (old_new.count(e.node.get()) != 0) need_repl = true; } } - for (const NodePtr& p : n->control_deps) { + for (const ObjectPtr& p : n->control_deps) { if (old_new.count(p.get()) != 0) need_repl = true; } if (need_repl) { - NodePtr np = Node::Create(); + ObjectPtr np = Node::Create(); np->attrs = n->attrs; old_new[n.get()] = std::move(np); } @@ -111,7 +111,7 @@ Graph OrderMutation(const Graph& src) { kv.second->inputs.push_back(e); } } - for (const NodePtr& p : kv.first->control_deps) { + for (const ObjectPtr& p : kv.first->control_deps) { kv.second->control_deps.emplace_back( get_with_default(old_new, p.get(), p)); } diff --git a/nnvm/src/pass/place_device.cc b/nnvm/src/pass/place_device.cc index a0c0fb2f534a..6d6866e472d6 100644 --- a/nnvm/src/pass/place_device.cc +++ b/nnvm/src/pass/place_device.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -105,8 +105,8 @@ Graph PlaceDevice(Graph src) { src.attrs["device"] = std::make_shared(std::move(device)); return src; } - std::map, NodePtr> copy_map; - std::vector new_node_map(idx.num_nodes(), nullptr); + std::map, ObjectPtr> copy_map; + std::vector new_node_map(idx.num_nodes(), nullptr); std::unordered_map new_device_map; static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); @@ -142,7 +142,7 @@ Graph PlaceDevice(Graph src) { CHECK(!need_mutate) << "consistency check"; } if (need_mutate) { - NodePtr new_node = Node::Create(); + ObjectPtr new_node = Node::Create(); new_node->attrs = inode.source->attrs; new_node->inputs.reserve(inode.inputs.size()); for (size_t i = 0; i < inode.inputs.size(); ++i) { @@ -154,7 +154,7 @@ Graph PlaceDevice(Graph src) { new_node->inputs.emplace_back( NodeEntry{it->second, 0, 0}); } else { - NodePtr copy_node = Node::Create(); + ObjectPtr copy_node = Node::Create(); std::ostringstream os; os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; copy_node->attrs.op = copy_op; diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 69d4a05f66e8..9389995d0521 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -86,7 +86,7 @@ struct JSONNode { }; // pointer to the graph node - NodePtr node; + ObjectPtr node; // inputs std::vector inputs; // control flow dependencies @@ -190,7 +190,7 @@ struct JSONGraph { void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { std::unordered_map node2index; jgraph->node_row_ptr.push_back(0); - DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) { + DFSVisit(src->outputs, [&node2index, jgraph](const ObjectPtr& n) { uint32_t nid = static_cast(jgraph->nodes.size()); node2index[n.get()] = nid; if (n->is_variable()) { @@ -202,7 +202,7 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { for (const NodeEntry& e : n->inputs) { jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version); } - for (const NodePtr& c : n->control_deps) { + for (const ObjectPtr& c : n->control_deps) { jnode.control_deps.push_back(node2index.at(c.get())); } jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs()); diff --git a/src/api/api_base.cc b/src/api/api_base.cc index cbefaa464ded..bcfd82bee7fe 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -32,7 +32,7 @@ TVM_REGISTER_API("_format_str") .set_body([](TVMArgs args, TVMRetValue *ret) { CHECK(args[0].type_code() == kObjectHandle); std::ostringstream os; - os << args[0].operator NodeRef(); + os << args[0].operator ObjectRef(); *ret = os.str(); }); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 8a74fe5cdb7d..00ceaf72118c 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -65,7 +65,7 @@ TVM_REGISTER_API("_Array") data.push_back(ObjectRef(nullptr)); } } - auto node = make_node(); + auto node = make_object(); node->data = std::move(data); *ret = Array(node); }); @@ -105,7 +105,7 @@ TVM_REGISTER_API("_Map") data.emplace(std::make_pair(args[i].operator std::string(), args[i + 1].operator ObjectRef())); } - auto node = make_node(); + auto node = make_object(); node->data = std::move(data); *ret = Map(node); } else { @@ -119,7 +119,7 @@ TVM_REGISTER_API("_Map") data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef())); } - auto node = make_node(); + auto node = make_object(); node->data = std::move(data); *ret = Map(node); } @@ -186,7 +186,7 @@ TVM_REGISTER_API("_MapItems") if (ptr->IsInstance()) { auto* n = static_cast(ptr); - auto rkvs = make_node(); + auto rkvs = make_object(); for (const auto& kv : n->data) { rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.second); @@ -194,7 +194,7 @@ TVM_REGISTER_API("_MapItems") *ret = Array(rkvs); } else { auto* n = static_cast(ptr); - auto rkvs = make_node(); + auto rkvs = make_object(); for (const auto& kv : n->data) { rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(kv.second); diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index c62cc8ad16a0..339b25a51894 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -100,12 +100,13 @@ TVM_REGISTER_API("ir_pass.RewriteForTensorCore") }); TVM_REGISTER_API("ir_pass.AttrsEqual") -.set_body_typed([](const NodeRef& lhs, const NodeRef& rhs) { +.set_body_typed( + [](const ObjectRef& lhs, const ObjectRef& rhs) { return AttrsEqual()(lhs, rhs); }); TVM_REGISTER_API("ir_pass.AttrsHash") -.set_body_typed([](const NodeRef &node) { +.set_body_typed([](const ObjectRef &node) { return AttrsHash()(node); }); @@ -118,7 +119,7 @@ TVM_REGISTER_API("ir_pass.ExprUseVar") TVM_REGISTER_API("ir_pass.PostOrderVisit") .set_body([](TVMArgs args, TVMRetValue *ret) { PackedFunc f = args[1]; - ir::PostOrderVisit(args[0], [f](const NodeRef& n) { + ir::PostOrderVisit(args[0], [f](const ObjectRef& n) { f(n); }); }); @@ -126,7 +127,7 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") TVM_REGISTER_API("ir_pass.LowerStorageAccess") .set_body([](TVMArgs args, TVMRetValue *ret) { LoweredFunc f = args[0]; - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = LowerStorageAccessInfo(f->body); *ret = LoweredFunc(n); }); diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 19f045241915..0b84be291f71 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -42,7 +42,7 @@ class VariablePathFinder: public IRVisitor { public: explicit VariablePathFinder(Expr target) : target_(target) {} - void Visit(const NodeRef& node) final { + void Visit(const ObjectRef& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); @@ -82,7 +82,7 @@ class BoundDeducer: public IRVisitor { void Deduce(); - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (!success_) return; if (e.get() == path_[iter_++]) { IRVisitor::Visit(e); @@ -202,7 +202,7 @@ class BoundDeduceInputChecker: public IRVisitor { return target_count == 1; } - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (e.same_as(deducer_->target_)) ++target_count; IRVisitor::Visit(e); } diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 022dd8e94dbb..6a19a7aeb3f2 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -56,7 +56,7 @@ class CanonicalExprNode : public BaseExprNode { } static constexpr const char* _type_key = "arith.CanonicalExpr"; - TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode); + TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, BaseExprNode); }; enum DivMode { @@ -147,10 +147,14 @@ class SplitExprNode : public CanonicalExprNode { /*! \brief positive infty */ static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; static constexpr const char* _type_key = "arith.SplitExpr"; - TVM_DECLARE_NODE_TYPE_INFO(SplitExprNode, CanonicalExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode); }; -TVM_DEFINE_COW_NODE_REF(SplitExpr, Expr, SplitExprNode); +class SplitExpr : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, Expr, SplitExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode); +}; inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const { if (index.same_as(other->index)) return true; @@ -272,7 +276,7 @@ class SumExprNode : public CanonicalExprNode { void AddToSelf(const SumExpr& other, int64_t scale); static constexpr const char* _type_key = "arith.SumExpr"; - TVM_DECLARE_NODE_TYPE_INFO(SumExprNode, CanonicalExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode); private: /*! @@ -405,7 +409,11 @@ class SumExprNode : public CanonicalExprNode { } }; -TVM_DEFINE_COW_NODE_REF(SumExpr, Expr, SumExprNode); +class SumExpr : public Expr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, Expr, SumExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode); +}; void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) { // NOTE: it is rare to have a balanced long expression, @@ -507,7 +515,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (const auto* op = expr.as()) { expr = op->Normalize(); } - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->dtype = expr.dtype(); n->index = std::move(expr); n->div_mode = kTruncDiv; @@ -544,7 +552,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (const auto* op = expr.as()) { return GetRef(op); } - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->dtype = expr.dtype(); if (const auto* op = expr.as()) { n->base = op->value; @@ -655,8 +663,8 @@ SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, SumExpr* out_non_divisible) { - auto divisible = make_node(); - auto non_divisible = make_node(); + auto divisible = make_object(); + auto non_divisible = make_object(); divisible->dtype = psum->dtype; non_divisible->dtype = psum->dtype; diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index c0519107d5b8..16e489a9c818 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -35,7 +35,7 @@ TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); ConstIntBound::ConstIntBound( int64_t min_value, int64_t max_value) { - auto node = make_node(); + auto node = make_object(); node->min_value = min_value; node->max_value = max_value; data_ = std::move(node); @@ -123,7 +123,7 @@ class ConstIntBoundAnalyzer::Impl : } // Override visitor behaviors - Entry VisitExprDefault_(const Node* op) final { + Entry VisitExprDefault_(const Object* op) final { return Everything( static_cast(op)->dtype); } diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index cf37545502ba..c4ee40f12da8 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -106,7 +106,7 @@ class LinearEqDetector } return ret; } - LinearEqEntry VisitExprDefault_(const Node* op, const Expr& e) final { + LinearEqEntry VisitExprDefault_(const Object* op, const Expr& e) final { if (fail_) return LinearEqEntry(); if (ExprUseVar(e, var_)) { fail_ = true; @@ -171,7 +171,7 @@ bool DetectClipBound( std::unordered_map* bmap) { int flag = 0; Var var; - auto fvisit = [&bmap, &flag, &var](const NodeRef& n) { + auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) { if (const Variable* v = n.as()) { if (bmap->count(v)) { if (flag == 0) { diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index e4f2042a19d7..79b39748426d 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -37,7 +37,7 @@ Expr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); Expr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); IntervalSet::IntervalSet(Expr min_value, Expr max_value) { - auto node = make_node(); + auto node = make_object(); node->min_value = std::move(min_value); node->max_value = std::move(max_value); data_ = std::move(node); @@ -505,7 +505,7 @@ class IntervalSetEvaluator : return Union(analyzer_, false_set, true_set); } - IntervalSet VisitExprDefault_(const Node* op) final { + IntervalSet VisitExprDefault_(const Object* op) final { DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); } diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 831b44409030..2e072127b449 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -75,7 +75,7 @@ class IntervalSetNode : public IntSetNode { } static constexpr const char* _type_key = "arith.IntervalSet"; - TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IntervalSetNode, IntSetNode); }; /*! @@ -116,8 +116,8 @@ class IntervalSet : public IntSet { return IntervalSet(pos_inf(), neg_inf()); } - TVM_DEFINE_NODE_REF_COW(IntervalSetNode); - TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IntervalSetNode); + TVM_DEFINE_OBJECT_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); }; /*! diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 25c7391fd9c4..5ab1bd386748 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -37,7 +37,7 @@ using namespace ir; TVM_REGISTER_NODE_TYPE(ModularSetNode); ModularSet::ModularSet(int64_t coeff, int64_t base) { - auto node = make_node(); + auto node = make_object(); node->coeff = coeff; node->base = base; // finish construction. @@ -120,7 +120,7 @@ class ModularSetAnalyzer::Impl : } // Override visitor behaviors - Entry VisitExprDefault_(const Node* op) final { + Entry VisitExprDefault_(const Object* op) final { return Everything(); } diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index fd07a377e955..bff956473c87 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -250,7 +250,7 @@ class PBinaryExpr : b_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const NodeType* ptr = node.as()) { if (!a_.Match_(ptr->a)) return false; if (!b_.Match_(ptr->b)) return false; @@ -282,7 +282,7 @@ class PConstWithTypeLike : void InitMatch_() const {} - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::IntImm* ptr = node.as()) { return ptr->value == value_; } else { @@ -364,7 +364,7 @@ class PNotExpr : public Pattern > { value_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Not* ptr = node.as()) { if (!value_.Match_(ptr->a)) return false; return true; @@ -410,7 +410,7 @@ class PSelectExpr : false_value_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Select* ptr = node.as()) { if (!condition_.Match_(ptr->condition)) return false; if (!true_value_.Match_(ptr->true_value)) return false; @@ -472,7 +472,7 @@ class PCastExpr : value_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Cast* ptr = node.as()) { if (!dtype_.Match_(ptr->dtype)) return false; if (!value_.Match_(ptr->value)) return false; @@ -530,7 +530,7 @@ class PRampExpr : lanes_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Ramp* ptr = node.as()) { if (!base_.Match_(ptr->base)) return false; if (!stride_.Match_(ptr->stride)) return false; @@ -592,7 +592,7 @@ class PBroadcastExpr : lanes_.InitMatch_(); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Broadcast* ptr = node.as()) { if (!value_.Match_(ptr->value)) return false; if (!lanes_.Match_(ptr->lanes)) return false; @@ -704,7 +704,7 @@ class PCallExpr : detail::tuple_for_each(finit, args_); } - bool Match_(const NodeRef& node) const { + bool Match_(const ObjectRef& node) const { if (const ir::Call* ptr = node.as()) { if (ptr->args.size() != sizeof...(TArgs)) return false; if (ptr->name != Op::kName) return false; diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index ca25731cafef..3ea2cb77d316 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -53,7 +53,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) */ Target CreateTarget(const std::string& target_name, const std::vector& options) { - auto t = make_node(); + auto t = make_object(); t->target_name = target_name; std::string libs_flag = "-libs="; @@ -366,7 +366,7 @@ void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, Map* out_binds, - Array* out_arg_list, + Array* out_arg_list, const BuildConfig& config) { *out_binds = binds; @@ -396,7 +396,7 @@ Stmt BuildStmt(Schedule sch, const Array& args, const std::unordered_map& binds, bool loop_partition, - Array *out_arg_list, + Array *out_arg_list, const BuildConfig& config) { sch = sch.normalize(); @@ -445,7 +445,7 @@ Array lower(Schedule sch, const std::string& name, const std::unordered_map& binds, const BuildConfig& config) { - Array out_arg_list; + Array out_arg_list; auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); return Array({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); } @@ -618,7 +618,7 @@ runtime::Module build(const Array& funcs, } BuildConfig BuildConfig::Create() { - return BuildConfig(make_node()); + return BuildConfig(make_object()); } /*! \brief Entry to hold the BuildConfig context stack. */ @@ -701,7 +701,7 @@ GenericFunc GenericFunc::Get(const std::string& name) { std::lock_guard(m->mutex); auto it = m->fmap.find(name); if (it == m->fmap.end()) { - auto f = make_node(); + auto f = make_object(); f->name_ = name; auto gf = GenericFunc(f); m->fmap[name] = gf; @@ -825,7 +825,7 @@ TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo") TVM_REGISTER_API("_GenericFuncCreate") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = GenericFunc(make_node()); + *ret = GenericFunc(make_object()); }); TVM_REGISTER_API("_GenericFuncGetGlobal") diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 2bb86093e2f8..c723a2284ebf 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -408,7 +408,7 @@ void CodeGenHybrid::PrintIndent() { std::string CodeGenHybrid::GetVarID(const Variable *v) { if (binds_.count(v)) return binds_[v]; - auto key = std::make_pair(static_cast(v), 0); + auto key = std::make_pair(static_cast(v), 0); if (id_map_.count(key)) { return id_map_[key]; } @@ -472,7 +472,7 @@ void CodeGenHybrid::ReserveKeywords() { } void CodeGenHybrid::DumpStmt(const Stmt &stmt, - const Array &inputs, + const Array &inputs, const Array &outputs, const std::string &name) { ReserveKeywords(); diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 2c719b0b3ecf..647ef77fc534 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -56,7 +56,7 @@ class CodeGenHybrid : * \param outputs Output tensors of this schedule. * \param name The name of the function. */ - void DumpStmt(const Stmt &stmt, const Array &inputs, const Array &outputs, + void DumpStmt(const Stmt &stmt, const Array &inputs, const Array &outputs, const std::string &name = "hybrid_func"); /*! * \brief Finalize the compilation and return the code. @@ -152,7 +152,7 @@ class CodeGenHybrid : /*! * \brief Keys are either (tensors, value_index) or (variables, 0). * Values are the corresponding IDs.*/ - std::map, std::string> id_map_; + std::map, std::string> id_map_; /*! \brief Variables (keys) binded to the threads (values). */ std::map binds_; /*! diff --git a/src/lang/api_registry.cc b/src/lang/api_registry.cc index d6a413e987cf..68d42a2c1433 100644 --- a/src/lang/api_registry.cc +++ b/src/lang/api_registry.cc @@ -33,7 +33,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ObjectPtr CreateEnvNode(const std::string& name) { auto* f = runtime::Registry::Get(name); CHECK(f != nullptr) << "Cannot find global function \'" << name << '\''; - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->func = *f; n->name = name; return n; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index b83734beacb3..1c341d53168e 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -39,8 +39,8 @@ void DictAttrsNode::InitByPackedArgs( for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; runtime::TVMArgValue val = args[i + 1]; - if (val.type_code() == kObjectHandle) { - dict.Set(key, val.operator NodeRef()); + if (val.IsObjectRef()) { + dict.Set(key, val.operator ObjectRef()); } else if (val.type_code() == kStr) { dict.Set(key, Expr(val.operator std::string())); } else { @@ -53,8 +53,8 @@ Array DictAttrsNode::ListFieldInfo() const { return {}; } -Attrs DictAttrsNode::make(Map dict) { - NodePtr n = make_node(); +Attrs DictAttrsNode::make(Map dict) { + ObjectPtr n = make_object(); n->dict = std::move(dict); return Attrs(n); } diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index eb5d87efbbfa..9bbd8d62105f 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -334,7 +334,7 @@ Buffer Buffer::MakeStrideView() const { if ((*this)->strides.size() != 0) return *this; if ((*this)->shape.size() == 0) return *this; std::vector temp; - auto n = make_node(*operator->()); + auto n = make_object(*operator->()); Expr acc = make_const(n->DefaultIndexType(), 1); for (size_t i = n->shape.size(); i != 0 ; --i) { temp.push_back(acc); @@ -419,7 +419,7 @@ Buffer BufferNode::make(Var data, int data_alignment, int offset_factor, BufferType buffer_type) { - auto n = make_node(); + auto n = make_object(); n->data = std::move(data); n->dtype = dtype; n->shape = std::move(shape); diff --git a/src/lang/data_layout.cc b/src/lang/data_layout.cc index 5393bbffb148..58f033b69e51 100644 --- a/src/lang/data_layout.cc +++ b/src/lang/data_layout.cc @@ -68,7 +68,7 @@ const LayoutAxis& LayoutAxis::make(const std::string& name) { } Layout::Layout(const Array& axes) { - auto node = make_node(); + auto node = make_object(); node->axes = axes; std::ostringstream repr; for (const IterVar& axis : axes) { @@ -89,7 +89,7 @@ Layout::Layout(const Array& axes) { Layout::Layout(const std::string& name) { // NOLINT(*) if (name == "__undef__") return; - auto node = make_node(); + auto node = make_object(); node->name = name; if (name.empty()) return; // scalar @@ -347,7 +347,7 @@ Array BijectiveLayout::BackwardShape(const Array& shape) const { BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout, const Layout& dst_layout) { - auto n = make_node(); + auto n = make_object(); n->src_layout = src_layout; n->dst_layout = dst_layout; diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 997c15177546..5a54f0407c8d 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -42,14 +42,14 @@ Var::Var(std::string name_hint, DataType t) : Var(Variable::make(t, name_hint)) {} Var Variable::make(DataType t, std::string name_hint) { - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = t; node->name_hint = std::move(name_hint); return Var(node); } Range::Range(Expr begin, Expr end) - : Range(make_node( + : Range(make_object( begin, is_zero(begin) ? end : (end - begin))) { } @@ -57,21 +57,21 @@ Range::Range(Expr begin, Expr end) Integer IntImm::make(DataType t, int64_t value) { CHECK(t.is_int() && t.is_scalar()) << "ValueError: IntImm can only take scalar."; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = t; node->value = value; return Integer(node); } Range Range::make_by_min_extent(Expr min, Expr extent) { - return Range(make_node(min, extent)); + return Range(make_object(min, extent)); } IterVar IterVarNode::make(Range dom, Var var, IterVarType t, std::string thread_tag) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->dom = dom; n->var = var; n->iter_type = t; @@ -89,7 +89,7 @@ IterVar reduce_axis(Range dom, std::string name) { dom, Var(name), kCommReduce); } -void Dump(const NodeRef& n) { +void Dump(const ObjectRef& n) { std::cerr << n << "\n"; } diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 427e026bc728..d5cc285ac861 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -34,7 +34,7 @@ namespace ir { Expr UIntImm::make(DataType t, uint64_t value) { CHECK(t.is_uint() && t.lanes() == 1) << "ValueError: UIntImm can only take scalar"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = t; node->value = value; return Expr(node); @@ -43,14 +43,14 @@ Expr UIntImm::make(DataType t, uint64_t value) { Expr FloatImm::make(DataType t, double value) { CHECK_EQ(t.lanes(), 1) << "ValueError: FloatImm can only take scalar"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = t; node->value = value; return Expr(node); } Expr StringImm::make(std::string value) { - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = DataType::Handle(); node->value = std::move(value); return Expr(node); @@ -59,7 +59,7 @@ Expr StringImm::make(std::string value) { Expr Cast::make(DataType t, Expr value) { CHECK(value.defined()); CHECK_EQ(t.lanes(), value.dtype().lanes()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); return Expr(node); @@ -72,7 +72,7 @@ Expr And::make(Expr a, Expr b) { CHECK(b.dtype().is_bool()); CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); @@ -86,7 +86,7 @@ Expr Or::make(Expr a, Expr b) { CHECK(b.dtype().is_bool()); CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); @@ -97,7 +97,7 @@ Expr Not::make(Expr a) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.dtype().is_bool()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); return Expr(node); @@ -111,7 +111,7 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) { CHECK_EQ(condition.dtype().lanes(), true_value.dtype().lanes()); CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; - NodePtr(); + ObjectPtr(); node->dtype = true_value.dtype(); node->condition = std::move(condition); node->true_value = std::move(true_value); @@ -126,7 +126,7 @@ Expr Load::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) { CHECK_EQ(dtype.lanes(), index.dtype().lanes()); CHECK_EQ(dtype.lanes(), predicate.dtype().lanes()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = dtype; node->buffer_var = std::move(buffer_var); node->index = std::move(index); @@ -143,7 +143,7 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) { CHECK_GT(lanes, 1); CHECK_EQ(stride.dtype(), base.dtype()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = base.dtype().with_lanes(lanes); node->base = base; node->stride = stride; @@ -156,7 +156,7 @@ Expr Broadcast::make(Expr value, int lanes) { CHECK(value.dtype().is_scalar()); CHECK_GT(lanes, 1); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = value.dtype().with_lanes(lanes); node->value = std::move(value); node->lanes = lanes; @@ -168,7 +168,7 @@ Expr Let::make(Var var, Expr value, Expr body) { CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = body.dtype(); node->var = std::move(var); node->value = std::move(value); @@ -208,7 +208,7 @@ Expr Call::make(DataType dtype, } } - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = dtype; node->name = std::move(name); node->args = std::move(args); @@ -232,7 +232,7 @@ Expr Shuffle::make(Array vectors, } CHECK_LE(indices.size(), static_cast(total_lanes)); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->dtype = base_type.with_lanes(static_cast(indices.size())); node->vectors = std::move(vectors); node->indices = std::move(indices); @@ -262,7 +262,7 @@ CommReducer CommReducerNode::make(Array lhs, Array rhs, Array result, Array identity_element) { - auto node = make_node(); + auto node = make_object(); node->lhs = lhs; node->rhs = rhs; node->result = result; @@ -293,7 +293,7 @@ Expr Reduce::make(CommReducer combiner, Array source, if (!condition.defined()) { condition = const_true(); } - auto n = make_node(); + auto n = make_object(); CHECK(source.defined()); for (size_t i = 0; i < axis.size(); ++i) { CHECK(axis[i].defined()); @@ -308,7 +308,7 @@ Expr Reduce::make(CommReducer combiner, Array source, } Expr Any::make() { - auto n = make_node(); + auto n = make_object(); return Expr(n); } @@ -317,18 +317,18 @@ Stmt LetStmt::make(Var var, Expr value, Stmt body) { CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); return Stmt(node); } -Stmt AttrStmt::make(NodeRef node, +Stmt AttrStmt::make(ObjectRef node, std::string attr_key, Expr value, Stmt body) { - auto n = make_node(); + auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); n->value = std::move(value); @@ -343,7 +343,7 @@ Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) { << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->condition = std::move(condition); node->message = std::move(message); node->body = std::move(body); @@ -353,7 +353,7 @@ Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) { Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) { CHECK(body.defined()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->func = std::move(func); node->is_producer = is_producer; node->body = std::move(body); @@ -373,7 +373,7 @@ Stmt For::make(Var loop_var, CHECK(loop_var.dtype().is_scalar()); CHECK(body.defined()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); node->extent = std::move(extent); @@ -390,7 +390,7 @@ Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) { CHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->value = std::move(value); node->index = std::move(index); @@ -407,7 +407,7 @@ Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array ar CHECK(args[i].defined()) << "Provide to undefined location\n"; } - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->func = std::move(func); node->value_index = value_index; node->value = std::move(value); @@ -430,7 +430,7 @@ Stmt Allocate::make(Var buffer_var, CHECK(condition.defined()); CHECK(condition.dtype().is_bool()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; node->extents = std::move(extents); @@ -457,7 +457,7 @@ int32_t Allocate::constant_allocation_size(const Array& extents) { } Stmt Free::make(Var buffer_var) { - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->buffer_var = buffer_var; return Stmt(node); } @@ -478,7 +478,7 @@ Stmt Realize::make(FunctionRef func, CHECK(condition.defined()); CHECK(condition.dtype().is_bool()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->func = std::move(func); node->value_index = value_index; node->dtype = dtype; @@ -496,7 +496,7 @@ Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bo CHECK(bounds[i]->extent.dtype().is_scalar()); } - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->func = std::move(func); node->value_index = value_index; node->dtype = dtype; @@ -507,7 +507,7 @@ Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bo Stmt Block::make(Stmt first, Stmt rest) { CHECK(first.defined()); CHECK(rest.defined()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); // canonicalize. if (const Block* b = first.as()) { @@ -536,7 +536,7 @@ Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { CHECK(then_case.defined()); // else_case may be null. - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->condition = std::move(condition); node->then_case = std::move(then_case); node->else_case = std::move(else_case); @@ -546,7 +546,7 @@ Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { Stmt Evaluate::make(Expr value) { CHECK(value.defined()); - NodePtr node = make_node(); + ObjectPtr node = make_object(); node->value = std::move(value); return Stmt(node); } diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 1c110936b3ef..e9ca89a4b31e 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -47,7 +47,7 @@ Expr Tensor::operator()(Array indices) const { } Tensor Operation::output(size_t i) const { - auto node = make_node(); + auto node = make_object(); node->op = *this; node->value_index = i; node->dtype = (*this)->output_dtype(i); @@ -59,7 +59,7 @@ Tensor TensorNode::make(Array shape, DataType dtype, Operation op, int value_index) { - auto n = make_node(); + auto n = make_object(); n->shape = std::move(shape); n->dtype = dtype; n->op = op; @@ -87,7 +87,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, Stmt body, Stmt reduce_init, Stmt reduce_update) { - auto n = make_node(); + auto n = make_object(); n->name = std::move(name); n->op = std::move(op); n->inputs = std::move(inputs); @@ -115,7 +115,7 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array regions, Array reduce_axis, Array scalar_inputs) { - auto n = make_node(); + auto n = make_object(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); n->regions = std::move(regions); diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 5a991aa3ad1b..5e8a0f709f81 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -79,7 +79,7 @@ class NodeIndexer : public AttrVisitor { // make index of all the children of node void MakeIndex(Object* node) { if (node == nullptr) return; - CHECK(node->IsInstance()); + CHECK(node->IsInstance()); if (node_index_.count(node)) return; CHECK_EQ(node_index_.size(), node_list_.size()); diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index bd129ac33058..c0cae269ffc3 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -90,8 +90,8 @@ Tensor compute(Array shape, FCompute fcompute, std::string name, std::string tag, - Map attrs) { - auto op_node = make_node(); + Map attrs) { + auto op_node = make_object(); // compute dimension. size_t ndim = shape.size(); std::vector axis; @@ -112,8 +112,8 @@ Array compute(Array shape, FBatchCompute fcompute, std::string name, std::string tag, - Map attrs) { - auto op_node = make_node(); + Map attrs) { + auto op_node = make_object(); // compute dimension. size_t ndim = shape.size(); std::vector axis; @@ -136,13 +136,13 @@ Array compute(Array shape, Operation ComputeOpNode::make(std::string name, std::string tag, - Map attrs, + Map attrs, Array axis, Array body) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } - auto n = make_node(); + auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -161,7 +161,7 @@ Array ComputeOpNode::InputTensors() const { Array ret; std::unordered_set visited; for (auto& e : body) { - ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) { + ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { const ir::Call *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); @@ -188,7 +188,7 @@ Operation ComputeOpNode::ReplaceInputs( if (!new_reduce.same_as(this->body[0])) { const ir::Reduce* r = new_reduce.as(); for (size_t k = 0; k < this->body.size(); ++k) { - auto n = make_node(*r); + auto n = make_object(*r); n->value_index = static_cast(k); n->dtype = r->source[k].dtype(); arr.push_back(Expr(n)); @@ -215,7 +215,7 @@ void ComputeOpNode::PropBoundToInputs( const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); - auto fvisit = [&dom_map, out_dom_map, analyzer](const NodeRef& n) { + auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { auto *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); @@ -574,7 +574,7 @@ class ComputeVerifier final : protected ir::IRVisitor { protected: /// Visitor implementation //@{ - void Visit(const NodeRef& n) final { + void Visit(const ObjectRef& n) final { ++level_; ir::IRVisitor::Visit(n); --level_; diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index 883ebdc4a0f7..b921c86f3556 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -57,15 +57,15 @@ Array ExternOpNode::output_shape(size_t i) const { Operation ExternOpNode::make(std::string name, std::string tag, - Map attrs, + Map attrs, Array inputs, Array input_placeholders, Array output_placeholders, Stmt body) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } - auto n = make_node(); + auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -93,7 +93,7 @@ Operation ExternOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - auto n = make_node(*this); + auto n = make_object(*this); n->body = op::ReplaceTensor(this->body, rmap); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; @@ -161,7 +161,7 @@ Stmt ExternOpNode::BuildProvide( CHECK_EQ(stage->op.operator->(), this); Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { - Array bind_spec; + Array bind_spec; Array tuple; bind_spec.push_back(buffer); bind_spec.push_back(tensor); diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index 1e1a81423b69..061929a31ef1 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -63,14 +63,14 @@ Array HybridOpNode::output_shape(size_t i) const { Operation HybridOpNode::make(std::string name, std::string tag, - Map attrs, + Map attrs, Array inputs, Array outputs, Stmt body) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } - auto n = make_node(); + auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); @@ -91,7 +91,7 @@ Array HybridOpNode::InputTensors() const { } std::unordered_set visited; Array curr_inputs; - ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) { + ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) { const ir::Call *call = n.as(); if (call != nullptr && call->func.defined()) { Tensor t = Downcast(call->func).output(call->value_index); @@ -108,7 +108,7 @@ Operation HybridOpNode::ReplaceInputs( const Operation &self, const std::unordered_map &rmap) const { CHECK_EQ(self.operator->(), this); - auto n = make_node(*this); + auto n = make_object(*this); n->body = op::ReplaceTensor(this->body, rmap); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; @@ -185,7 +185,7 @@ Stmt HybridOpNode::BuildProvide( for (int i = 0; i < this->num_outputs(); ++i) { rmap[outputs[i]] = stage->op.output(i); } - auto n = make_node(*this); + auto n = make_object(*this); /* This is a story little bit complicated. * The following two lines of codes replace output tensors' usage. * This is the simplest way I (@were) can come up with to glue @@ -369,7 +369,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage, expected = IterVarTypeToForType(attr->iter_type); } - PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const NodeRef &node) { + PostOrderVisit(stmt, + [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { if (const For *op = node.as()) { if (op->loop_var.get() == var) { ++found; @@ -390,7 +391,7 @@ Stmt ApplyLoopOrder(const Stage &stage, const std::unordered_map &dom_map, const std::unordered_map &rebased, Stmt stmt) { std::vector current_order; - PostOrderVisit(stmt, [¤t_order](const NodeRef &node) { + PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) { if (const For *op = node.as()) current_order.push_back(op->loop_var.get()); }); @@ -466,7 +467,7 @@ Stmt ApplySchedule(const Stage &stage, std::vector GatherLoopVars(Stmt stmt) { // TODO(@were): Write a comprehensive pass to analyze iter var types std::vector res_; - PostOrderVisit(stmt, [&res_](const NodeRef &node) { + PostOrderVisit(stmt, [&res_](const ObjectRef& node) { if (const For *op = node.as()) { Var loop_var(op->loop_var); Range dom = Range::make_by_min_extent(op->min, op->extent); diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc index 6910f63b44d3..7863c8a52265 100644 --- a/src/op/placeholder_op.cc +++ b/src/op/placeholder_op.cc @@ -55,7 +55,7 @@ Array PlaceholderOpNode::output_shape(size_t i) const { Operation PlaceholderOpNode::make(std::string name, Array shape, DataType dtype) { - auto n = make_node(); + auto n = make_object(); n->name = name; n->shape = shape; n->dtype = dtype; diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index e83a23194cf8..57f16f82c54b 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -64,16 +64,16 @@ Array ScanOpNode::output_shape(size_t i) const { Operation ScanOpNode::make(std::string name, std::string tag, - Map attrs, + Map attrs, IterVar axis, Array init, Array update, Array state_placeholder, Array inputs) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } - auto n = make_node(); + auto n = make_object(); CHECK_EQ(init.size(), update.size()); CHECK_EQ(init.size(), state_placeholder.size()); @@ -126,7 +126,7 @@ Array scan(Array init, Array inputs, std::string name, std::string tag, - Map attrs) { + Map attrs) { IterVar scan_axis = IterVarNode::make( Range::make_by_min_extent( @@ -157,7 +157,7 @@ Operation ScanOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - auto n = make_node(*this); + auto n = make_object(*this); for (size_t i = 0; i < n->init.size(); ++i) { if (rmap.count(n->init[i])) { n->init.Set(i, rmap.at(n->init[i])); diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index e59f90f4948e..cfd6e23a0db4 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -59,7 +59,7 @@ Operation TensorComputeOpNode::make(std::string name, Array tensors, Array regions, Array scalar_inputs) { - auto n = make_node(); + auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); n->axis = std::move(axis); @@ -80,8 +80,8 @@ Operation TensorComputeOpNode::ReplaceInputs( const Operation& self, const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); - auto n = make_node(*this); - auto intrin = make_node(*(this->intrin.operator->())); + auto n = make_object(*this); + auto intrin = make_object(*(this->intrin.operator->())); intrin->body = op::ReplaceTensor(this->intrin->body, rmap); if (intrin->reduce_init.defined()) { intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap); @@ -146,7 +146,7 @@ Stmt TensorComputeOpNode::BuildProvide( Tensor tensor = inputs[i]; Region region = this->input_regions[i]; Buffer buffer = this->intrin->buffers[i]; - Array bind_spec{buffer, tensor}; + Array bind_spec{buffer, tensor}; Array tuple; for (size_t i = 0; i < region.size(); ++i) { @@ -162,7 +162,7 @@ Stmt TensorComputeOpNode::BuildProvide( for (int i = 0; i < this->num_outputs(); ++i) { Tensor tensor = stage->op.output(i); Buffer buffer = this->intrin->buffers[num_inputs + i]; - Array bind_spec{buffer, tensor}; + Array bind_spec{buffer, tensor}; Array tuple; for (size_t i = 0; i < this->axis.size(); ++i) { diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index b7f32de8b5ad..7ab54e983028 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -379,7 +379,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, for (size_t i = 0; i < intrin->inputs.size(); ++i) { Tensor tensor = inputs[i]; Buffer buffer = intrin->buffers[i]; - Array bind_spec{buffer, tensor}; + Array bind_spec{buffer, tensor}; auto it = in_region.find(tensor); CHECK(it != in_region.end()); const Array& region = it->second; @@ -407,7 +407,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) { Tensor tensor = stage->op.output(i - intrin->inputs.size()); Buffer buffer = intrin->buffers[i]; - Array bind_spec{buffer, tensor}; + Array bind_spec{buffer, tensor}; output_bind_nest.emplace_back(AttrStmt::make( bind_spec, ir::attr::buffer_bind_scope, Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); @@ -507,7 +507,7 @@ TVM_REGISTER_API("test.op.InferTensorizeRegion") stage, as_unordered_map(dmap), &out_dom, &in_region); - *ret = Array{Map(out_dom), + *ret = Array{Map(out_dom), Map >(in_region)}; }); diff --git a/src/pass/combine_context_call.cc b/src/pass/combine_context_call.cc index f1cb8fe10a4b..e050fee98e67 100644 --- a/src/pass/combine_context_call.cc +++ b/src/pass/combine_context_call.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -108,7 +108,7 @@ class ContextCallCombiner final : public IRMutator { }; LoweredFunc CombineContextCall(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = ContextCallCombiner().Combine(n->body); return LoweredFunc(n); } diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index 4aa8879f679b..a5b3285f7fa9 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -104,7 +104,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } // Write synchronization to be inserted before or after stmt. - std::unordered_map > sync_; + std::unordered_map > sync_; protected: bool Enabled(const Variable* buf, @@ -229,8 +229,8 @@ class CoProcBarrierDetector : public StorageAccessVisitor { PlanWriteBarrier(scope_.back(), nullptr); } - std::unordered_map > barrier_before_; - std::unordered_map > barrier_after_; + std::unordered_map > barrier_before_; + std::unordered_map > barrier_after_; protected: bool Enabled(const Variable* buf, @@ -458,14 +458,14 @@ class CoProcInstDepDetector : public IRVisitor { // insert before is stored in reverse order // the first element is closest to the node. - std::unordered_map > insert_before_; - std::unordered_map > insert_after_; + std::unordered_map > insert_before_; + std::unordered_map > insert_after_; private: // state in the sync entry struct SyncState { // The statement of the state. - const Node* node{nullptr}; + const Object* node{nullptr}; // Set of all possible contexts in the entering moment. std::unordered_set enter_ctx; // Set of all possible contexts in the exit moment. @@ -679,8 +679,8 @@ class CoProcSyncInserter : public IRMutator { private: // insert before is stored in reverse order // the first element is closest to the node. - std::unordered_map > insert_before_; - std::unordered_map > insert_after_; + std::unordered_map > insert_before_; + std::unordered_map > insert_after_; }; diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc index a1c635e2692b..e3ffcc4f15f3 100644 --- a/src/pass/hoist_if_then_else.cc +++ b/src/pass/hoist_if_then_else.cc @@ -35,8 +35,8 @@ namespace tvm { namespace ir { -using HoistMap = std::unordered_map>; -using VarMap = std::unordered_map>; +using HoistMap = std::unordered_map>; +using VarMap = std::unordered_map>; /* * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant. @@ -124,12 +124,12 @@ class IfThenElseHoist { // Check whether a given IfThenElse stmt is the first one appearing // in a For stmt. bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { - std::vector if_node_list; + std::vector if_node_list; const For* for_node = for_stmt.as(); CHECK(for_node); CHECK(if_stmt.as()); - PostOrderVisit(for_node->body, [&](const NodeRef& node) { + PostOrderVisit(for_node->body, [&](const ObjectRef& node) { if (node.as()) { if_node_list.push_back(node.get()); } @@ -141,12 +141,12 @@ bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { // With this function we only need to visit and mutate top level For node // in the main VisitAndMutate function. Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { - const Node* top_for_node; + const Object* top_for_node; const For* parent_for_node = parent_for_stmt.as(); CHECK(parent_for_node); CHECK(new_if_stmt.as()); - PostOrderVisit(parent_for_node->body, [&](const NodeRef& node) { + PostOrderVisit(parent_for_node->body, [&](const ObjectRef& node) { if (node.as()) { top_for_node = node.get(); } @@ -154,7 +154,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { PackedFunc replace_target_for = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ - const NodeRef& current_for = args[0]; + const ObjectRef& current_for = args[0]; if (current_for.get() == top_for_node) { *ret = new_if_stmt; } @@ -173,7 +173,7 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { PackedFunc replace_then_case = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ - const NodeRef& node = args[0]; + const ObjectRef& node = args[0]; if (node == if_stmt) { *ret = node.as()->then_case; } @@ -181,7 +181,7 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { PackedFunc replace_else_case = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ - const NodeRef& node = args[0]; + const ObjectRef& node = args[0]; if (node == if_stmt) { *ret = node.as()->else_case; } @@ -199,13 +199,13 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { // Locate all For nodes and capture child IfThenElse nodes. void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { - PostOrderVisit(stmt, [&](const NodeRef& node){ + PostOrderVisit(stmt, [&](const ObjectRef& node){ const For* for_node = node.as(); if (!for_node) return; std::queue tracker; tracker.push(for_node->body); - Stmt for_stmt = Downcast(node); + Stmt for_stmt = Downcast(node); for2if_map_.insert({for_stmt.get(), std::vector()}); while (!tracker.empty()) { Stmt head = tracker.front(); @@ -227,9 +227,9 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { // Record condition variables. if (!cond_var_map_.count(head.get())) { - std::unordered_set new_var_set; + std::unordered_set new_var_set; cond_var_map_.insert({head.get(), new_var_set}); - PostOrderVisit(if_node->condition, [&](const NodeRef& cond_node) { + PostOrderVisit(if_node->condition, [&](const ObjectRef& cond_node) { if (cond_node.as()) { cond_var_map_[head.get()].insert(cond_node.get()); } @@ -239,15 +239,15 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { continue; } } - ordered_for_list_.emplace_back(Downcast(node)); + ordered_for_list_.emplace_back(Downcast(node)); }); } // For each IfThenElse node, find the highest For node which // meets loop invariant condition. void IfThenElseHoist::LocateTopFor() { - std::unordered_map if_position_map; - std::unordered_set top_for_var_set; + std::unordered_map if_position_map; + std::unordered_set top_for_var_set; // Create IfThenElse -> For map. for (const Stmt& for_stmt : ordered_for_list_) { @@ -256,7 +256,7 @@ void IfThenElseHoist::LocateTopFor() { CHECK(for_node); top_for_var_map_.insert({for_node->loop_var.get(), if_list}); for (const Stmt& if_stmt : if_list) { - const Node* if_node = if_stmt.get(); + const Object* if_node = if_stmt.get(); if2for_map_[if_node].push_back(for_stmt); } } @@ -264,7 +264,7 @@ void IfThenElseHoist::LocateTopFor() { // Locate the highest For node which is loop invariant. for (const auto& item : if2for_map_) { Stmt top_for; - const Node* if_stmt = item.first; + const Object* if_stmt = item.first; std::vector for_list = item.second; for (size_t i = 0; i < for_list.size(); ++i) { const Stmt& for_stmt = for_list.at(i); @@ -291,9 +291,9 @@ void IfThenElseHoist::LocateTopFor() { top_for_var_set.insert(item.second.as()->loop_var.get()); } - std::vector removed_for_var_list; + std::vector removed_for_var_list; for (const auto& item : top_for_var_map_) { - const Node* top_for_var = item.first; + const Object* top_for_var = item.first; std::vector if_list = item.second; if (!top_for_var_set.count(top_for_var)) { removed_for_var_list.push_back(top_for_var); @@ -307,7 +307,7 @@ void IfThenElseHoist::LocateTopFor() { top_for_var_map_[top_for_var] = actual_if_list; } } - for (const Node* top_for_var : removed_for_var_list) { + for (const Object* top_for_var : removed_for_var_list) { top_for_var_map_.erase(top_for_var); } } @@ -374,7 +374,7 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { PackedFunc replace_top_for = PackedFunc( [&](TVMArgs args, TVMRetValue *ret){ - const NodeRef& current_for = args[0]; + const ObjectRef& current_for = args[0]; const For* for_node = current_for.as(); if (!for_node) return; diff --git a/src/pass/infer_fragment.cc b/src/pass/infer_fragment.cc index 71da645474b0..13f9ebade9b1 100644 --- a/src/pass/infer_fragment.cc +++ b/src/pass/infer_fragment.cc @@ -214,7 +214,7 @@ Stmt InferFragment(Stmt stmt) { LoweredFunc InferFragment(LoweredFunc f) { CHECK_NE(f->func_type, kHostFunc); - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = InferFragment(f->body); return LoweredFunc(n); } diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index c80c7fcdaa8c..7e7af187dce1 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -37,7 +37,7 @@ class ExprTouched final : public IRVisitor { bool check_write) : touched_var_(touched), check_write_(check_write) {} - void Visit(const NodeRef& n) final { + void Visit(const ObjectRef& n) final { // early stopping if (expr_touched_ && !check_write_) return; IRVisitor::Visit(n); diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index e399e7f2c54f..6a61d5e402f9 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -358,7 +358,7 @@ class IRDeepCompare : return order_; } - int CompareNodeRef(const NodeRef& lhs, const NodeRef& rhs) { + int CompareNodeRef(const ObjectRef& lhs, const ObjectRef& rhs) { if (order_ != 0) return order_; if (lhs.get() < rhs.get()) { order_ = -1; return order_; diff --git a/src/pass/ir_util.cc b/src/pass/ir_util.cc index 8b6e66135235..cdc708ce5faf 100644 --- a/src/pass/ir_util.cc +++ b/src/pass/ir_util.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,38 +31,38 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { Stmt s = *ri; if (const auto* for_ = s.as()) { - auto n = make_node(*for_); + auto n = make_object(*for_); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* let = s.as()) { - auto n = make_node(*let); + auto n = make_object(*let); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* attr = s.as()) { - auto n = make_node(*attr); + auto n = make_object(*attr); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* ite = s.as()) { - auto n = make_node(*ite); + auto n = make_object(*ite); CHECK(is_no_op(n->then_case)); CHECK(!n->else_case.defined()); n->then_case = body; body = Stmt(n); } else if (const auto* block = s.as()) { - auto n = make_node(*block); + auto n = make_object(*block); CHECK(is_no_op(n->rest)); n->rest = body; body = Stmt(n); } else if (const auto* assert_ = s.as()) { - auto n = make_node(*assert_); + auto n = make_object(*assert_); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* alloc = s.as()) { - auto n = make_node(*alloc); + auto n = make_object(*alloc); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index d6f163ccedc6..467cd5de2ef7 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -29,9 +29,9 @@ namespace ir { // visitor to implement apply class IRApplyVisit : public IRVisitor { public: - explicit IRApplyVisit(std::function f) : f_(f) {} + explicit IRApplyVisit(std::function f) : f_(f) {} - void Visit(const NodeRef& node) final { + void Visit(const ObjectRef& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); IRVisitor::Visit(node); @@ -39,11 +39,11 @@ class IRApplyVisit : public IRVisitor { } private: - std::function f_; - std::unordered_set visited_; + std::function f_; + std::unordered_set visited_; }; -void PostOrderVisit(const NodeRef& node, std::function fvisit) { +void PostOrderVisit(const ObjectRef& node, std::function fvisit) { IRApplyVisit(fvisit).Visit(node); } diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index cfc6e5a7fc68..7f5b4cca0bb4 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -54,7 +54,7 @@ class AttrScopeLifter : public IRMutator { Stmt body = AttrStmt::make( attr_node_, attr_key_, attr_value_, op->body); // undefine them - attr_node_ = NodeRef(); + attr_node_ = ObjectRef(); attr_value_ = Expr(); return Allocate::make( op->buffer_var, op->dtype, @@ -93,7 +93,7 @@ class AttrScopeLifter : public IRMutator { return IRMutator::Mutate_(op, s); } Stmt then_case = this->Mutate(op->then_case); - NodeRef first_node; + ObjectRef first_node; Expr first_value; std::swap(first_node, attr_node_); std::swap(first_value, attr_value_); @@ -119,7 +119,7 @@ class AttrScopeLifter : public IRMutator { else_case = AttrStmt::make( attr_node_, attr_key_, attr_value_, else_case); // undefine them - attr_node_ = NodeRef(); + attr_node_ = ObjectRef(); attr_value_ = Expr(); } if (then_case.same_as(op->then_case) && @@ -149,11 +149,11 @@ class AttrScopeLifter : public IRMutator { std::vector MutateSeq(const std::vector& seq) { std::vector res_seq; - NodeRef curr_node; + ObjectRef curr_node; Expr curr_value; Stmt curr_stmt; for (const Stmt & stmt : seq) { - attr_node_ = NodeRef(); + attr_node_ = ObjectRef(); attr_value_ = Expr(); Stmt rest = this->Mutate(stmt); if (attr_node_.defined() && @@ -188,7 +188,7 @@ class AttrScopeLifter : public IRMutator { } res_seq.push_back(curr_stmt); // reset - attr_node_ = NodeRef(); + attr_node_ = ObjectRef(); attr_value_ = Expr(); } return res_seq; @@ -209,7 +209,7 @@ class AttrScopeLifter : public IRMutator { } std::string attr_key_; - NodeRef attr_node_; + ObjectRef attr_node_; Expr attr_value_; }; diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 1ac386767ae3..e68387f1baad 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -37,10 +37,10 @@ using arith::IntSet; using arith::DeduceBound; using arith::Intersect; -using PartitionKey = std::pair; +using PartitionKey = std::pair; struct PartitionKeyHash { std::size_t operator()(PartitionKey const& k) const noexcept { - std::size_t h1 = std::hash{}(k.first); + std::size_t h1 = std::hash{}(k.first); std::size_t h2 = std::hash{}(k.second); return h1 ^ h2; } @@ -53,7 +53,7 @@ using Partition = std::unordered_map; bool ExprUseVars(Expr expr, const std::unordered_set& vars) { bool success = false; - PostOrderVisit(expr, [&vars, &success](const NodeRef& node) { + PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) { if (const Variable* v = node.as()) { if (vars.count(v)) { success = true; @@ -138,7 +138,7 @@ class CandidateSelector final : public IRVisitor { } } - std::unordered_set candidates; + std::unordered_set candidates; private: bool in_likely_{false}; @@ -257,7 +257,7 @@ class PartitionFinder : public IRVisitor { // Replace the set of conditions given by ps with cond_value (true or false) class ConditionEliminator : public IRMutator { public: - explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) + explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) : ps_(ps), cond_value_(cond_value) {} using IRMutator::Mutate; @@ -269,7 +269,7 @@ class ConditionEliminator : public IRMutator { } private: - std::unordered_set ps_; + std::unordered_set ps_; bool cond_value_; }; @@ -277,7 +277,7 @@ class ConditionEliminator : public IRMutator { // Insert the partition branch at the innermost thread scope class ThreadPartitionInserter : public IRMutator { public: - explicit ThreadPartitionInserter(const std::unordered_set& ps, + explicit ThreadPartitionInserter(const std::unordered_set& ps, Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { @@ -299,7 +299,7 @@ class ThreadPartitionInserter : public IRMutator { } private: - const std::unordered_set& ps_; + const std::unordered_set& ps_; Expr cond_; bool innermost_thread_scope_; }; @@ -364,15 +364,15 @@ class LoopPartitioner : public IRMutator { } private: - Stmt TryPartition(const Node* op, const Stmt& stmt, VarExpr var, + Stmt TryPartition(const Object* op, const Stmt& stmt, VarExpr var, Expr min, Expr max, Stmt body, bool partition_thread_scope); - std::pair> + std::pair> GetIntervalAndCondset(const Partition &partitions, const arith::IntervalSet &for_interval, bool cond_value); - inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); + inline Stmt MakeFor(const Object* op, Expr extent, Stmt body); /* Candidate IRs that may be partitioned potentially */ std::unordered_map hint_map_; @@ -383,12 +383,12 @@ class LoopPartitioner : public IRMutator { // Returns an interval (in the first component) in which all the conditions // given in the second component provably have value given by cond_value -std::pair> +std::pair> LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, const arith::IntervalSet &for_interval, bool cond_value) { Array sets; - std::unordered_set cond_set; + std::unordered_set cond_set; for (const auto &kv : partitions) { if (kv.first.second == cond_value) { @@ -461,7 +461,7 @@ Stmt AppendStmts(const Stmt& a, const Stmt& b) { * which will eventually be simplified to empty code. And because only one loop was generated * from loop 2 we stop recursing. */ -Stmt LoopPartitioner::TryPartition(const Node* node, +Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, VarExpr var, Expr min, @@ -481,7 +481,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, arith::IntervalSet for_interval(min, max); bool cond_value; IntSet middle_interval; - std::unordered_set cond_set; + std::unordered_set cond_set; // find an interval in which all conditions on var are true std::tie(middle_interval, cond_set) = GetIntervalAndCondset(finder.partitions, for_interval, true); @@ -592,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, return s; } -inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { +inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body) { const For *for_node = static_cast(node); CHECK(for_node); if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) { diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc index e24cddd97f25..c45019ab38b8 100644 --- a/src/pass/lower_custom_datatypes.cc +++ b/src/pass/lower_custom_datatypes.cc @@ -130,7 +130,7 @@ class CustomDatatypesLowerer : public IRMutator { }; LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = CustomDatatypesLowerer(target).Mutate(n->body); return LoweredFunc(n); } diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index f0b0b3c36d42..dd81826a5988 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -282,7 +282,7 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = LowerIntrinStmt(n->body, target); return LoweredFunc(n); } diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 2a121180d695..03470271b029 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -338,7 +338,7 @@ class ThreadAllreduceBuilder final : public IRMutator { LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size) { CHECK_NE(f->func_type, kHostFunc); - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = ThreadAllreduceBuilder(warp_size).Mutate(n->body); return LoweredFunc(n); } diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index c8c8fa9c62d0..9a33d647b683 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -360,7 +360,7 @@ class BuiltinLower : public IRMutator { }; LoweredFunc LowerTVMBuiltin(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = BuiltinLower().Build(n->body); return LoweredFunc(n); } diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index 0ed2b6232fc1..0749127b905b 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -380,7 +380,7 @@ class WarpMemoryRewriter : private IRMutator { LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size) { CHECK_EQ(f->func_type, kDeviceFunc); - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body); return LoweredFunc(n); } diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 74b8f891299a..b0f9482545d3 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -42,7 +42,7 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) { LoweredFunc MakeAPI(Stmt body, std::string name, - Array api_args, + Array api_args, int num_unpacked_args, bool is_restricted) { const Stmt nop = Evaluate::make(0); @@ -168,7 +168,7 @@ LoweredFunc MakeAPI(Stmt body, buf_arg.second, buf_arg.second->name_hint); } - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->name = name; n->args = args; n->handle_data_type = binder.def_handle_dtype(); @@ -266,7 +266,7 @@ class DeviceTypeBinder: public IRMutator { LoweredFunc BindDeviceType(LoweredFunc f, int device_type) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = DeviceTypeBinder(device_type).Mutate(n->body); return LoweredFunc(n); } diff --git a/src/pass/remap_thread_axis.cc b/src/pass/remap_thread_axis.cc index f3f0d009573d..49d92d027193 100644 --- a/src/pass/remap_thread_axis.cc +++ b/src/pass/remap_thread_axis.cc @@ -85,7 +85,7 @@ RemapThreadAxis(LoweredFunc f, Map thread_map) { } CHECK_EQ(f->func_type, kDeviceFunc); - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); // replace the thread axis for (size_t i = 0; i < n->thread_axis.size(); ++i) { auto it = tmap.find(n->thread_axis[i]->thread_tag); diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc index 06579f31e17a..1159e568f519 100644 --- a/src/pass/simple_passes.cc +++ b/src/pass/simple_passes.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,7 +31,7 @@ namespace ir { class IRSideEffect : public IRVisitor { public: - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (has_side_effect_) return; IRVisitor::Visit(e); } @@ -103,7 +103,7 @@ Expr Substitute(Expr expr, const Map& value_map) { class VarTouchVisitor : public IRVisitor { public: - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (use_var_) return; IRVisitor::Visit(e); } diff --git a/src/pass/skip_assert.cc b/src/pass/skip_assert.cc index 5f310a61dfe3..817416d9fd2c 100644 --- a/src/pass/skip_assert.cc +++ b/src/pass/skip_assert.cc @@ -38,7 +38,7 @@ Stmt SkipAssert(Stmt stmt) { } LoweredFunc SkipAssert(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = SkipAssert(f->body); return LoweredFunc(n); } diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index 5076300d968a..f045c271456c 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -176,8 +176,8 @@ class HostDeviceSplitter : public IRMutator { handle_data_type_[kv.first.get()] = kv.second; } name_ = f->name; - NodePtr n = - make_node(*f.operator->()); + ObjectPtr n = + make_object(*f.operator->()); n->body = this->Mutate(f->body); n->func_type = kHostFunc; Array ret{LoweredFunc(n)}; @@ -191,7 +191,7 @@ class HostDeviceSplitter : public IRMutator { Stmt SplitDeviceFunc(Stmt body) { std::ostringstream os; os << name_ << "_kernel" << device_funcs_.size(); - NodePtr n = make_node(); + ObjectPtr n = make_object(); // isolate the device function. IRUseDefAnalysis m; m.visit_thread_extent_ = false; diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc index 0fff1e6e6774..37db29c58079 100644 --- a/src/pass/ssa.cc +++ b/src/pass/ssa.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -38,7 +38,7 @@ class IRVerifySSA final : public IRVisitor { public: bool is_ssa{true}; - void Visit(const NodeRef& n) final { + void Visit(const ObjectRef& n) final { if (!is_ssa) return; IRVisitor::Visit(n); } diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index c146a8709b1e..bf8d4e020521 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -341,7 +341,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { } LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = LowerStorageAccessInfo(f->body); return LoweredFunc(n); } diff --git a/src/pass/storage_access.h b/src/pass/storage_access.h index 028645b78640..302ca929581d 100644 --- a/src/pass/storage_access.h +++ b/src/pass/storage_access.h @@ -71,7 +71,7 @@ class StorageAccessVisitor : public IRVisitor { /*! \brief Access pattern about a single statement */ struct StmtEntry { /*! \brief The statement */ - const Node* stmt; + const Object* stmt; /*! \brief access patterns in the statement */ std::vector access; }; diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index d6dde29a519d..2df2672adcb1 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -402,7 +402,7 @@ class StorageFlattener : public IRMutator { // We do support a few relaxed case, such as bindingx // region with shape [1, 1, n, m] to buffer with shape [n, m] Stmt HandleBufferBindScope(const AttrStmt* op) { - Array arr = Downcast > (op->node); + Array arr = Downcast > (op->node); CHECK_EQ(arr.size(), 2U); const BufferNode* buffer = arr[0].as(); const TensorNode* tensor = arr[1].as(); @@ -512,7 +512,7 @@ class StorageFlattener : public IRMutator { // Dimension alignment std::unordered_map > dim_align_; // Storage scope - std::unordered_map storage_scope_; + std::unordered_map storage_scope_; // The current thread scope. std::vector curr_thread_scope_; // Collects shapes. diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 12a06da8007f..01c6f983d692 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -59,7 +59,7 @@ class LinearAccessPatternFinder final : public IRVisitor { /*! \brief record the touch hist of statment. */ struct StmtEntry { // The statment - const Node* stmt; + const Object* stmt; // The index in the linear_seq_ to point to end of the nested scope. // This is only set to non-zero if stmt is a nested scope. // if offset > 0, means this is the begin, the end entry is current_index + offset @@ -236,7 +236,7 @@ class LinearAccessPatternFinder final : public IRVisitor { // class InplaceOpVerifier : public IRVisitor { public: - bool Check(const Node* stmt, + bool Check(const Object* stmt, const Variable* dst, const Variable* src) { dst_ = dst; @@ -258,7 +258,7 @@ class InplaceOpVerifier : public IRVisitor { using IRVisitor::Visit_; - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (!result_) return; IRVisitor::Visit(e); } @@ -471,7 +471,7 @@ class StoragePlanRewriter : public IRMutator { // The scope that this alloc attaches after // For shared/local memory it is beginning of the thread extent. // for global memory it is nullptr, means beginning of everything. - const Node* attach_scope_{nullptr}; + const Object* attach_scope_{nullptr}; // The constant size of the buffer in bits, only used if it is constant uint64_t const_nbits{0}; // The storage scope. @@ -695,7 +695,7 @@ class StoragePlanRewriter : public IRMutator { } } } - void PlanNewScope(const Node* op) { + void PlanNewScope(const Object* op) { if (thread_scope_ != nullptr) { CHECK(thread_scope_ == op); // erase all memory atatched to this scope. @@ -808,7 +808,7 @@ class StoragePlanRewriter : public IRMutator { } // Allocate new storage entry. StorageEntry* NewAlloc(const Allocate* op, - const Node* attach_scope, + const Object* attach_scope, const StorageScope& scope, size_t const_nbits) { CHECK(op != nullptr); @@ -824,7 +824,7 @@ class StoragePlanRewriter : public IRMutator { } StorageEntry* FindAlloc(const Allocate* op, - const Node* attach_scope, + const Object* attach_scope, const StorageScope& scope) { CHECK(op != nullptr); // skip plan for local variable, @@ -908,17 +908,17 @@ class StoragePlanRewriter : public IRMutator { } } // thread scope. - const Node* thread_scope_{nullptr}; + const Object* thread_scope_{nullptr}; // whether enable inplace detection. bool detect_inplace_{false}; // Locations of free ops. - std::unordered_map event_map_; + std::unordered_map event_map_; // constant size free map. std::multimap const_free_map_; // symbolic free list, for non constant items. std::list sym_free_list_; // The allocation attach map - std::unordered_map > attach_map_; + std::unordered_map > attach_map_; // The allocation assign map std::unordered_map alloc_map_; // The allocations @@ -987,7 +987,7 @@ class VectorAllocRewriter : public IRMutator { LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); VectorAllocRewriter rewriter; n->body = rewriter.Mutate(n->body); for (Var arg : f->args) { diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 018a6bb2e79e..0f8bef8383f2 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -39,7 +39,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { : sync_scope_(sync_scope) {} // The syncs inserted before each statement - std::unordered_set syncs_inserted_; + std::unordered_set syncs_inserted_; protected: bool Enabled(const Variable* buf, @@ -200,7 +200,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { class ThreadSyncInserter : public IRMutator { public: ThreadSyncInserter(StorageScope sync_scope, - const std::unordered_set& syncs) + const std::unordered_set& syncs) : sync_scope_(sync_scope), syncs_(syncs) {} Stmt Mutate(Stmt stmt) final { @@ -346,11 +346,11 @@ class ThreadSyncInserter : public IRMutator { } // data structure. StorageScope sync_scope_; - const std::unordered_set& syncs_; + const std::unordered_set& syncs_; // The storage scope of each buffer std::unordered_map storage_scope_; // The read write statistics of storage - std::unordered_map rw_stats_; + std::unordered_map rw_stats_; // The statistics for global barrier bool in_thread_env_{false}; // memorized results @@ -369,7 +369,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { CHECK_NE(f->func_type, kHostFunc); - auto n = make_node(*f.operator->()); + auto n = make_object(*f.operator->()); n->body = ThreadSync(f->body, storage_scope); return LoweredFunc(n); } diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index 2ead2b934d7e..8dcc0e49e119 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -225,9 +225,9 @@ class MMAMatcher: public IRVisitor { } std::unordered_map buf_map_; - std::unordered_map storage_scope_; + std::unordered_map storage_scope_; std::unordered_map> mma_sync_; - std::unordered_map buf_name_; + std::unordered_map buf_name_; std::unordered_set frag_reg_; bool matched_{false}; bool tensor_core_on_{false}; @@ -365,7 +365,7 @@ class ScheduleAnalyser { std::unordered_map matrix_abc_; std::unordered_map matrix_major_; std::unordered_map> mma_sync_; - std::unordered_map buf_name_; + std::unordered_map buf_name_; }; // IndexVisitor visits access index of fragment @@ -745,7 +745,7 @@ class BufferAnalyser : public IRVisitor { std::unordered_map buf_map_; std::unordered_map > dim_align_; - std::unordered_map storage_scope_; + std::unordered_map storage_scope_; std::unordered_map matrix_abc_; std::unordered_map matrix_major_; std::unordered_set frag_reg_; @@ -868,9 +868,9 @@ class TensorCoreIRMutator : public IRMutator { Expr c = operands[2]; auto cc = c.as(); - NodePtr buffer_node_a = make_node(); - NodePtr buffer_node_b = make_node(); - NodePtr buffer_node_c = make_node(); + ObjectPtr buffer_node_a = make_object(); + ObjectPtr buffer_node_b = make_object(); + ObjectPtr buffer_node_c = make_object(); auto mma_sync_call = [&buffer_node_a, &buffer_node_b] @@ -921,7 +921,7 @@ class TensorCoreIRMutator : public IRMutator { Call::Intrinsic)); }; - NodePtr buffer_node = make_node(); + ObjectPtr buffer_node = make_object(); return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index}, fill_fragment_call, call->dtype); @@ -971,7 +971,7 @@ class TensorCoreIRMutator : public IRMutator { Call::Intrinsic)); }; - NodePtr buffer_node = make_node(); + ObjectPtr buffer_node = make_object(); return add_buffer_bind_scope_(call, buffer_node, TensorKey{op->func, op->value_index}, load_matrix_call, call->dtype); @@ -1011,7 +1011,7 @@ class TensorCoreIRMutator : public IRMutator { Call::Intrinsic)); }; - NodePtr buffer_node = make_node(); + ObjectPtr buffer_node = make_object(); return add_buffer_bind_scope_(call, buffer_node, TensorKey{call->func, call->value_index}, store_matrix_call, call->dtype); @@ -1073,7 +1073,7 @@ class TensorCoreIRMutator : public IRMutator { } Stmt add_buffer_bind_scope_(const Call* call, - const NodePtr &buffer_node, const TensorKey &key, + const ObjectPtr &buffer_node, const TensorKey &key, const std::function &call_back, DataType datatype) { auto it = bounds_.find(key); @@ -1124,7 +1124,7 @@ class TensorCoreIRMutator : public IRMutator { buffer_node->offset_factor = 1; Buffer buffer(buffer_node); - NodePtr tensor_node = make_node(); + ObjectPtr tensor_node = make_object(); tensor_node->value_index = key.value_index; tensor_node->op = Downcast(key.f); tensor_node->shape = shape; @@ -1140,7 +1140,7 @@ class TensorCoreIRMutator : public IRMutator { intrinsic::tvm_tuple, args, Call::Intrinsic); - Array node = {buffer, tensor}; + Array node = {buffer, tensor}; return AttrStmt::make(node, "buffer_bind_scope", tuple, diff --git a/src/pass/verify_memory.cc b/src/pass/verify_memory.cc index 1d7bb3d8425b..4a5c8adeb8e7 100644 --- a/src/pass/verify_memory.cc +++ b/src/pass/verify_memory.cc @@ -64,7 +64,7 @@ class MemoryAccessVerifier final : protected IRVisitor { protected: /// Visitor implementation //@{ - void Visit(const NodeRef &n) final { + void Visit(const ObjectRef &n) final { if (Failed()) return; IRVisitor::Visit(n); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 780e19bd017f..102e4c299774 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -256,7 +256,7 @@ class RelayBuildModule : public runtime::ModuleNode { relay::Function func, const std::unordered_map& params) { std::unordered_map name_dict; - std::unordered_set repeat_var; + std::unordered_set repeat_var; for (auto arg : func->params) { const auto &name = arg->name_hint(); if (name_dict.count(name)) { @@ -266,7 +266,7 @@ class RelayBuildModule : public runtime::ModuleNode { } } - std::unordered_map bind_dict; + std::unordered_map bind_dict; for (auto &kv : params) { if (name_dict.count(kv.first) == 0) { continue; diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 7c33ac9ed61a..68a3bed3bc4b 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -51,7 +51,7 @@ TVM_REGISTER_NODE_TYPE(CCacheValueNode); TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); CCacheKey CCacheKeyNode::make(Function source_func, Target target) { - auto n = make_node(); + auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); return CCacheKey(n); @@ -109,7 +109,7 @@ class ScheduleGetter : std::pair Create(const Function& prim_func) { static auto fschedule = Op::GetAttr("FTVMSchedule"); - auto cache_node = make_node(); + auto cache_node = make_object(); cache_node->target = target_; for (Var param : prim_func->params) { Array inputs; @@ -330,7 +330,7 @@ class ScheduleGetter : Attrs master_attrs_; int master_op_pattern_{0}; std::ostringstream readable_name_stream_; - std::unordered_map, NodeHash, NodeEqual> memo_; + std::unordered_map, ObjectHash, ObjectEqual> memo_; Array scalars_; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. @@ -380,7 +380,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { param_shapes_[param] = shape_inputs; } readable_name_stream_ << "shape_func"; - auto cache_node = make_node(); + auto cache_node = make_object(); cache_node->outputs = VisitExpr(prim_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; @@ -574,13 +574,13 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { /*! \brief String stream for function name */ std::ostringstream readable_name_stream_; /*! \brief Map from parameter to its shape function usage state */ - std::unordered_map param_states_; + std::unordered_map param_states_; /*! \brief Map from parameter to list of data placeholder */ - std::unordered_map, NodeHash, NodeEqual> param_data_; + std::unordered_map, ObjectHash, ObjectEqual> param_data_; /*! \brief Map from parameter to list of shape placeholder */ - std::unordered_map, NodeHash, NodeEqual> param_shapes_; + std::unordered_map, ObjectHash, ObjectEqual> param_shapes_; /*! \brief Memoized visit result */ - std::unordered_map, NodeHash, NodeEqual> memo_; + std::unordered_map, ObjectHash, ObjectEqual> memo_; /*! \brief Stack of data dependencies for shape function */ std::vector data_dependants_; /*! \brief Scalars used in the shape function */ @@ -656,9 +656,9 @@ class CompileEngineImpl : public CompileEngineNode { cache_.clear(); } // List all items in the cache. - Array ListItems() { + Array ListItems() { std::lock_guard lock(mutex_); - Array items; + Array items; for (auto& kv : cache_) { items.push_back(kv.first); items.push_back(kv.second); @@ -688,14 +688,14 @@ class CompileEngineImpl : public CompileEngineNode { if (it->second->cached_func.defined()) return it->second; value = it->second; } else { - value = CCacheValue(make_node()); + value = CCacheValue(make_object()); value->use_count = 0; cache_[key] = value; } // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. if (!key->source_func->UseDefaultCompiler()) { - auto cache_node = make_node(); + auto cache_node = make_object(); const auto name_node = FunctionGetAttr(key->source_func, attr::kExternalSymbol).as(); CHECK(name_node != nullptr) << "External function has not been attached a name yet."; @@ -709,7 +709,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(!value->cached_func.defined()); auto spair = CreateSchedule(key->source_func, key->target); - auto cache_node = make_node( + auto cache_node = make_object( *(spair.second.operator->())); // Skip lowering for device copy node. @@ -749,7 +749,7 @@ class CompileEngineImpl : public CompileEngineNode { if (it->second->cached_func.defined()) return it->second; value = it->second; } else { - value = CCacheValue(make_node()); + value = CCacheValue(make_object()); value->use_count = 0; shape_func_cache_[key] = value; } @@ -758,7 +758,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(!value->cached_func.defined()); auto spair = MakeShapeFunc().Create(key->source_func); - auto cache_node = make_node( + auto cache_node = make_object( *(spair.second.operator->())); cache_node->func_name = GetUniqueName(cache_node->func_name); cache_node->target = key->target; @@ -811,7 +811,7 @@ const CompileEngine& CompileEngine::Global() { // intentionally allocate raw pointer to avoid // free during destructuion. static CompileEngine* inst = new CompileEngine( - make_node()); + make_object()); return *inst; } @@ -852,7 +852,7 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems") -.set_body_typed(CompileEngine)>( +.set_body_typed(CompileEngine)>( [](CompileEngine self){ return static_cast(self.operator->())->ListItems(); }); diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 596dfa7154f7..f6c38ba6b9a9 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -45,7 +45,7 @@ enum ShapeFuncParamState { }; /*! \brief Node container to represent a cached function. */ -struct CachedFuncNode : public Node { +struct CachedFuncNode : public Object { /* \brief compiled target */ tvm::Target target; /*! \brief Function name */ @@ -69,15 +69,17 @@ struct CachedFuncNode : public Node { } static constexpr const char* _type_key = "relay.CachedFunc"; - TVM_DECLARE_NODE_TYPE_INFO(CachedFuncNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); }; -TVM_DEFINE_NODE_REF(CachedFunc, CachedFuncNode); - +class CachedFunc : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); +}; class CCacheKey; /*! \brief Compile cache key */ -class CCacheKeyNode : public Node { +class CCacheKeyNode : public Object { public: /*! \brief The source function to be lowered. */ Function source_func; @@ -106,7 +108,7 @@ class CCacheKeyNode : public Node { Target target); static constexpr const char* _type_key = "relay.CCacheKey"; - TVM_DECLARE_NODE_TYPE_INFO(CCacheKeyNode, tvm::Node); + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); private: /*! @@ -116,10 +118,10 @@ class CCacheKeyNode : public Node { }; /*! \brief cache entry used in compile engine */ -class CCacheKey : public NodeRef { +class CCacheKey : public ObjectRef { public: CCacheKey() {} - explicit CCacheKey(ObjectPtr n) : NodeRef(n) {} + explicit CCacheKey(ObjectPtr n) : ObjectRef(n) {} const CCacheKeyNode* operator->() const { return static_cast(get()); } @@ -132,7 +134,7 @@ class CCacheKey : public NodeRef { }; /*! \brief Node container for compile cache. */ -class CCacheValueNode : public Node { +class CCacheValueNode : public Object { public: /*! \brief The corresponding function */ CachedFunc cached_func; @@ -146,14 +148,14 @@ class CCacheValueNode : public Node { v->Visit("use_count", &use_count); } static constexpr const char* _type_key = "relay.CCacheValue"; - TVM_DECLARE_NODE_TYPE_INFO(CCacheValueNode, tvm::Node); + TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); }; /*! \brief cache entry used in compile engine */ -class CCacheValue : public NodeRef { +class CCacheValue : public ObjectRef { public: CCacheValue() {} - explicit CCacheValue(ObjectPtr n) : NodeRef(n) {} + explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} CCacheValueNode* operator->() { return static_cast(get_mutable()); } @@ -167,7 +169,7 @@ class CCacheValue : public NodeRef { * \brief Backend compilation engine for * low level code generation. */ -class CompileEngineNode : public Node { +class CompileEngineNode : public Object { public: /*! * \brief Get lowered result. @@ -200,14 +202,14 @@ class CompileEngineNode : public Node { void VisitAttrs(AttrVisitor*) {} static constexpr const char* _type_key = "relay.CompileEngine"; - TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(CompileEngineNode, Object); }; /*! \brief cache entry used in compile engine */ -class CompileEngine : public NodeRef { +class CompileEngine : public ObjectRef { public: CompileEngine() {} - explicit CompileEngine(ObjectPtr n) : NodeRef(n) {} + explicit CompileEngine(ObjectPtr n) : ObjectRef(n) {} CompileEngineNode* operator->() { return static_cast(get_mutable()); } diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index cdaf813c44e4..84fada060744 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -152,7 +152,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase { code_stream_ << builder.JIT(); } - runtime::Module CreateCSourceModule(const NodeRef& ref) override { + runtime::Module CreateCSourceModule(const ObjectRef& ref) override { // Create headers code_stream_ << "#include \n"; code_stream_ << "#include \n"; @@ -170,7 +170,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase { out[i] = a[i] p_OP_ b[i]; \ } \ } - + #define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_) \ extern "C" void p_ID_(float* a, float* b, float* out) { \ for (int64_t i = 0; i < p_DIM1_; ++i) { \ @@ -214,7 +214,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase { * CUDA, etc, under TVM, so the generated code could be packed in a runtime * module. This module simplifies code serialization and invocation. */ -runtime::Module CCompiler(const NodeRef& ref) { +runtime::Module CCompiler(const ObjectRef& ref) { CSourceCodegen csource; return csource.CreateCSourceModule(ref); } diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 1319ca2ff787..d97f5dcd9103 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -49,7 +49,7 @@ class CSourceModuleCodegenBase { * * \return A runtime module. */ - virtual runtime::Module CreateCSourceModule(const NodeRef& ref) = 0; + virtual runtime::Module CreateCSourceModule(const ObjectRef& ref) = 0; /*! * \brief Get the external symbol of the Relay function name. diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index e7f7bd6ff559..675198fcc9b3 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -254,7 +254,7 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { * * \return The runtime module that contains C source code. */ - runtime::Module CreateCSourceModule(const NodeRef& ref) override { + runtime::Module CreateCSourceModule(const ObjectRef& ref) override { // Create headers code_stream_ << "#include \n"; code_stream_ << "#include \n"; @@ -298,7 +298,7 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { * \brief The external compiler/codegen tool. It takes a Relay expression/module and * compile it into a runtime module. */ -runtime::Module DNNLCompiler(const NodeRef& ref) { +runtime::Module DNNLCompiler(const ObjectRef& ref) { DNNLModuleCodegen dnnl; return dnnl.CreateCSourceModule(ref); } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index fc12cf66900f..5f210436f9b9 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -47,9 +47,9 @@ class GraphOpNode; using IntegerArray = Array; using ShapeVector = std::vector >; using GraphAttrs = std::unordered_map; -using GraphNodePtr = std::shared_ptr; -using GraphInputNodePtr = std::shared_ptr; -using GraphOpNodePtr = std::shared_ptr; +using GraphObjectPtr = std::shared_ptr; +using GraphInputObjectPtr = std::shared_ptr; +using GraphOpObjectPtr = std::shared_ptr; using TargetsMap = std::unordered_map; /*! \brief Lowered outputs */ @@ -255,7 +255,7 @@ class GraphRuntimeCodegen * \param expr * \return std::vector<_NodeRef> */ - std::vector AddNode(GraphNodePtr node, Expr expr) { + std::vector AddNode(GraphObjectPtr node, Expr expr) { auto checked_type = expr->checked_type(); size_t count = storage_device_map_.count(expr); CHECK_GT(count, 0) << "Expr is not existing in storage plan"; @@ -319,7 +319,7 @@ class GraphRuntimeCodegen } /*! \brief Visitors */ - std::unordered_map, NodeHash, NodeEqual> visitor_cache_; + std::unordered_map, ObjectHash, ObjectEqual> visitor_cache_; std::vector VisitExpr(const Expr& expr) override { if (visitor_cache_.count(expr)) return visitor_cache_.at(expr); @@ -587,13 +587,13 @@ class GraphRuntimeCodegen protected: /*! \brief nodes */ - std::vector nodes_; + std::vector nodes_; /*! \brief output of graph */ std::vector heads_; /*! \brief mod */ runtime::Module* mod_; /*! \brief variable map */ - std::unordered_map> var_map_; + std::unordered_map> var_map_; /*! \brief target device */ TargetsMap targets_; /*! \brief params */ @@ -601,7 +601,7 @@ class GraphRuntimeCodegen /*! \brief plan memory of device result */ Map> storage_device_map_; /*! \brief lowered funcs */ - std::unordered_map> + std::unordered_map> lowered_funcs_; /*! \brief name map */ std::unordered_map name_map_; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index b5fd0c914b62..b4777845670a 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -45,7 +45,7 @@ inline const PackedFunc& GetPackedFunc(const std::string& name) { /* Value Implementation */ Closure ClosureNode::make(tvm::Map env, Function func) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->env = std::move(env); n->func = std::move(func); return Closure(n); @@ -64,7 +64,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) // TODO(@jroesch): this doesn't support mutual letrec /* Value Implementation */ RecClosure RecClosureNode::make(Closure clos, Var bind) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->clos = std::move(clos); n->bind = std::move(bind); return RecClosure(n); @@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TupleValue TupleValueNode::make(tvm::Array value) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->fields = value; return TupleValue(n); } @@ -95,7 +95,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TensorValue TensorValueNode::make(runtime::NDArray data) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = std::move(data); return TensorValue(n); } @@ -112,7 +112,7 @@ TVM_REGISTER_API("relay._make.TensorValue") .set_body_typed(TensorValueNode::make); RefValue RefValueNode::make(Value value) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->value = value; return RefValue(n); } @@ -131,7 +131,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ConstructorValue ConstructorValueNode::make(int32_t tag, tvm::Array fields, Constructor constructor) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->tag = tag; n->fields = fields; n->constructor = constructor; @@ -204,7 +204,7 @@ struct Stack { class InterpreterState; /*! \brief A container capturing the state of the interpreter. */ -class InterpreterStateNode : public Node { +class InterpreterStateNode : public Object { public: using Frame = tvm::Map; using Stack = tvm::Array; @@ -223,13 +223,16 @@ class InterpreterStateNode : public Node { static InterpreterState make(Expr current_expr, Stack stack); static constexpr const char* _type_key = "relay.InterpreterState"; - TVM_DECLARE_NODE_TYPE_INFO(InterpreterStateNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateNode, Object); }; -RELAY_DEFINE_NODE_REF(InterpreterState, InterpreterStateNode, NodeRef); +class InterpreterState : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateNode); +}; InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->current_expr = std::move(current_expr); n->stack = std::move(stack); return InterpreterState(n); diff --git a/src/relay/backend/param_dict.cc b/src/relay/backend/param_dict.cc index 9bde3a0b4edd..e517fee3a4af 100644 --- a/src/relay/backend/param_dict.cc +++ b/src/relay/backend/param_dict.cc @@ -95,7 +95,7 @@ TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict") for (size_t i = 0; i < size; ++i) { tvm::runtime::NDArray temp; temp.Load(strm); - auto n = tvm::make_node(); + auto n = tvm::make_object(); n->name = std::move(names[i]); n->array = temp; ret.push_back(NamedNDArray(n)); diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h index aa3c0244118f..e2d225aadd19 100644 --- a/src/relay/backend/param_dict.h +++ b/src/relay/backend/param_dict.h @@ -40,7 +40,7 @@ constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; /*! * \brief Wrapper node for naming `NDArray`s. */ -struct NamedNDArrayNode : public ::tvm::Node { +struct NamedNDArrayNode : public ::tvm::Object { std::string name; tvm::runtime::NDArray array; @@ -50,11 +50,13 @@ struct NamedNDArrayNode : public ::tvm::Node { } static constexpr const char* _type_key = "NamedNDArray"; - TVM_DECLARE_NODE_TYPE_INFO(NamedNDArrayNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(NamedNDArrayNode, Object); }; -TVM_DEFINE_NODE_REF(NamedNDArray, NamedNDArrayNode); - +class NamedNDArray : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(NamedNDArray, ObjectRef, NamedNDArrayNode); +}; } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 0de47bda0bbc..af425a4966d0 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -112,7 +112,7 @@ struct ConditionNode { virtual ~ConditionNode() {} }; -using ConditionNodePtr = std::shared_ptr; +using ConditionObjectPtr = std::shared_ptr; /*! * \brief A var binding condition @@ -144,15 +144,15 @@ struct TagCompare : ConditionNode { ~TagCompare() {} }; -using TreeNodePtr = typename relay::TreeNode::pointer; -using TreeLeafNode = relay::TreeLeafNode; -using TreeLeafFatalNode = relay::TreeLeafFatalNode; -using TreeBranchNode = relay::TreeBranchNode; +using TreeObjectPtr = typename relay::TreeNode::pointer; +using TreeLeafNode = relay::TreeLeafNode; +using TreeLeafFatalNode = relay::TreeLeafFatalNode; +using TreeBranchNode = relay::TreeBranchNode; -TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data, +TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, Pattern pattern, - TreeNodePtr then_branch, - TreeNodePtr else_branch) { + TreeObjectPtr then_branch, + TreeObjectPtr else_branch) { if (pattern.as()) { // We ignore wildcard binding since it's not producing new vars return then_branch; @@ -185,16 +185,16 @@ TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data, } } -TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data, +TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, Clause clause, - TreeNodePtr else_branch) { + TreeObjectPtr else_branch) { return BuildDecisionTreeFromPattern(data, clause->lhs, TreeLeafNode::Make(clause->rhs), else_branch); } -TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array clauses) { +TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array clauses) { // When nothing matches, the VM throws fatal error - TreeNodePtr else_branch = TreeLeafFatalNode::Make(); + TreeObjectPtr else_branch = TreeLeafFatalNode::Make(); // Start from the last clause for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) { else_branch = BuildDecisionTreeFromClause(data, *it, else_branch); @@ -674,7 +674,7 @@ class VMFunctionCompiler : ExprFunctor { } } - void CompileTreeNode(TreeNodePtr tree) { + void CompileTreeNode(TreeObjectPtr tree) { if (std::dynamic_pointer_cast(tree)) { auto node = std::dynamic_pointer_cast(tree); VisitExpr(node->body); @@ -731,13 +731,13 @@ class VMFunctionCompiler : ExprFunctor { protected: /*! \brief Store the expression a variable points to. */ - std::unordered_map expr_map_; + std::unordered_map expr_map_; /*! \brief Instructions in the VMFunction. */ std::vector instructions_; /*! \brief Parameter names of the function. */ std::vector params_; /*! \brief Map from var to register number. */ - std::unordered_map var_register_map_; + std::unordered_map var_register_map_; /*! \brief Last used register number. */ size_t last_register_; /*! \brief Total number of virtual registers allocated. */ @@ -786,7 +786,7 @@ relay::Function VMCompiler::BindParamsByName( relay::Function func, const std::unordered_map& params) { std::unordered_map name_dict; - std::unordered_set repeat_var; + std::unordered_set repeat_var; for (auto arg : func->params) { const auto &name = arg->name_hint(); if (name_dict.count(name)) { @@ -795,7 +795,7 @@ relay::Function VMCompiler::BindParamsByName( name_dict[name] = arg; } } - std::unordered_map bind_dict; + std::unordered_map bind_dict; for (auto &kv : params) { if (name_dict.count(kv.first) == 0) { continue; diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 8cdb12e4dafa..2beab1536a18 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -52,7 +52,7 @@ using namespace tvm::runtime::vm; using namespace relay::transform; template -using NodeMap = std::unordered_map; +using NodeMap = std::unordered_map; using TagMap = NodeMap; using TagNameMap = std::unordered_map; using GlobalMap = NodeMap; @@ -76,7 +76,7 @@ struct VMCompilerContext { // List of cached functions std::vector cached_funcs; // The functions that have been lowered. - std::unordered_map seen_funcs; + std::unordered_map seen_funcs; }; diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 3bb1458b0758..f94f837ef550 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -53,7 +53,7 @@ namespace vm { */ struct PrimitiveInliner : ExprMutator { Module module_; - std::unordered_map var_map; + std::unordered_map var_map; explicit PrimitiveInliner(const Module& module) : module_(module) {} diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index ab9dc8cbec63..7298c50e6f1f 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -43,7 +43,7 @@ inline std::string GenerateName(const Function& func) { } bool IsClosure(const Function& func) { - NodeRef res = FunctionGetAttr(func, attr::kClosure); + ObjectRef res = FunctionGetAttr(func, attr::kClosure); const ir::IntImm* pval = res.as(); return pval && pval->value != 0; } @@ -200,7 +200,7 @@ class LambdaLifter : public ExprMutator { } private: - std::unordered_map lambda_map_; + std::unordered_map lambda_map_; std::vector letrec_; Module module_; }; diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index ee44e26fdfa1..546f1d30cb41 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -46,7 +46,7 @@ struct CallTracer : ExprVisitor { std::unordered_set called_funcs_; // Record the expressions that are being visited - std::unordered_set visiting_; + std::unordered_set visiting_; explicit CallTracer(const Module& module) : module_{module}, @@ -96,7 +96,7 @@ struct CallTracer : ExprVisitor { * * \param module The Relay module. * \param entry_funcs The set of functions that can be entry function. - * + * * \return The module with dead functions removed. */ Module RemoveUnusedFunctions(const Module& module, diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 1f51ecc84fdc..73172879d393 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -28,7 +28,7 @@ namespace tvm { namespace relay { PatternWildcard PatternWildcardNode::make() { - NodePtr n = make_node(); + ObjectPtr n = make_object(); return PatternWildcard(n); } @@ -43,7 +43,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); PatternVar PatternVarNode::make(tvm::relay::Var var) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->var = std::move(var); return PatternVar(n); } @@ -61,7 +61,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) PatternConstructor PatternConstructorNode::make(Constructor constructor, tvm::Array patterns) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->constructor = std::move(constructor); n->patterns = std::move(patterns); return PatternConstructor(n); @@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); PatternTuple PatternTupleNode::make(tvm::Array patterns) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->patterns = std::move(patterns); return PatternTuple(n); } @@ -99,7 +99,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) Constructor ConstructorNode::make(std::string name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->inputs = std::move(inputs); n->belong_to = std::move(belong_to); @@ -121,7 +121,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TypeData TypeDataNode::make(GlobalTypeVar header, tvm::Array type_vars, tvm::Array constructors) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->header = std::move(header); n->type_vars = std::move(type_vars); n->constructors = std::move(constructors); @@ -141,7 +141,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); Clause ClauseNode::make(Pattern lhs, Expr rhs) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->lhs = std::move(lhs); n->rhs = std::move(rhs); return Clause(n); @@ -160,7 +160,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); Match MatchNode::make(Expr data, tvm::Array clauses, bool complete) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = std::move(data); n->clauses = std::move(clauses); n->complete = complete; diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index df91f794f6d1..589de09b0b81 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -49,7 +49,7 @@ class AlphaEqualHandler: * \param rhs The right hand operand. * \return The comparison result. */ - bool Equal(const NodeRef& lhs, const NodeRef& rhs) { + bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; if (lhs->IsInstance()) { @@ -88,7 +88,7 @@ class AlphaEqualHandler: * \param rhs The right hand operand. * \return The comparison result. */ - bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) { + bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) { auto compute = [&]() { if (&lhs == &rhs) return true; if (auto lhsd = lhs.as()) { @@ -127,7 +127,7 @@ class AlphaEqualHandler: return Compare(compute(), lhs, rhs); } - bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) { + bool Compare(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { if (assert_mode_) { CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true); } @@ -180,7 +180,7 @@ class AlphaEqualHandler: * \param rhs The right hand operand. * \return The compare result. */ - bool LeafNodeEqual(const ObjectRef& lhs, const ObjectRef& rhs) { + bool LeafObjectEqual(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; auto it = equal_map_.find(lhs); if (it != equal_map_.end()) { @@ -197,7 +197,7 @@ class AlphaEqualHandler: } using AttrsEqualHandler::VisitAttr_; bool VisitAttr_(const Variable* lhs, const ObjectRef& other) final { - return LeafNodeEqual(GetRef(lhs), other); + return LeafObjectEqual(GetRef(lhs), other); } // Type equality @@ -211,13 +211,13 @@ class AlphaEqualHandler: } bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final { - return LeafNodeEqual(GetRef(lhs), other); + return LeafObjectEqual(GetRef(lhs), other); } bool VisitType_(const TypeVarNode* lhs, const Type& other) final { if (const TypeVarNode* rhs = other.as()) { if (lhs->kind != rhs->kind) return false; - return LeafNodeEqual(GetRef(lhs), other); + return LeafObjectEqual(GetRef(lhs), other); } else { return false; } @@ -290,7 +290,7 @@ class AlphaEqualHandler: } bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final { - return LeafNodeEqual(GetRef(lhs), other); + return LeafObjectEqual(GetRef(lhs), other); } bool VisitType_(const TypeCallNode* lhs, const Type& other) final { @@ -366,7 +366,7 @@ class AlphaEqualHandler: if (const VarNode* rhs = other.as()) { if (lhs->name_hint() != rhs->name_hint()) return false; if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false; - return LeafNodeEqual(GetRef(lhs), other); + return LeafObjectEqual(GetRef(lhs), other); } else { return false; } @@ -600,23 +600,23 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { // TODO(@jroesch): move to correct namespace? TVM_REGISTER_API("relay._make._alpha_equal") -.set_body_typed([](NodeRef a, NodeRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(false, false).Equal(a, b); }); TVM_REGISTER_API("relay._make._assert_alpha_equal") -.set_body_typed([](NodeRef a, NodeRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal"; }); TVM_REGISTER_API("relay._make._graph_equal") -.set_body_typed([](NodeRef a, NodeRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { return AlphaEqualHandler(true, false).Equal(a, b); }); TVM_REGISTER_API("relay._make._assert_graph_equal") -.set_body_typed([](NodeRef a, NodeRef b) { +.set_body_typed([](ObjectRef a, ObjectRef b) { bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal"; }); diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 3bc916d9a406..ca8755730d80 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -33,11 +33,11 @@ using namespace tvm::runtime; ObjectPtr GetSourceNameNode(const std::string& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr - static std::unordered_map > source_map; + static std::unordered_map > source_map; auto sn = source_map.find(name); if (sn == source_map.end()) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); source_map[name] = n; n->name = std::move(name); return n; @@ -66,7 +66,7 @@ TVM_REGISTER_NODE_TYPE(SourceNameNode) }); Span SpanNode::make(SourceName source, int lineno, int col_offset) { - auto n = make_node(); + auto n = make_object(); n->source = std::move(source); n->lineno = lineno; n->col_offset = col_offset; @@ -88,7 +88,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_API("relay._base.set_span") -.set_body_typed([](NodeRef node_ref, Span sp) { +.set_body_typed([](ObjectRef node_ref, Span sp) { auto rn = node_ref.as(); CHECK(rn); rn->span = sp; diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc index 33273f972ea8..7c47c7441dbb 100644 --- a/src/relay/ir/error.cc +++ b/src/relay/ir/error.cc @@ -37,7 +37,7 @@ void RelayErrorStream::Raise() const { } template -using NodeMap = std::unordered_map; +using NodeMap = std::unordered_map; void ErrorReporter::RenderErrors(const Module& module, bool use_color) { // First we pick an error reporting strategy for each error. @@ -46,7 +46,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported"; } - NodeMap> error_maps; + NodeMap> error_maps; // Set control mode in order to produce colors; if (use_color) { @@ -132,7 +132,7 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) { LOG(FATAL) << annotated_prog.str() << std::endl; } -void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) { +void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err) { size_t index_to_insert = this->errors_.size(); this->errors_.push_back(err); auto it = this->node_to_error_.find(node); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index cae35895dbbf..66e083d498cb 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -30,7 +30,7 @@ using tvm::IRPrinter; using namespace tvm::runtime; Constant ConstantNode::make(runtime::NDArray data) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = std::move(data); return Constant(n); } @@ -63,7 +63,7 @@ TensorType ConstantNode::tensor_type() const { } Tuple TupleNode::make(tvm::Array fields) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->fields = std::move(fields); return Tuple(n); } @@ -81,14 +81,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) Var VarNode::make(Id vid, Type type_annotation) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->vid = std::move(vid); n->type_annotation = std::move(type_annotation); return Var(n); } Var VarNode::make(std::string name_hint, Type type_annotation) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); return VarNode::make(Id(n), type_annotation); } @@ -110,7 +110,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); GlobalVar GlobalVarNode::make(std::string name_hint) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); return GlobalVar(n); } @@ -132,7 +132,7 @@ Function FunctionNode::make(tvm::Array params, Type ret_type, tvm::Array type_params, tvm::Attrs attrs) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); CHECK(params.defined()); CHECK(type_params.defined()); n->params = std::move(params); @@ -157,7 +157,7 @@ FuncType FunctionNode::func_type_annotation() const { } bool FunctionNode::IsPrimitive() const { - NodeRef res = FunctionGetAttr(GetRef(this), attr::kPrimitive); + ObjectRef res = FunctionGetAttr(GetRef(this), attr::kPrimitive); const ir::IntImm* pval = res.as(); return pval && pval->value != 0; } @@ -183,13 +183,13 @@ TVM_REGISTER_API("relay._expr.FunctionGetParams") }); bool FunctionNode::UseDefaultCompiler() const { - NodeRef res = FunctionGetAttr(GetRef(this), attr::kCompiler); + ObjectRef res = FunctionGetAttr(GetRef(this), attr::kCompiler); const ir::StringImm* pval = res.as(); return pval == nullptr || pval->value == "default"; } -NodeRef FunctionGetAttr(const Function& func, const std::string& key) { - if (!func->attrs.defined()) { return NodeRef(); } +ObjectRef FunctionGetAttr(const Function& func, const std::string& key) { + if (!func->attrs.defined()) { return ObjectRef(); } const DictAttrsNode* dict_attrs = func->attrs.as(); CHECK(dict_attrs); @@ -197,19 +197,19 @@ NodeRef FunctionGetAttr(const Function& func, const std::string& key) { if (it != dict_attrs->dict.end()) { return (*it).second; } else { - return NodeRef(); + return ObjectRef(); } } -Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data) { +Function FunctionSetAttr(const Function& func, const std::string& key, const ObjectRef& data) { const DictAttrsNode* dattrs = func->attrs.as(); Attrs func_attrs; if (dattrs) { - Map dict = dattrs->dict; + Map dict = dattrs->dict; dict.Set(key, data); func_attrs = DictAttrsNode::make(dict); } else { - Map dict = {{key, data}}; + Map dict = {{key, data}}; func_attrs = DictAttrsNode::make(dict); } @@ -236,7 +236,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) Call CallNode::make(Expr op, Array args, Attrs attrs, Array type_args) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->op = std::move(op); n->args = std::move(args); n->attrs = std::move(attrs); @@ -257,7 +257,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); Let LetNode::make(Var var, Expr value, Expr body) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->var = std::move(var); n->value = std::move(value); n->body = std::move(body); @@ -277,7 +277,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); @@ -297,7 +297,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->tuple = std::move(tuple); n->index = index; return TupleGetItem(n); @@ -315,7 +315,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); RefCreate RefCreateNode::make(Expr value) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->value = std::move(value); return RefCreate(n); } @@ -332,7 +332,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); RefRead RefReadNode::make(Expr ref) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->ref = std::move(ref); return RefRead(n); } @@ -349,7 +349,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); RefWrite RefWriteNode::make(Expr ref, Expr value) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->ref = std::move(ref); n->value = std::move(value); return RefWrite(n); @@ -372,8 +372,8 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") }); TVM_REGISTER_API("relay._expr.FunctionSetAttr") -.set_body_typed( - [](Function func, std::string name, NodeRef ref) { +.set_body_typed( + [](Function func, std::string name, ObjectRef ref) { return FunctionSetAttr(func, name, ref); }); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index ac45d61e873d..e3846c93d49a 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -340,7 +340,7 @@ class ExprApplyVisit : public ExprVisitor { private: std::function f_; - std::unordered_set visited_; + std::unordered_set visited_; }; void PostOrderVisit(const Expr& e, std::function fvisit) { @@ -422,7 +422,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { func->ret_type, func->type_params, func->attrs); - std::unordered_set set; + std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); } @@ -445,7 +445,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { TVM_REGISTER_API("relay._expr.Bind") .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef input = args[0]; + ObjectRef input = args[0]; if (input->IsInstance()) { *ret = Bind(Downcast(input), args[1]); } else { diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index f37b1a4c10be..15f5105808aa 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -47,8 +47,8 @@ class RelayHashHandler: * \param ref The node to hash. * \return the hash value. */ - size_t Hash(const NodeRef& ref) { - if (!ref.defined()) return NodeHash()(ref); + size_t Hash(const ObjectRef& ref) { + if (!ref.defined()) return ObjectHash()(ref); if (ref->IsInstance()) { return TypeHash(Downcast(ref)); @@ -64,9 +64,9 @@ class RelayHashHandler: * \param ref The attributes. * \return the hash value */ - size_t AttrHash(const NodeRef& ref) { + size_t AttrHash(const ObjectRef& ref) { if (!ref.defined()) { - return NodeHash()(ref); + return ObjectHash()(ref); } return AttrsHashHandler::Hash(ref); } @@ -78,7 +78,7 @@ class RelayHashHandler: */ size_t TypeHash(const Type& type) { if (!type.defined()) { - return NodeHash()(type); + return ObjectHash()(type); } auto found = hash_map_.find(type); if (found != hash_map_.end()) { @@ -102,7 +102,7 @@ class RelayHashHandler: */ size_t ExprHash(const Expr& expr) { if (!expr.defined()) { - return NodeHash()(expr); + return ObjectHash()(expr); } auto found = hash_map_.find(expr); if (found != hash_map_.end()) { @@ -221,7 +221,7 @@ class RelayHashHandler: return hash; } - size_t BindVar(const NodeRef& var) { + size_t BindVar(const ObjectRef& var) { size_t hash = std::hash()(var_counter++); CHECK_EQ(hash_map_.count(var), 0); if (auto var_node = var.as()) { @@ -238,7 +238,7 @@ class RelayHashHandler: size_t VisitExpr_(const VarNode* var) final { // hash free variable - size_t name_hash = std::hash()(var->vid.get()); + size_t name_hash = std::hash()(var->vid.get()); return Combine(name_hash, TypeHash(var->type_annotation)); } @@ -308,7 +308,7 @@ class RelayHashHandler: } size_t VisitExpr_(const OpNode* op) final { - return NodeHash()(GetRef(op)); + return ObjectHash()(GetRef(op)); } size_t VisitExpr_(const ConstantNode* rconst) final { @@ -416,7 +416,7 @@ class RelayHashHandler: } private: // renaming of NodeRef to indicate two nodes equals to each other - std::unordered_map hash_map_; + std::unordered_map hash_map_; int var_counter = 0; }; @@ -429,7 +429,7 @@ size_t StructuralHash::operator()(const Expr& expr) const { } TVM_REGISTER_API("relay._analysis._expr_hash") -.set_body_typed([](NodeRef ref) { +.set_body_typed([](ObjectRef ref) { return static_cast(RelayHashHandler().Hash(ref)); }); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 3bd8d59aaf49..2fa79c7b6322 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -38,7 +38,7 @@ Module ModuleNode::make(tvm::Map global_funcs, tvm::Map global_type_defs, std::unordered_set imports ) { - auto n = make_node(); + auto n = make_object(); n->functions = std::move(global_funcs); n->type_definitions = std::move(global_type_defs); n->global_type_var_map_ = {}; @@ -327,14 +327,14 @@ TVM_REGISTER_API("relay._module.Module_Add") .set_body([](TVMArgs args, TVMRetValue* ret) { Module mod = args[0]; GlobalVar var = args[1]; - NodeRef val = args[2]; + ObjectRef val = args[2]; bool update = args[3]; CHECK(val->IsInstance()); if (val->IsInstance()) { mod->Add(var, Downcast(val), update); } else if (val->IsInstance()) { GlobalVar gv = Downcast(val); - auto mod_copy = Module(make_node(*mod.operator->())); + auto mod_copy = Module(make_object(*mod.operator->())); mod_copy = transform::EtaExpand( /* expand_constructor */ false, /* expand_global_var */ true)(mod_copy); auto func = mod_copy->Lookup(gv->name_hint); diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 7b5217d4c066..05788b1a78b5 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -67,7 +67,7 @@ const Op& Op::Get(const std::string& name) { OpRegistry::OpRegistry() { OpManager* mgr = OpManager::Global(); - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->index_ = mgr->op_counter++; op_ = Op(n); } @@ -205,17 +205,17 @@ TVM_REGISTER_API("relay.op._Register") }); // helper to get internal dev function in objectref. -struct Op2NodePtr : public ObjectRef { - static NodePtr Get(const Op& op) { - return GetDataPtr(op); +struct Op2ObjectPtr : public ObjectRef { + static ObjectPtr Get(const Op& op) { + return GetDataPtr(op); } }; -NodePtr CreateOp(const std::string& name) { +ObjectPtr CreateOp(const std::string& name) { // Hack use TVMRetValue as exchange auto op = Op::Get(name); CHECK(op.defined()) << "Cannot find op \'" << name << '\''; - return Op2NodePtr::Get(op); + return Op2ObjectPtr::Get(op); } TVM_REGISTER_NODE_TYPE(OpNode) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 597ef4abee4f..478469c586ef 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -116,14 +116,14 @@ class TextMetaDataContext { * \param node The node to be converted to meta node. * \return A string representation of the meta node. */ - Doc GetMetaNode(const NodeRef& node) { + Doc GetMetaNode(const ObjectRef& node) { auto it = meta_repr_.find(node); if (it != meta_repr_.end()) { return it->second; } std::string type_key = node->GetTypeKey(); CHECK(!type_key.empty()); - Array& mvector = + Array& mvector = meta_data_[type_key]; int64_t index = static_cast(mvector.size()); mvector.push_back(node); @@ -143,7 +143,7 @@ class TextMetaDataContext { */ Doc GetMetaSection() const { if (meta_data_.size() == 0) return Doc(); - return Doc(SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); + return Doc(SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); } /*! \return whether the meta data context is empty. */ @@ -153,9 +153,9 @@ class TextMetaDataContext { private: /*! \brief additional metadata stored in TVM json format */ - std::unordered_map > meta_data_; + std::unordered_map > meta_data_; /*! \brief map from meta data into its string representation */ - std::unordered_map meta_repr_; + std::unordered_map meta_repr_; }; class PrettyPrinter : @@ -191,7 +191,7 @@ class PrettyPrinter : } // indent a new body - Doc PrintBody(const NodeRef& node, int indent = 2) { + Doc PrintBody(const ObjectRef& node, int indent = 2) { Doc doc; Doc body; doc << "{"; @@ -202,7 +202,7 @@ class PrettyPrinter : // create a new scope by creating a new printer object. This allows temp var // numbers to be reused and prevents hoisted vars from escaping too far - Doc PrintScope(const NodeRef& node) { + Doc PrintScope(const ObjectRef& node) { // print in a new scope doc_stack_.push_back(Doc()); // must print first so doc_stack_.back() reference doesn't become stale @@ -212,7 +212,7 @@ class PrettyPrinter : return doc; } - Doc PrintFinal(const NodeRef& node) { + Doc PrintFinal(const ObjectRef& node) { if (node.as()) { Expr expr = Downcast(node); dg_ = DependencyGraph::Create(&arena_, expr); @@ -235,7 +235,7 @@ class PrettyPrinter : std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); std::vector PrintFuncAttrs(const Attrs& attrs); - Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) { + Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) { if (node.as()) { return PrintExpr(Downcast(node), meta, try_inline); } else if (node.as()) { @@ -383,7 +383,7 @@ class PrettyPrinter : Doc printed_expr; if (meta) { - printed_expr = meta_.GetMetaNode(GetRef(expr.get())); + printed_expr = meta_.GetMetaNode(GetRef(expr.get())); } else if (!inline_expr && expr.as()) { // wrap GNFed let in brackets Doc body; @@ -440,7 +440,7 @@ class PrettyPrinter : } // default fall-back, record it as meta node. Doc doc; - return doc << Print(GetRef(op), true); + return doc << Print(GetRef(op), true); } Doc VisitExpr_(const TupleNode* op) final { @@ -624,7 +624,7 @@ class PrettyPrinter : if (it != memo_pattern_.end()) return it->second; Doc printed_pattern; if (meta) { - printed_pattern = meta_.GetMetaNode(GetRef(pattern.get())); + printed_pattern = meta_.GetMetaNode(GetRef(pattern.get())); } else { printed_pattern = VisitPattern(pattern); } @@ -687,7 +687,7 @@ class PrettyPrinter : if (it != memo_type_.end()) return it->second; Doc printed_type; if (meta) { - printed_type = meta_.GetMetaNode(GetRef(type.get())); + printed_type = meta_.GetMetaNode(GetRef(type.get())); } else { printed_type = VisitType(type); } @@ -695,9 +695,9 @@ class PrettyPrinter : return printed_type; } - Doc VisitTypeDefault_(const Node* node) final { + Doc VisitTypeDefault_(const Object* node) final { // by default always print as meta data - return Print(GetRef(node), true); + return Print(GetRef(node), true); } Doc VisitType_(const TypeVarNode* node) final { @@ -728,7 +728,7 @@ class PrettyPrinter : Doc doc; doc << "Tensor[("; std::vector shapes; - for (NodeRef shape : node->shape) { + for (ObjectRef shape : node->shape) { shapes.push_back(PrintAttr(shape)); } doc << PrintSep(shapes); @@ -816,7 +816,7 @@ class PrettyPrinter : if (value.as()) { printed_attr << "?"; } else if (meta) { - printed_attr = meta_.GetMetaNode(Downcast(value)); + printed_attr = meta_.GetMetaNode(Downcast(value)); } else { printed_attr = VisitAttr(value); } @@ -866,11 +866,11 @@ class PrettyPrinter : /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{}; /*! \brief Map from Expr to Doc */ - std::unordered_map memo_; + std::unordered_map memo_; /*! \brief Map from Type to Doc */ - std::unordered_map memo_type_; + std::unordered_map memo_type_; /*! \brief Map from Type to Doc */ - std::unordered_map memo_pattern_; + std::unordered_map memo_pattern_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief meta data context */ @@ -969,7 +969,7 @@ std::vector PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) { return docs; } -std::string PrettyPrint_(const NodeRef& node, +std::string PrettyPrint_(const ObjectRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { Doc doc; @@ -978,20 +978,20 @@ std::string PrettyPrint_(const NodeRef& node, return doc.str(); } -std::string PrettyPrint(const NodeRef& node) { +std::string PrettyPrint(const ObjectRef& node) { Doc doc; doc << PrettyPrinter(false, runtime::TypedPackedFunc()).PrintFinal(node); return doc.str(); } -std::string AsText(const NodeRef& node, +std::string AsText(const ObjectRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { return PrettyPrint_(node, show_meta_data, annotate); } TVM_REGISTER_API("relay._expr.AsText") -.set_body_typed)>(AsText); diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 94e9883d4e41..70071d0445aa 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -30,7 +30,7 @@ using tvm::IRPrinter; using namespace tvm::runtime; TensorType TensorTypeNode::make(Array shape, DataType dtype) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->shape = std::move(shape); n->dtype = std::move(dtype); return TensorType(n); @@ -64,7 +64,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TypeVar TypeVarNode::make(std::string name, Kind kind) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->var = tvm::Var(name); n->kind = std::move(kind); return TypeVar(n); @@ -85,7 +85,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->var = tvm::Var(name); n->kind = std::move(kind); return GlobalTypeVar(n); @@ -106,7 +106,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TypeCall TypeCallNode::make(Type func, tvm::Array args) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->func = std::move(func); n->args = std::move(args); return TypeCall(n); @@ -125,7 +125,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); IncompleteType IncompleteTypeNode::make(Kind kind) { - auto n = make_node(); + auto n = make_object(); n->kind = std::move(kind); return IncompleteType(n); } @@ -147,7 +147,7 @@ FuncType FuncTypeNode::make(tvm::Array arg_types, Type ret_type, tvm::Array type_params, tvm::Array type_constraints) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->arg_types = std::move(arg_types); n->ret_type = std::move(ret_type); n->type_params = std::move(type_params); @@ -172,7 +172,7 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func, Array args, int num_inputs, Attrs attrs) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->func = std::move(func); n->args = std::move(args); n->num_inputs = num_inputs; @@ -194,7 +194,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TupleType TupleTypeNode::make(Array fields) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->fields = std::move(fields); return TupleType(n); } @@ -211,7 +211,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); RefType RefTypeNode::make(Type value) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->value = std::move(value); return RefType(n); } diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index 67c139185ebf..09049cf83f86 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -93,7 +93,7 @@ class TypeFunctor { virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; - virtual R VisitTypeDefault_(const Node* op, Args...) { + virtual R VisitTypeDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; // unreachable, written to stop compiler warning } diff --git a/src/relay/op/algorithm/argsort.cc b/src/relay/op/algorithm/argsort.cc index b64d656b66a0..7a58cfd258a9 100644 --- a/src/relay/op/algorithm/argsort.cc +++ b/src/relay/op/algorithm/argsort.cc @@ -51,7 +51,7 @@ Expr MakeArgsort(Expr data, int axis, bool is_ascend, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->is_ascend = is_ascend; attrs->dtype = dtype; diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index ecb3f7d3be05..055d65bf3252 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -72,7 +72,7 @@ Expr MakeTopK(Expr data, std::string ret_type, bool is_ascend, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->k = k; attrs->axis = axis; attrs->ret_type = ret_type; diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 6835525c3585..9234591659c5 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -41,7 +41,7 @@ TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); TVM_REGISTER_API("relay.op.annotation._make.on_device") .set_body_typed([](Expr data, int device_type) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->device_type = device_type; static const Op& op = Op::Get("on_device"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -87,7 +87,7 @@ TVM_ADD_FILELINE) TVM_REGISTER_NODE_TYPE(CastHintAttrs); Expr CastHint(Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("annotation.cast_hint"); return CallNode::make(op, {data}, Attrs{attrs}, {}); diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc index f7f800fccb10..f592d3ed3f74 100644 --- a/src/relay/op/debug.cc +++ b/src/relay/op/debug.cc @@ -55,7 +55,7 @@ RELAY_REGISTER_OP("debug") .set_attr("FTVMCompute", DebugCompute); Expr MakeDebug(Expr expr, std::string name) { - auto dattrs = make_node(); + auto dattrs = make_object(); if (name.size() > 0) { dattrs->debug_func = EnvFunc::Get(name); } else { diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 3b997a273fa5..290ccef06d99 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -44,7 +44,7 @@ TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); TVM_REGISTER_API("relay.op._make.device_copy") .set_body_typed([](Expr data, int src_dev_type, int dst_dev_type) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->src_dev_type = src_dev_type; attrs->dst_dev_type = dst_dev_type; static const Op& op = Op::Get("device_copy"); diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index a65312316076..f6329f7af709 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -73,7 +73,7 @@ Expr MakeResize(Expr data, std::string method, bool align_corners, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index c535d76838c8..72edeac05399 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -43,7 +43,7 @@ TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs); // being able to see the arguments as well? TVM_REGISTER_API("relay.op.memory._make.alloc_storage") .set_body_typed([](Expr size, Expr alignment, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("memory.alloc_storage"); return CallNode::make(op, {size, alignment}, Attrs(attrs), {}); @@ -90,7 +90,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") TVM_REGISTER_API("relay.op.memory._make.alloc_tensor") .set_body_typed assert_shape)>( [](Expr storage, tvm::relay::Expr shape, DataType dtype, Array assert_shape) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; if (assert_shape.defined()) { attrs->assert_shape = assert_shape; @@ -260,7 +260,7 @@ TVM_REGISTER_API("relay.op.memory._make.shape_func") .set_body_typed)>( [](Expr func, Expr inputs, Expr outputs, Array is_input) { static const Op& op = Op::Get("memory.shape_func"); - auto attrs = make_node(); + auto attrs = make_object(); attrs->is_input = is_input; return CallNode::make(op, {func, inputs, outputs}, Attrs(attrs), {}); }); diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index d651baeccb4c..973ee0b3fe05 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -86,7 +86,7 @@ bool BitPackRel(const Array& types, int num_inputs, const Attrs& attrs, Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack_type, std::string name) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->bits = bits; attrs->pack_axis = pack_axis; attrs->bit_axis = bit_axis; @@ -151,7 +151,7 @@ Expr MakeBinaryConv2D(Expr data, Expr weight, Array strides, Array kernel_size, int activation_bits, int weight_bits, std::string data_layout, std::string kernel_layout, DataType pack_dtype, DataType out_dtype, bool unipolar) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->channels = std::move(channels); @@ -224,7 +224,7 @@ bool BinaryDenseRel(const Array& types, int num_inputs, const Attrs& attrs // Positional relay function to create bitserial dense operator used by frontend FFI. Expr MakeBinaryDense(Expr data, Expr weight, IndexExpr units, int data_bits, int weight_bits, DataType pack_dtype, DataType out_dtype, bool unipolar) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->units = units; attrs->data_bits = data_bits; attrs->weight_bits = weight_bits; diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 4a1fd466108d..40c24462c8f7 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -66,7 +66,7 @@ Expr MakeConv2D(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -124,7 +124,7 @@ Expr MakeConv3D(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -289,7 +289,7 @@ Expr MakeConv2DTranspose(Expr data, std::string out_layout, Array output_padding, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->channels = std::move(channels); attrs->kernel_size = std::move(kernel_size); attrs->strides = std::move(strides); @@ -448,7 +448,7 @@ Expr MakeConv1DTranspose(Expr data, std::string out_layout, Array output_padding, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->channels = std::move(channels); attrs->kernel_size = std::move(kernel_size); attrs->strides = std::move(strides); @@ -595,7 +595,7 @@ Expr MakeConv2DWinograd(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->tile_size = tile_size; attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -668,7 +668,7 @@ bool Conv2DWinogradWeightTransformRel(const Array& types, Expr MakeConv2DWinogradWeightTransform(Expr weight, int tile_size) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->tile_size = tile_size; static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform"); return CallNode::make(op, {weight}, Attrs(attrs), {}); @@ -708,7 +708,7 @@ Expr MakeConv2DWinogradNNPACK(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -783,7 +783,7 @@ bool Conv2DWinogradNNPACKWeightTransformRel(const Array& types, Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, int convolution_algorithm, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->convolution_algorithm = convolution_algorithm; attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_weight_transform"); @@ -821,7 +821,7 @@ Expr MakeConv2DNCHWcInt8(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -870,7 +870,7 @@ Expr MakeConv2DNCHWc(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -920,7 +920,7 @@ Expr MakeDepthwiseConv2DNCHWc(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -1079,7 +1079,7 @@ Expr MakeDeformableConv2D(Expr data, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = strides; attrs->padding = padding; attrs->dilation = dilation; diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index dfb360a2dec0..79c3e687db36 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -72,7 +72,7 @@ bool BiasAddRel(const Array& types, Expr MakeBiasAdd(Expr data, Expr bias, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.bias_add"); return CallNode::make(op, {data, bias}, Attrs(attrs), {}); @@ -104,7 +104,7 @@ RELAY_REGISTER_OP("nn.bias_add") TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs); Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.fifo_buffer"); return CallNode::make(op, {input, buffer}, Attrs(attrs), {}); @@ -175,7 +175,7 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("nn.dense"); @@ -208,7 +208,7 @@ TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); // Positional relay function to create leaky relu operator used by frontend FFI. Expr MakeLeakyRelu(Expr data, double alpha) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->alpha = alpha; static const Op& op = Op::Get("nn.leaky_relu"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -288,7 +288,7 @@ Array > PReluInferCorrectLayout( Expr MakePRelu(Expr data, Expr alpha, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.prelu"); return CallNode::make(op, {data, alpha}, Attrs(attrs), {}); @@ -327,7 +327,7 @@ TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); TVM_REGISTER_API("relay.op.nn._make.softmax") .set_body_typed([](Expr data, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.softmax"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -362,7 +362,7 @@ RELAY_REGISTER_OP("nn.softmax") // relay.nn.log_softmax TVM_REGISTER_API("relay.op.nn._make.log_softmax") .set_body_typed([](Expr data, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.log_softmax"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -504,7 +504,7 @@ Expr MakeLRN(Expr data, double alpha, double beta, double bias) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->size = size; attrs->axis = axis; attrs->alpha = alpha; @@ -545,7 +545,7 @@ TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs); Expr MakeL2Normalize(Expr data, double eps, Array axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->eps = eps; attrs->axis = std::move(axis); static const Op& op = Op::Get("nn.l2_normalize"); @@ -591,7 +591,7 @@ bool DropoutRel(const Array& types, } Expr MakeDropout(Expr data, double rate) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->rate = rate; static const Op& op = Op::Get("nn.dropout"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -680,7 +680,7 @@ bool BatchNormRel(const Array& types, Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, int axis, double epsilon, bool center, bool scale) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; @@ -763,7 +763,7 @@ bool InstanceNormRel(const Array& types, Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, bool scale) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; @@ -833,7 +833,7 @@ bool LayerNormRel(const Array& types, Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, bool scale) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; @@ -1024,7 +1024,7 @@ bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attr // Positional relay function to create DepthToSpace operator // used by frontend FFI Expr MakeDepthToSpace(Expr data, int block_size, std::string layout, std::string mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->block_size = block_size; attrs->layout = std::move(layout); attrs->mode = std::move(mode); @@ -1082,7 +1082,7 @@ bool SpaceToDepthRel(const Array& types, int num_inputs, const Attrs& attr // Positional relay function to create SpaceToDepth operator // used by frontend FFI Expr MakeSpaceToDepth(Expr data, int block_size, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->block_size = block_size; attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.space_to_depth"); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 519619f8812a..5cde41446fe6 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -192,7 +192,7 @@ Expr MakePad(Expr data, Array > pad_width, double pad_value, std::string pad_mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pad_value = pad_value; attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); @@ -267,7 +267,7 @@ bool MirrorPadRel(const Array& types, // Handler to create a call to the padding op used by front-end FFI Expr MakeMirrorPad(Expr data, Array > pad_width, std::string mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->mode = mode; attrs->pad_width = std::move(pad_width); static const Op& op = Op::Get("nn.mirror_pad"); diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index e7529a9d7bb9..00216900e2b5 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -63,7 +63,7 @@ Expr MakeMaxPool(Expr data, std::string layout, bool ceil_mode, std::string op_name) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -82,7 +82,7 @@ Expr MakeAvgPool(Expr data, bool ceil_mode, bool count_include_pad, std::string op_name) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -359,7 +359,7 @@ Array GlobalPool2DCompute(const Attrs& attrs, Expr MakeGlobalAvgPool2D(Expr data, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_avg_pool2d"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -391,7 +391,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") // GlobalMaxPool Expr MakeGlobalMaxPool2D(Expr data, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_max_pool2d"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -511,7 +511,7 @@ Array AdaptivePool2DCompute(const Attrs& attrs, Expr MakeAdaptiveAvgPool2D(Expr data, Array output_size, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); static const Op& op = Op::Get("contrib.adaptive_avg_pool2d"); @@ -550,7 +550,7 @@ RELAY_REGISTER_OP("contrib.adaptive_avg_pool2d") Expr MakeAdaptiveMaxPool2D(Expr data, Array output_size, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); static const Op& op = Op::Get("contrib.adaptive_max_pool2d"); @@ -647,7 +647,7 @@ Array Pool2DGradCompute(const Attrs& attrs, const Array& inputs, // MaxPool2DGrad Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, Array strides, Array padding, std::string layout, bool ceil_mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -695,7 +695,7 @@ RELAY_REGISTER_OP("nn.max_pool2d_grad") Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, Array strides, Array padding, std::string layout, bool ceil_mode, bool count_include_pad) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 7cf8a27f3b56..fc22725977f2 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -65,7 +65,7 @@ bool SparseDenseRel(const Array& types, int num_inputs, const Attrs& attrs // Positional relay function to create dense operator used by frontend FFI. Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) { - auto attrs = make_node(); + auto attrs = make_object(); static const Op& op = Op::Get("nn.sparse_dense"); return CallNode::make(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); } @@ -114,7 +114,7 @@ bool SparseTransposeRel(const Array& types, int num_inputs, const Attrs& a } Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) { - auto attrs = make_node(); + auto attrs = make_object(); static const Op& op = Op::Get("nn.sparse_transpose"); return CallNode::make(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {}); } diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 61b40588c3d7..2ba7b2f7bcf4 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -102,7 +102,7 @@ Expr MakeUpSampling(Expr data, std::string layout, std::string method, bool align_corners) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->scale_h = scale_h; @@ -182,7 +182,7 @@ Expr MakeUpSampling3D(Expr data, std::string layout, std::string method, std::string coordinate_transformation_mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->scale_d = scale_d; diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 42c1fc485a63..53495ccff15d 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -145,7 +145,7 @@ class OpMatch { private: /*! \brief The match function map. */ - std::unordered_map match_map_; + std::unordered_map match_map_; /*! \brief An optional default case. */ MatchFunc default_; }; diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 7e499bae7683..4e9a900c7cd6 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -308,7 +308,7 @@ bool ReduceRel(const Array& types, Array axis, \ bool keepdims, \ bool exclude) { \ - auto attrs = make_node(); \ + auto attrs = make_object(); \ attrs->axis = std::move(axis); \ attrs->keepdims = keepdims; \ attrs->exclude = exclude; \ @@ -625,7 +625,7 @@ Expr MakeVariance(Expr data, Array axis, bool keepdims, bool exclude) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; attrs->exclude = exclude; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ff018e43aea7..7407f21e8e9a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -77,7 +77,7 @@ Array CastCompute(const Attrs& attrs, Expr MakeCast(Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("cast"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -165,7 +165,7 @@ Array ReinterpretCompute(const Attrs& attrs, const Array& inputs } Expr MakeReinterpret(Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("reinterpret"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -242,7 +242,7 @@ Array ExpandDimsCompute(const Attrs& attrs, Expr MakeExpandDims(Expr data, int axis, int num_newaxis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->num_newaxis = num_newaxis; static const Op& op = Op::Get("expand_dims"); @@ -328,7 +328,7 @@ Array> ConcatenateLayout( Expr MakeConcatenate(Expr data, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("concatenate"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -423,7 +423,7 @@ Array StackCompute(const Attrs& attrs, Expr MakeStack(Expr data, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("stack"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -515,7 +515,7 @@ Array TransposeCompute(const Attrs& attrs, Expr MakeTranspose(Expr data, Array axes) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("transpose"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -706,7 +706,7 @@ Array ReshapeCompute(const Attrs& attrs, Expr MakeReshape(Expr data, Array newshape) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = false; static const Op& op = Op::Get("reshape"); @@ -860,7 +860,7 @@ bool ArgWhereRel(const Array& types, TVM_REGISTER_API("relay.op._make.argwhere") .set_body_typed([](Expr data) { static const Op& op = Op::Get("argwhere"); - auto attrs = make_node(); + auto attrs = make_object(); return CallNode::make(op, {data}, Attrs(attrs), {}); }); @@ -938,7 +938,7 @@ Expr MakeTake(Expr data, Expr indices, Integer axis, std::string mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->mode = std::move(mode); static const Op& op = Op::Get("take"); @@ -1019,7 +1019,7 @@ Array FullCompute(const Attrs& attrs, Expr MakeFull(Expr fill_value, Array shape, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("full"); @@ -1054,7 +1054,7 @@ bool InitOpRel(const Array& types, Expr MakeZeros(Array shape, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("zeros"); @@ -1075,7 +1075,7 @@ RELAY_REGISTER_OP("zeros") Expr MakeOnes(Array shape, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("ones"); @@ -1244,7 +1244,7 @@ Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->start = start; attrs->stop = stop; attrs->step = step; @@ -1335,7 +1335,7 @@ Array RepeatCompute(const Attrs& attrs, Expr MakeRepeat(Expr data, int repeats, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->repeats = repeats; attrs->axis = axis; static const Op& op = Op::Get("repeat"); @@ -1445,7 +1445,7 @@ Array TileCompute(const Attrs& attrs, Expr MakeTile(Expr data, Array reps) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->reps = reps; static const Op& op = Op::Get("tile"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -1506,7 +1506,7 @@ Array ReverseCompute(const Attrs& attrs, Expr MakeReverse(Expr data, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("reverse"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -1623,7 +1623,7 @@ TVM_REGISTER_NODE_TYPE(SqueezeAttrs); Expr MakeSqueeze(Expr data, Array axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("squeeze"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -1764,7 +1764,7 @@ bool BroadCastToRel(const Array& types, Expr MakeBroadCastTo(Expr data, Array shape) { static const Op& op = Op::Get("broadcast_to"); - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); return CallNode::make(op, {data}, Attrs(attrs), {}); } @@ -2006,7 +2006,7 @@ Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->begin = std::move(begin); attrs->end = std::move(end); attrs->strides = std::move(strides); @@ -2189,9 +2189,9 @@ Array SplitCompute(const Attrs& attrs, } Expr MakeSplit(Expr data, - NodeRef indices_or_sections, + ObjectRef indices_or_sections, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); static const Op& op = Op::Get("split"); @@ -2294,7 +2294,7 @@ bool SliceLikeRel(const Array& types, Expr MakeSliceLike(Expr data, Expr shape_like, Array axes) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("slice_like"); return CallNode::make(op, {data, shape_like}, Attrs(attrs), {}); @@ -2403,7 +2403,7 @@ bool LayoutTransformRel(const Array& types, Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->src_layout = std::move(src_layout); attrs->dst_layout = std::move(dst_layout); static const Op& op = Op::Get("layout_transform"); @@ -2431,7 +2431,7 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] /* relay._contrib_reverse_reshape */ Expr MakeReverseReshape(Expr data, Array newshape) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = true; static const Op& op = Op::Get("_contrib_reverse_reshape"); @@ -2566,7 +2566,7 @@ Expr MakeSequenceMask(Expr data, Expr valid_length, double mask_value, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->mask_value = std::move(mask_value); attrs->axis = std::move(axis); static const Op& op = Op::Get("sequence_mask"); @@ -2687,7 +2687,7 @@ Expr MakeOneHot(Expr indices, int depth, int axis, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->depth = std::move(depth); attrs->axis = axis; attrs->dtype = dtype; diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 710c910794c8..d4cd7be807b1 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -159,7 +159,7 @@ TVM_REGISTER_NODE_TYPE(ClipAttrs); TVM_REGISTER_API("relay.op._make.clip") .set_body_typed([](Expr a, double a_min, double a_max) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->a_min = a_min; attrs->a_max = a_max; static const Op& op = Op::Get("clip"); @@ -302,7 +302,7 @@ Array ShapeOfCompute(const Attrs& attrs, TVM_REGISTER_API("relay.op._make.shape_of") .set_body_typed([](Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("shape_of"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -353,7 +353,7 @@ Array NdarraySizeCompute(const Attrs& attrs, TVM_REGISTER_API("relay.op.contrib._make.ndarray_size") .set_body_typed([](Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("contrib.ndarray_size"); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index 28289e76810f..2dd09403f144 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -60,7 +60,7 @@ Expr MakeMultiBoxPrior(Expr data, Array steps, Array offsets, bool clip) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->sizes = std::move(sizes); attrs->ratios = std::move(ratios); attrs->steps = std::move(steps); @@ -135,7 +135,7 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob, bool clip, double threshold, Array variances) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->clip = std::move(clip); attrs->threshold = std::move(threshold); attrs->variances = std::move(variances); diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index cba5b6bc7c50..6759e186eeda 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -52,7 +52,7 @@ Expr MakeGetValidCounts(Expr data, double score_threshold, int id_index, int score_index) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->score_threshold = score_threshold; attrs->id_index = id_index; attrs->score_index = score_index; @@ -114,7 +114,7 @@ Expr MakeNMS(Expr data, int id_index, bool return_indices, bool invalid_to_bottom) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->max_output_size = max_output_size; attrs->iou_threshold = iou_threshold; attrs->force_suppress = force_suppress; diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 52440969ae59..24f4b98b8ed0 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -51,7 +51,7 @@ bool ROIAlignRel(const Array& types, int num_inputs, const Attrs& attrs, Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spatial_scale, int sample_ratio, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pooled_size = pooled_size; attrs->spatial_scale = spatial_scale; attrs->sample_ratio = sample_ratio; @@ -102,7 +102,7 @@ bool ROIPoolRel(const Array& types, int num_inputs, const Attrs& attrs, Expr MakeROIPool(Expr data, Expr rois, Array pooled_size, double spatial_scale, std::string layout) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pooled_size = pooled_size; attrs->spatial_scale = spatial_scale; attrs->layout = layout; @@ -163,7 +163,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array Array ratios, int feature_stride, double threshold, int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size, bool iou_loss) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->scales = scales; attrs->ratios = ratios; attrs->feature_stride = feature_stride; diff --git a/src/relay/op/vision/yolo.cc b/src/relay/op/vision/yolo.cc index fe0684376c39..74b59f649ccd 100644 --- a/src/relay/op/vision/yolo.cc +++ b/src/relay/op/vision/yolo.cc @@ -62,7 +62,7 @@ bool YoloReorgRel(const Array& types, Expr MakeYoloReorg(Expr data, Integer stride) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->stride = stride; static const Op& op = Op::Get("vision.yolo_reorg"); return CallNode::make(op, {data}, Attrs(attrs), {}); diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index bd89c5123bd7..b3b08c145101 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -107,8 +107,8 @@ class AlterTransformMemorizer : public TransformMemorizer { * 2. Do not support nested tuple arguments. */ Expr AlterOpLayout(const Expr& expr) { - AlterTransformMemorizer alterMemorizer(make_node()); - auto fcontext = [&](const Call& call) -> NodeRef { return alterMemorizer; }; + AlterTransformMemorizer alterMemorizer(make_object()); + auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; }; return ForwardRewrite(expr, LayoutRewriter, fcontext); } diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc index 6913eb2d80c5..c790659012ee 100644 --- a/src/relay/pass/canonicalize_cast.cc +++ b/src/relay/pass/canonicalize_cast.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -92,12 +92,11 @@ class CastCanonicalizer : public ExprMutator { } private: - std::unordered_map ref_counter_; + std::unordered_map ref_counter_; // cast op is frequently checked for equivalence. Therefore, we cache it to // reduce lookup overhead. const Op& cast_op_; - Expr GetNewCallArg(const Expr& e) { // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor Expr new_expr = this->VisitExpr(e); diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 109d86e806f6..e5c253e4ac56 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -91,7 +91,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { const CallNode* group_root = branches[0][0]; const auto* attrs = group_root->attrs.as(); CHECK(attrs); - const auto new_attrs = make_node(); + const auto new_attrs = make_object(); new_attrs->strides = attrs->strides; new_attrs->padding = attrs->padding; new_attrs->dilation = attrs->dilation; diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h index 858926e662e6..619a153595b7 100644 --- a/src/relay/pass/combine_parallel_op.h +++ b/src/relay/pass/combine_parallel_op.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -46,10 +46,10 @@ using Branch = std::vector; using Group = std::vector; using FIsSupportedOp = std::function; using FAreCompatibleOps = std::function; -using ExprSubstMap = std::unordered_map; +using ExprSubstMap = std::unordered_map; /* - * Class to find parallel branches starting with op that are + * Class to find parallel branches starting with op that are * grouped if they are able to be combined. They are eligible to * be combined if they have the same input data. * Op can be followed by zero or more elemwise or broadcast ops, @@ -91,22 +91,22 @@ class BranchGroupFinder : private ExprVisitor { const Op& cached_op_; /* \brief function to return true if op is eligible to be combined, - * false otherwise + * false otherwise */ FIsSupportedOp fis_supported_op_; /* \brief function to return true if two parallel ops are eligible - * to be combined, false otherwise + * to be combined, false otherwise */ FAreCompatibleOps fare_compatible_ops_; /* \brief ops that are on the first (logically, leftmost) branch * of parallel ops and are eligible to be combined */ - std::unordered_set op_roots_; + std::unordered_set op_roots_; /* \brief map of Expr to CallNodes that follow it */ - std::unordered_map, NodeHash, NodeEqual> children_map_; + std::unordered_map, ObjectHash, ObjectEqual> children_map_; /* * \brief Creates new branch from op and its children that have diff --git a/src/relay/pass/convert_layout.cc b/src/relay/pass/convert_layout.cc index fa8b8722f814..8b223ee100d1 100644 --- a/src/relay/pass/convert_layout.cc +++ b/src/relay/pass/convert_layout.cc @@ -117,8 +117,8 @@ class ConvertTransformMemorizer : public TransformMemorizer { */ Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) { ConvertTransformMemorizer transformMemorizer( - make_node(desired_layout)); - auto fcontext = [&](const Call& call) -> NodeRef { return transformMemorizer; }; + make_object(desired_layout)); + auto fcontext = [&](const Call& call) -> ObjectRef { return transformMemorizer; }; return ForwardRewrite(expr, LayoutRewriter, fcontext); } diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc index af25e9fbac5d..6816cc7d2d83 100644 --- a/src/relay/pass/de_duplicate.cc +++ b/src/relay/pass/de_duplicate.cc @@ -104,8 +104,8 @@ Expr DeDup(const Expr& e) { } private: - std::unordered_map rename_; - std::unordered_map type_rename_; + std::unordered_map rename_; + std::unordered_map type_rename_; }; CHECK(WellFormed(e)) << AsText(e, false); Expr ret = DeDupMutator().VisitExpr(e); diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index df16baeeed7b..14bca58cf328 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -36,8 +36,8 @@ namespace tvm { namespace relay { template -using VarMap = std::unordered_map; -using VarSet = std::unordered_set; +using VarMap = std::unordered_map; +using VarSet = std::unordered_set; class CalcDep; class FindDef : private ExprVisitor { diff --git a/src/relay/pass/dependency_graph.cc b/src/relay/pass/dependency_graph.cc index 42b829fc3c73..81c205a33c2f 100644 --- a/src/relay/pass/dependency_graph.cc +++ b/src/relay/pass/dependency_graph.cc @@ -64,7 +64,7 @@ class DependencyGraph::Creator : private ExprFunctor { parent->children.Push(child_link); } - std::unordered_set visited_; + std::unordered_set visited_; DependencyGraph::Node* NewNode(bool new_scope) { auto* ret = arena_->make(); diff --git a/src/relay/pass/dependency_graph.h b/src/relay/pass/dependency_graph.h index 6b2af7e156a8..d6a4e9588df9 100644 --- a/src/relay/pass/dependency_graph.h +++ b/src/relay/pass/dependency_graph.h @@ -54,7 +54,7 @@ class DependencyGraph { }; /*! \brief Maps a Relay Expr to its node in the dependency graph. */ - std::unordered_map expr_node; + std::unordered_map expr_node; /*! \brief The dependency graph in post DFS order. */ std::vector post_dfs_order; diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 6ad04b0e15e4..91a7fa315f5d 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -280,7 +280,7 @@ class RewriteAnnotation : public ExprMutator { * \return The created call node. */ Call CreateDeviceCopy(const Expr& src, int src_dev_type, int dst_dev_type) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->src_dev_type = src_dev_type; attrs->dst_dev_type = dst_dev_type; static const Op& op = Op::Get("device_copy"); diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index 07827d2c8e14..d180fcc150be 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -76,7 +76,7 @@ class CommonSubexprEliminator : public ExprMutator { return new_expr; } - std::unordered_map, NodeHash, NodeEqual> expr_map_; + std::unordered_map, ObjectHash, ObjectEqual> expr_map_; runtime::TypedPackedFunc fskip_; }; diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index dca08cc834d1..888874cf0f75 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -49,7 +49,7 @@ class TypeVarReplacer : public TypeMutator { private: /*! \brief variable replacement map to remap old type vars to fresh ones */ - std::unordered_map replace_map_; + std::unordered_map replace_map_; }; /*! diff --git a/src/relay/pass/expr_subst.cc b/src/relay/pass/expr_subst.cc index baca63233338..d3e6aa8dbfe6 100644 --- a/src/relay/pass/expr_subst.cc +++ b/src/relay/pass/expr_subst.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,7 +30,7 @@ namespace relay { class ExprSubstituter : public ExprMutator { public: - explicit ExprSubstituter(std::unordered_map subst_map) + explicit ExprSubstituter(std::unordered_map subst_map) : subst_map_(subst_map) {} Expr VisitExpr(const Expr& expr) final { @@ -45,7 +45,8 @@ class ExprSubstituter : public ExprMutator { tvm::Map subst_map_; }; -Expr ExprSubst(const Expr& expr, std::unordered_map subst_map) { +Expr ExprSubst(const Expr& expr, + std::unordered_map subst_map) { return ExprSubstituter(std::move(subst_map)).Mutate(expr); } diff --git a/src/relay/pass/expr_subst.h b/src/relay/pass/expr_subst.h index bc53d3f51be0..2ffefa25657d 100644 --- a/src/relay/pass/expr_subst.h +++ b/src/relay/pass/expr_subst.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,7 +29,8 @@ namespace tvm { namespace relay { -Expr ExprSubst(const Expr& expr, std::unordered_map subst_map); +Expr ExprSubst(const Expr& expr, + std::unordered_map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc index d610f9523f8a..79830a709fe0 100644 --- a/src/relay/pass/feature.cc +++ b/src/relay/pass/feature.cc @@ -36,7 +36,7 @@ FeatureSet DetectFeature(const Expr& expr) { return FeatureSet::No(); } struct FeatureDetector : ExprVisitor { - std::unordered_set visited_; + std::unordered_set visited_; FeatureSet fs = FeatureSet::No(); void VisitExpr(const Expr& expr) final { diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 1e22571f6b43..4a6417b174e3 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -52,7 +52,7 @@ class ConstantChecker : private ExprVisitor { } private: - std::unordered_map memo_; + std::unordered_map memo_; void VisitExpr_(const TupleNode* n) final { bool result = true; @@ -266,7 +266,7 @@ class ConstantFolder : public ExprMutator { } // Cast the constant into correct dtype - auto cast_attrs = make_node(); + auto cast_attrs = make_object(); cast_attrs->dtype = param->dtype; Expr ret = CallNode::make(cast_op_, { shape }, Attrs(cast_attrs), {}); return ConstEvaluate(ret); diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index e13a50a99c58..711297ca1883 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -95,13 +95,16 @@ class MessageNode : public RelayNode { static Message make(const AxesSet& axes, bool require_positive); static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message"; - TVM_DECLARE_NODE_TYPE_INFO(MessageNode, RelayNode); + TVM_DECLARE_FINAL_OBJECT_INFO(MessageNode, RelayNode); }; -RELAY_DEFINE_NODE_REF(Message, MessageNode, NodeRef); +class Message : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode); +}; Message MessageNode::make(const AxesSet& axes, bool require_positive) { - auto n = make_node(); + auto n = make_object(); n->axes = axes; n->require_positive = require_positive; return Message(n); @@ -183,7 +186,7 @@ class ScaledExprNode : public TempExprNode { } static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr"; - TVM_DECLARE_NODE_TYPE_INFO(ScaledExprNode, TempExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ScaledExprNode, TempExprNode); }; using FForwardRewrite = TypedPackedFunc< @@ -196,7 +199,7 @@ using FForwardRewrite = TypedPackedFunc< //---------------------------------------------- class ForwardPrep : private ExprVisitor { public: - std::unordered_map + std::unordered_map Prepare(const Expr& body) { this->Update(body, NullValue()); this->VisitExpr(body); @@ -215,7 +218,7 @@ class ForwardPrep : private ExprVisitor { // The invoke list std::vector > flist_; // The message on each node. - std::unordered_map message_; + std::unordered_map message_; // Update the message stored at node. void Update(const Expr& node, const Message& message) { // We run intersection of messages: @@ -228,7 +231,7 @@ class ForwardPrep : private ExprVisitor { // because %z2 will propagate null to %y, // the AxesSet on %y is also null, // and the forward folding won't be triggered. - const Node* key = node.get(); + const Object* key = node.get(); if (message_.count(key)) { message_[key] = Intersect(message_[key], message); } else { @@ -323,7 +326,7 @@ Expr ReluForwardRewrite(const Call& ref_call, const auto* input = new_args[0].as(); if (input == nullptr) return Expr(nullptr); // return transformed conv2d - auto rnode = make_node(); + auto rnode = make_object(); rnode->value = CallNode::make( ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); rnode->scale = input->scale; @@ -366,7 +369,7 @@ Expr AddSubForwardRewrite(const Call& ref_call, if (!slhs && !srhs) return Expr(); const auto* tlhs = ref_call->args[0]->type_as(); const auto* trhs = ref_call->args[1]->type_as(); - auto rnode = make_node(); + auto rnode = make_object(); if (slhs != nullptr) { CHECK(srhs == nullptr); @@ -422,7 +425,7 @@ Expr MultiplyForwardRewrite(const Call& ref_call, const auto* trhs = ref_call->args[1]->type_as(); Expr lhs = new_args[0]; Expr rhs = new_args[1]; - auto rnode = make_node(); + auto rnode = make_object(); if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs) && (!message->require_positive || IsAllPositiveConstant(rhs))) { @@ -531,12 +534,12 @@ RELAY_REGISTER_OP("nn.conv2d") Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); - auto fcontext = [&](const Call& call) -> NodeRef{ + auto fcontext = [&](const Call& call) -> ObjectRef{ auto it = message.find(call.get()); if (it != message.end()) { return it->second; } else { - return NodeRef(nullptr); + return ObjectRef(nullptr); } }; return ForwardRewrite( @@ -571,7 +574,7 @@ using FBackwardTransform = TypedPackedFunc< class BackwardPrep : private ExprVisitor { public: // The message on each node. - std::unordered_map + std::unordered_map Prepare(const Expr& body) { ref_counter_ = GetExprRefCount(body); this->VisitExpr(body); @@ -580,9 +583,9 @@ class BackwardPrep : private ExprVisitor { private: // The message on each node. - std::unordered_map message_; + std::unordered_map message_; // reference counter of an internal expr - std::unordered_map ref_counter_; + std::unordered_map ref_counter_; // Visit the expression. void VisitExpr_(const CallNode* call) { ExprVisitor::VisitExpr_(call); @@ -612,7 +615,7 @@ class BackwardPrep : private ExprVisitor { }; class BackwardTransformerNode : - public Node, + public Object, private ExprMutator { public: // Run forward transform. @@ -667,11 +670,11 @@ class BackwardTransformerNode : void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer"; - TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(BackwardTransformerNode, Object); private: // Valid axes on each node. - std::unordered_map message_; + std::unordered_map message_; // Override mutation of call. Expr VisitExpr_(const CallNode* call_node) final { return Transform(call_node, NullValue(), NullValue()); @@ -680,11 +683,11 @@ class BackwardTransformerNode : Expr Transform(const CallNode* call_node, Message message, Expr scale); }; -class BackwardTransformer : public NodeRef { +class BackwardTransformer : public ObjectRef { public: BackwardTransformer() {} explicit BackwardTransformer( - ::tvm::ObjectPtr<::tvm::Object> n) : NodeRef(n) { + ::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { } BackwardTransformerNode* operator->() const { return static_cast(get_mutable()); @@ -938,7 +941,7 @@ RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); Expr BackwardFoldScaleAxis(const Expr& data) { - return make_node()->Fold(data); + return make_object()->Fold(data); } } // namespace fold_scale_axis diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index fe5cc36cba95..fe0df010b626 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -61,14 +61,14 @@ class TempRealizer : private ExprMutator { class ForwardRewriter : private ExprMutator { public: ForwardRewriter(const OpMap* rewrite_map, - std::function fcontext, + std::function fcontext, std::function fmulti_ref_trigger) : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} ForwardRewriter(const FForwardRewrite* rewrite_func, - std::function fcontext, + std::function fcontext, std::function fmulti_ref_trigger) : rewrite_func_(rewrite_func), fcontext_(fcontext), @@ -88,11 +88,11 @@ class ForwardRewriter : private ExprMutator { const OpMap* rewrite_map_{nullptr}; const FForwardRewrite* rewrite_func_{nullptr}; // The context.const - std::function fcontext_{nullptr}; + std::function fcontext_{nullptr}; // The multiple reference trigger std::function fmulti_ref_trigger_{nullptr}; // Internal ref counter - std::unordered_map ref_counter_; + std::unordered_map ref_counter_; // internal realizer TempRealizer realizer_; @@ -172,7 +172,7 @@ class ForwardRewriter : private ExprMutator { if (frewrite != nullptr) { Expr res = frewrite( ref_call, call_args, - fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr)); + fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr)); if (res.defined()) return res; // abort, use old rule for (size_t i = 0; i < call_args.size(); ++i) { @@ -192,7 +192,7 @@ class ForwardRewriter : private ExprMutator { Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_name, - std::function fcontext, + std::function fcontext, std::function fmulti_ref_trigger) { auto rewrite_map = Op::GetAttr(rewrite_map_name); return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr); @@ -200,7 +200,7 @@ Expr ForwardRewrite(const Expr& expr, Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, - std::function fcontext, + std::function fcontext, std::function fmulti_ref_trigger) { return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); } diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 8209a8010b98..7b8f6de74382 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -103,7 +103,7 @@ class IndexedForwardGraph { /*! \brief A node in the graph. */ struct Node { /*! \brief weak reference to the corresponding edge. */ - const tvm::Node* ref{nullptr}; + const tvm::Object* ref{nullptr}; /*! \brief The index of the node in topological order. */ size_t index{0}; /*! \brief Whether this node is referenced by external source */ @@ -114,7 +114,7 @@ class IndexedForwardGraph { LinkedList outputs; }; /*! \brief The node map that maps node to graph */ - std::unordered_map node_map; + std::unordered_map node_map; /*! \brief All the nodes in post DFS order */ std::vector post_dfs_order; @@ -124,7 +124,7 @@ class IndexedForwardGraph { for (size_t i = 0; i < post_dfs_order.size(); ++i) { Node* node = post_dfs_order[i]; os << "node[" << i << "], " - << GetRef(node->ref) + << GetRef(node->ref) << " outputs=["; for (auto* link = node->outputs.head; link != nullptr; link = link->next) { os << link->value.node->index << ", "; @@ -167,7 +167,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) { - const tvm::Node* key = node.get(); + const tvm::Object* key = node.get(); IndexedForwardGraph::Node* current; auto it = graph_.node_map.find(key); if (it != graph_.node_map.end()) { @@ -186,10 +186,10 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } } - void AddNode(const tvm::Node* key) { + void AddNode(const tvm::Object* key) { auto it = graph_.node_map.find(key); CHECK(it != graph_.node_map.end()) - << "Cannot find node " << GetRef(key); + << "Cannot find node " << GetRef(key); IndexedForwardGraph::Node* node = it->second; CHECK(node->ref == nullptr); node->ref = key; @@ -523,12 +523,12 @@ class GraphPartitioner { /*! \brief The pattern of the group */ OpPatternKind pattern; /*! \brief reference to the root node. */ - const tvm::Node* root_ref{nullptr}; + const tvm::Object* root_ref{nullptr}; /*! * \brief Reference to the master node, * this field is not nullptr only if pattern is kOutEWiseFusable. */ - const tvm::Node* master_ref{nullptr}; + const tvm::Object* master_ref{nullptr}; /*! * \brief Find the group root, perform path compression * \return The root type node. @@ -847,7 +847,7 @@ class FuseMutator : private ExprMutator { /*! \brief Internal arena. */ common::Arena arena_; /*! \brief The group assignment map. */ - std::unordered_map gmap_; + std::unordered_map gmap_; /* \brief Internal group information map. */ std::unordered_map ginfo_; diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 69d12c26f103..61f7e2d8979d 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -133,7 +133,7 @@ struct FirstOrderReverseAD : ExprFunctor { const OpMap rev_map = Op::GetAttr("FPrimalGradient"); std::vector> backprop_actions; // we assume no closure so no need for lexical scoping - std::unordered_map env; + std::unordered_map env; LetList* ll; FirstOrderReverseAD(LetList* ll) : ll(ll) { } @@ -385,7 +385,7 @@ Expr BPEmpty() { } struct ReverseAD : ExprMutator { - using ADVarMap = std::unordered_map; + using ADVarMap = std::unordered_map; Var bp; std::shared_ptr ad_vars; diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index afcc4935fa41..7a524ee23318 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -63,7 +63,7 @@ * so we have to deduplicate them. * * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id. - * While it is permitted, most pass use NodeHash for Var, + * While it is permitted, most pass use ObjectHash for Var, * and having multiple VarNode for same Id break them. * Thus we remap them to a single Id for now. * @@ -110,7 +110,7 @@ using namespace runtime; */ struct VarHash { size_t operator()(const Var& v) const { - return NodeHash()(v->vid); + return ObjectHash()(v->vid); } }; @@ -130,13 +130,13 @@ Expr PostProcess(const Expr&); class StaticNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Static"; - TVM_DECLARE_BASE_NODE_INFO(StaticNode, RelayNode); + TVM_DECLARE_BASE_OBJECT_INFO(StaticNode, RelayNode); }; -class Static : public NodeRef { +class Static : public ObjectRef { public: Static() {} - explicit Static(ObjectPtr n) : NodeRef(n) {} + explicit Static(ObjectPtr n) : ObjectRef(n) {} const StaticNode* operator->() const { return static_cast(get()); } @@ -146,7 +146,7 @@ class Static : public NodeRef { using Time = size_t; -struct PStaticNode : Node { +struct PStaticNode : Object { static Time time() { static Time time_ = 0; Time ret = time_; @@ -160,35 +160,44 @@ struct PStaticNode : Node { pstatic(pstatic), dynamic(dynamic), created_time(time()) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } static constexpr const char* _type_key = "relay.PStatic"; - TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(PStaticNode, Object); }; -RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef); +class PStatic : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(PStatic, ObjectRef, PStaticNode); +}; struct STupleNode : StaticNode { std::vector fields; explicit STupleNode(const std::vector& fields) : fields(fields) { } static constexpr const char* _type_key = "relay.STuple"; - TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode); + TVM_DECLARE_FINAL_OBJECT_INFO(STupleNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(STuple, STupleNode, Static); +class STuple : public Static { + public: + TVM_DEFINE_OBJECT_REF_METHODS(STuple, Static, STupleNode); +}; Static MkSTuple(const std::vector& fields) { - return Static(make_node(fields)); + return Static(make_object(fields)); } struct STensorNode : StaticNode { runtime::NDArray data; explicit STensorNode(const NDArray& data) : data(data) { } static constexpr const char* _type_key = "relay.STensor"; - TVM_DECLARE_NODE_TYPE_INFO(STensorNode, StaticNode); + TVM_DECLARE_FINAL_OBJECT_INFO(STensorNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(STensor, STensorNode, Static); +class STensor : public Static { + public: + TVM_DEFINE_OBJECT_REF_METHODS(STensor, Static, STensorNode); +}; Static MkSTensor(const NDArray& data) { - return Static(make_node(data)); + return Static(make_object(data)); } struct SConstructorNode : StaticNode { @@ -197,25 +206,31 @@ struct SConstructorNode : StaticNode { SConstructorNode(const Constructor& constructor, const std::vector& fields) : constructor(constructor), fields(fields) { } static constexpr const char* _type_key = "relay.SConstructor"; - TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SConstructorNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(SConstructor, SConstructorNode, Static); +class SConstructor : public Static { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SConstructor, Static, SConstructorNode); +}; Static MkSConstructor(const Constructor& constructor, const std::vector& fields) { - return Static(make_node(constructor, fields)); + return Static(make_object(constructor, fields)); } struct SRefNode : StaticNode { static constexpr const char* _type_key = "relay.SRef"; // we will use the address as the guid for hashing - TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SRefNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(SRef, SRefNode, Static); +class SRef : public Static { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SRef, Static, SRefNode); +}; Static MkSRef() { - return Static(make_node()); + return Static(make_object()); } using Func = std::function(func)); + return Static(make_object(func)); } @@ -246,10 +264,10 @@ class FuelNode; * Every time we recurse, we do a meet and require that progress must be made. * This ensures we do not recurse infinitely in the Partial Evaluator. */ -class Fuel : public NodeRef { +class Fuel : public ObjectRef { public: Fuel() {} - explicit Fuel(ObjectPtr n) : NodeRef(n) {} + explicit Fuel(ObjectPtr n) : ObjectRef(n) {} const FuelNode* operator->() const; using ContainerType = FuelNode; @@ -279,7 +297,7 @@ class FuelNode : public RelayNode { return std::get<0>(ret); } static constexpr const char* _type_key = "relay.Fuel"; - TVM_DECLARE_BASE_NODE_INFO(FuelNode, RelayNode); + TVM_DECLARE_BASE_OBJECT_INFO(FuelNode, RelayNode); }; const FuelNode* Fuel::operator->() const { @@ -301,13 +319,16 @@ struct FSeqNode : FuelNode { } explicit FSeqNode(const std::vector& fuels) : fuels(fuels) { } static constexpr const char* _type_key = "relay.FSeq"; - TVM_DECLARE_NODE_TYPE_INFO(FSeqNode, FuelNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FSeqNode, FuelNode); }; -RELAY_DEFINE_NODE_REF(FSeq, FSeqNode, Fuel); +class FSeq : public Fuel { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FSeq, Fuel, FSeqNode); +}; Fuel MkFSeq(const std::vector& fuels) { - return Fuel(make_node(fuels)); + return Fuel(make_object(fuels)); } Fuel MkFTime(Time time); @@ -321,13 +342,16 @@ struct FTimeNode : FuelNode { } explicit FTimeNode(Time time) : time(time) { } static constexpr const char* _type_key = "relay.FTime"; - TVM_DECLARE_NODE_TYPE_INFO(FTimeNode, FuelNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FTimeNode, FuelNode); }; -RELAY_DEFINE_NODE_REF(FTime, FTimeNode, Fuel); +class FTime : public Fuel { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FTime, Fuel, FTimeNode); +}; Fuel MkFTime(Time time) { - return Fuel(make_node(time)); + return Fuel(make_object(time)); } Fuel MkFTValue(size_t tvalue); @@ -342,13 +366,16 @@ struct FTValueNode : FuelNode { } explicit FTValueNode(size_t tvalue) : tvalue(tvalue) { } static constexpr const char* _type_key = "relay.FTValue"; - TVM_DECLARE_NODE_TYPE_INFO(FTValueNode, FuelNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FTValueNode, FuelNode); }; -RELAY_DEFINE_NODE_REF(FTValue, FTValueNode, Fuel); +class FTValue : public Fuel { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FTValue, Fuel, FTValueNode); +}; Fuel MkFTValue(size_t tvalue) { - return Fuel(make_node(tvalue)); + return Fuel(make_object(tvalue)); } /*! \brief Initially every element has Fuel of FTop. It is the largest element. @@ -361,13 +388,16 @@ struct FTopNode : FuelNode { return std::make_tuple(f, !f.as()); } static constexpr const char* _type_key = "relay.FTop"; - TVM_DECLARE_NODE_TYPE_INFO(FTopNode, FuelNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FTopNode, FuelNode); }; -RELAY_DEFINE_NODE_REF(FTop, FTopNode, Fuel); +class FTop : public Fuel { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FTop, Fuel, FTopNode); +}; Fuel MkFTop() { - return Fuel(make_node()); + return Fuel(make_object()); } /*! @@ -500,11 +530,11 @@ class Store { PStatic HasStatic(const Static& stat, const Expr& dynamic) { CHECK(stat.defined()); - return PStatic(make_node(stat, dynamic)); + return PStatic(make_object(stat, dynamic)); } PStatic NoStatic(const Expr& dynamic) { - return PStatic(make_node(dynamic)); + return PStatic(make_object(dynamic)); } enum struct MatchStatus { @@ -559,6 +589,7 @@ struct WithFuncIdAttrs : public tvm::AttrsNode { TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); + RELAY_REGISTER_OP("annotation.with_funcid") .describe(R"code(Annotate a function with a funcid.)code" TVM_ADD_FILELINE) @@ -569,7 +600,7 @@ TVM_ADD_FILELINE) static const Op& with_funcid_op = Op::Get("annotation.with_funcid"); Expr MkWithFuncId(const Expr& expr, FuncId fid) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->fid = fid; return CallNode::make(with_funcid_op, {expr}, Attrs(attrs), {}); } @@ -1147,7 +1178,7 @@ class PartialEvaluator : public ExprFunctor private: Environment env_; Module mod_; - std::unordered_map gv_map_; + std::unordered_map gv_map_; /*! Termination checking is done as follows: * We have finitely many FunctionIds. * Each FunctionId maps to a class of semantically equivalent function (ignoring type), @@ -1161,7 +1192,7 @@ class PartialEvaluator : public ExprFunctor * when we PE inside the Function body. * Termination is guaranteed because Fuel is finitely descending - there can only be so many meet. */ - std::unordered_map func_map_; + std::unordered_map func_map_; std::unordered_map fuel_map_; Store store_; DLContext context_ = CPUContext(); diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 97b8fd681cb8..909ba0b8d712 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -44,7 +44,7 @@ struct RelayPassContextThreadLocalEntry { std::stack context_stack; RelayPassContextThreadLocalEntry() { - default_context = PassContext(make_node()); + default_context = PassContext(make_object()); } }; @@ -77,7 +77,7 @@ PassContext PassContext::Current() { } PassContext PassContext::Create() { - return PassContext(make_node()); + return PassContext(make_object()); } class ModulePass; @@ -126,10 +126,13 @@ class ModulePassNode : public PassNode { PassInfo pass_info); static constexpr const char* _type_key = "relay.ModulePass"; - TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode); }; -RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass); +class ModulePass : public Pass { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); +}; class FunctionPass; @@ -180,7 +183,7 @@ class FunctionPassNode : public PassNode { PassInfo pass_info); static constexpr const char* _type_key = "relay.FunctionPass"; - TVM_DECLARE_NODE_TYPE_INFO(FunctionPassNode, PassNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); private: /* @@ -193,7 +196,10 @@ class FunctionPassNode : public PassNode { bool SkipFunction(const Function& func) const; }; -RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); +class FunctionPass : public Pass { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); +}; /*! * \brief The SequentialNode contains a set of passes that transform Relay @@ -258,13 +264,13 @@ class SequentialNode : public PassNode { Module operator()(const Module& mod, const PassContext& pass_ctx) const final; static constexpr const char* _type_key = "relay.Sequential"; - TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); }; PassInfo PassInfoNode::make(int opt_level, std::string name, tvm::Array required) { - auto pass_info = make_node(); + auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); @@ -274,7 +280,7 @@ PassInfo PassInfoNode::make(int opt_level, ModulePass ModulePassNode::make( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { - auto n = make_node(); + auto n = make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); return ModulePass(n); @@ -297,7 +303,7 @@ Module ModulePassNode::operator()(const Module& mod, FunctionPass FunctionPassNode::make( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { - auto n = make_node(); + auto n = make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); return FunctionPass(n); @@ -330,20 +336,20 @@ Module FunctionPassNode::operator()(const Module& mod, } bool FunctionPassNode::SkipFunction(const Function& func) const { - NodeRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); + ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); const ir::IntImm* pval = skip_opt.as(); return (pval && pval->value != 0) || (!func->UseDefaultCompiler()); } Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { - auto n = make_node(); + auto n = make_object(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); data_ = std::move(n); } Sequential::Sequential(tvm::Array passes, std::string name) { - auto n = make_node(); + auto n = make_object(); n->passes = std::move(passes); PassInfo pass_info = PassInfoNode::make(2, std::move(name), {}); n->pass_info = std::move(pass_info); diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h index 225ce610d909..2d4722bc6759 100644 --- a/src/relay/pass/pass_util.h +++ b/src/relay/pass/pass_util.h @@ -39,7 +39,7 @@ namespace relay { * \param body The body expression. * \return The reference count mapping. */ -std::unordered_map +std::unordered_map GetExprRefCount(const Expr& body); /*! @@ -108,57 +108,57 @@ inline bool IsAtomic(const Expr& e) { return e.as() || e.as() || e.as() || e.as(); } -template +template struct TreeNode { - typedef std::shared_ptr> pointer; + typedef std::shared_ptr> pointer; virtual ~TreeNode() {} }; -template -struct TreeLeafNode : TreeNode { - using TreeNodePtr = typename TreeNode::pointer; +template +struct TreeLeafNode : TreeNode { + using TreeObjectPtr = typename TreeNode::pointer; Expr body; explicit TreeLeafNode(Expr body): body(body) {} - static TreeNodePtr Make(Expr body) { + static TreeObjectPtr Make(Expr body) { return std::make_shared(body); } ~TreeLeafNode() {} }; -template -struct TreeLeafFatalNode : TreeNode { - using TreeNodePtr = typename TreeNode::pointer; +template +struct TreeLeafFatalNode : TreeNode { + using TreeObjectPtr = typename TreeNode::pointer; TreeLeafFatalNode() = default; - static TreeNodePtr Make() { + static TreeObjectPtr Make() { return std::make_shared(); } ~TreeLeafFatalNode() {} }; -template -struct TreeBranchNode : TreeNode { - using TreeNodePtr = typename TreeNode::pointer; +template +struct TreeBranchNode : TreeNode { + using TreeObjectPtr = typename TreeNode::pointer; - ConditionNodePtr cond; - TreeNodePtr then_branch; - TreeNodePtr else_branch; + ConditionObjectPtr cond; + TreeObjectPtr then_branch; + TreeObjectPtr else_branch; - TreeBranchNode(ConditionNodePtr cond, - TreeNodePtr then_branch, - TreeNodePtr else_branch) + TreeBranchNode(ConditionObjectPtr cond, + TreeObjectPtr then_branch, + TreeObjectPtr else_branch) : cond(cond), then_branch(then_branch), else_branch(else_branch) {} - static TreeNodePtr Make(ConditionNodePtr cond, - TreeNodePtr then_branch, - TreeNodePtr else_branch) { + static TreeObjectPtr Make(ConditionObjectPtr cond, + TreeObjectPtr then_branch, + TreeObjectPtr else_branch) { return std::make_shared(cond, then_branch, else_branch); } diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 5e93ea1ff0aa..d3ec342b883d 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -104,9 +104,9 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, size_t base = tlhs->shape.size() - trhs->shape.size(); size_t j = 0; - NodePtr squeeze_attrs; + ObjectPtr squeeze_attrs; if (rhs_value != nullptr) { - squeeze_attrs = make_node(); + squeeze_attrs = make_object(); } for (size_t i = 0; i < tlhs->shape.size(); ++i) { @@ -149,7 +149,7 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, if (i == axes.size()) { int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1; if (num_pad_axis > 0) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = i; attrs->num_newaxis = static_cast(num_pad_axis); bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {}); @@ -158,7 +158,7 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, int64_t diff = axes[i]->value - axes[i - 1]->value; CHECK_GE(diff, 0L); if (diff > 0) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = i; attrs->num_newaxis = static_cast(diff); bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {}); @@ -291,7 +291,7 @@ T GetScalarFromConstant(Expr expr) { inline Expr Cast(Expr x, DataType dtype) { static const Op& op = Op::Get("cast"); - auto attrs = make_node(); + auto attrs = make_object(); attrs->dtype = dtype; return CallNode::make(op, {x}, Attrs(attrs), {}); } @@ -322,7 +322,7 @@ inline Expr Round(Expr x) { inline Expr Clip(Expr x, double a_min, double a_max) { static const Op& op = Op::Get("clip"); - auto attrs = make_node(); + auto attrs = make_object(); attrs->a_min = a_min; attrs->a_max = a_max; return CallNode::make(op, {x}, Attrs(attrs), {}); @@ -358,7 +358,7 @@ inline Expr ZerosLike(Expr e) { } inline Expr Zeros(Array shape, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("zeros"); @@ -406,7 +406,7 @@ inline Expr Copy(Expr data) { inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; attrs->exclude = exclude; @@ -415,7 +415,7 @@ inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { } inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; attrs->exclude = exclude; @@ -437,7 +437,7 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->shape = std::move(shape); attrs->dtype = std::move(dtype); static const Op& op = Op::Get("full"); @@ -448,7 +448,7 @@ static inline Expr Conv2D(Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, IndexExpr channels, Array kernel_size, std::string data_layout, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); @@ -467,7 +467,7 @@ static inline Expr Dense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("nn.dense"); @@ -475,7 +475,7 @@ static inline Expr Dense(Expr data, } static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclude) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; attrs->exclude = exclude; @@ -484,7 +484,7 @@ static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclu } static inline Expr Reshape(Expr data, Array newshape) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = false; static const Op& op = Op::Get("reshape"); @@ -494,7 +494,7 @@ static inline Expr Reshape(Expr data, Array newshape) { static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides, Array padding, std::string layout, bool ceil_mode, bool count_include_pad) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -507,7 +507,7 @@ static inline Expr AvgPool2D(Expr data, Array pool_size, Array> pad_width, double pad_value, std::string pad_mode) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->pad_value = pad_value; attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); @@ -516,7 +516,7 @@ static inline Expr Pad(Expr data, Array> pad_width, double pad_ } static inline Expr Tile(Expr data, Array reps) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->reps = reps; static const Op& op = Op::Get("tile"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -530,7 +530,7 @@ Expr MakeStridedSlice(Expr data, Array begin, Array end, Array Expr MakeStack(Expr data, int axis); -Expr MakeSplit(Expr data, NodeRef indices_or_sections, int axis); +Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); Expr MakeSqueeze(Expr data, Array axis); diff --git a/src/relay/pass/quantize/annotate.cc b/src/relay/pass/quantize/annotate.cc index c834d8e868fa..c3d01071caf7 100644 --- a/src/relay/pass/quantize/annotate.cc +++ b/src/relay/pass/quantize/annotate.cc @@ -50,10 +50,13 @@ class QAnnotateExprNode : public TempExprNode { Expr Realize() const final; static constexpr const char* _type_key = "relay.QAnnotateExpr"; - TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(QAnnotateExprNode, TempExprNode); }; -RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr); +class QAnnotateExpr : public TempExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode); +}; Expr QAnnotateExprNode::Realize() const { @@ -61,7 +64,7 @@ Expr QAnnotateExprNode::Realize() const { } QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) { - auto rnode = make_node(); + auto rnode = make_object(); rnode->expr = expr; rnode->kind = kind; return QAnnotateExpr(rnode); diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc index e78abbf6aee0..f9893f57a2c1 100644 --- a/src/relay/pass/quantize/calibrate.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -56,7 +56,7 @@ class StatsCollector : private ExprMutator { if (new_call->op == simulated_quantize_op_) { auto attrs = new_call->attrs.as(); // rewrite the annotation - auto new_attrs = make_node(); + auto new_attrs = make_object(); const Expr& quantize_input = new_call->args[0]; // expression being quantized auto placeholder = MakeConstantScalar(DataType::Float(32), 0.); // unused argument Array new_args{quantize_input, placeholder, placeholder, placeholder}; diff --git a/src/relay/pass/quantize/partition.cc b/src/relay/pass/quantize/partition.cc index 691483102937..710684caa1b1 100644 --- a/src/relay/pass/quantize/partition.cc +++ b/src/relay/pass/quantize/partition.cc @@ -50,10 +50,13 @@ class QPartitionExprNode : public TempExprNode { Expr Realize() const final; static constexpr const char* _type_key = "relay.QPartitionExpr"; - TVM_DECLARE_NODE_TYPE_INFO(QPartitionExprNode, TempExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(QPartitionExprNode, TempExprNode); }; -RELAY_DEFINE_NODE_REF(QPartitionExpr, QPartitionExprNode, TempExpr); +class QPartitionExpr : public TempExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode); +}; Expr QPartitionExprNode::Realize() const { @@ -64,7 +67,7 @@ Expr QPartitionExprNode::Realize() const { } QPartitionExpr QPartitionExprNode::make(Expr expr) { - auto rnode = make_node(); + auto rnode = make_object(); rnode->expr = expr; return QPartitionExpr(rnode); } diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index c022d4236b05..ef78bf2503d8 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -70,7 +70,7 @@ TVM_REGISTER_API("relay._quantize.simulated_quantize") .set_body_typed( [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign, std::string rounding) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->kind = kind; attrs->sign = sign; attrs->rounding = rounding; @@ -88,7 +88,7 @@ struct TVMQConfigThreadLocalEntry { std::stack context_stack; TVMQConfigThreadLocalEntry() : - default_config(make_node()) { + default_config(make_object()) { } }; diff --git a/src/relay/pass/quantize/quantize.h b/src/relay/pass/quantize/quantize.h index 77900ab33e7b..bfb7653686b6 100644 --- a/src/relay/pass/quantize/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -62,7 +62,7 @@ class QConfig; /*! * \brief Container for build configuration options */ -class QConfigNode : public Node { +class QConfigNode : public Object { public: int nbit_input = 8; int nbit_weight = 8; @@ -73,10 +73,10 @@ class QConfigNode : public Node { std::string calibrate_mode = "global_scale"; double global_scale = 8.0; std::string weight_scale = "power2"; - Array skip_conv_layers = Array(NodePtr(nullptr)); + Array skip_conv_layers = Array(ObjectPtr(nullptr)); bool do_simulation = false; bool round_for_shift = true; - Array debug_enabled_ops = Array(NodePtr(nullptr)); + Array debug_enabled_ops = Array(ObjectPtr(nullptr)); std::string rounding = "UPWARD"; void VisitAttrs(AttrVisitor* v) { @@ -97,16 +97,16 @@ class QConfigNode : public Node { } static constexpr const char* _type_key = "relay.quantize.QConfig"; - TVM_DECLARE_NODE_TYPE_INFO(QConfigNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(QConfigNode, Object); }; /*! * \brief Container for build configuration options */ -class QConfig : public NodeRef { +class QConfig : public ObjectRef { public: QConfig() {} - explicit QConfig(ObjectPtr n) : NodeRef(n) {} + explicit QConfig(ObjectPtr n) : ObjectRef(n) {} const QConfigNode* operator->() const { return static_cast(get()); diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc index 7a7e218ced05..bb8edf1edda7 100644 --- a/src/relay/pass/quantize/realize.cc +++ b/src/relay/pass/quantize/realize.cc @@ -45,10 +45,13 @@ class QRealizeExprNode : public TempExprNode { public: Expr data; static constexpr const char* _type_key = "relay.quantize.QRealizeExpr"; - TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode); + TVM_DECLARE_BASE_OBJECT_INFO(QRealizeExprNode, TempExprNode); }; -RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr); +class QRealizeExpr : public TempExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(QRealizeExpr, TempExpr, QRealizeExprNode); +}; class QRealizeIntExprNode : public QRealizeExprNode { @@ -67,10 +70,13 @@ class QRealizeIntExprNode : public QRealizeExprNode { TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype); static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; - TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(QRealizeIntExprNode, QRealizeExprNode); }; -RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr); +class QRealizeIntExpr : public QRealizeExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(QRealizeIntExpr, QRealizeExpr, QRealizeIntExprNode); +}; Expr QRealizeIntExprNode::Realize() const { @@ -82,7 +88,7 @@ Expr QRealizeIntExprNode::Realize() const { } QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) { - NodePtr n = make_node(); + ObjectPtr n = make_object(); n->data = std::move(data); n->dom_scale = std::move(dom_scale); n->dtype = std::move(dtype); @@ -120,7 +126,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, Expr QuantizeRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); // do not handle data type cast const auto param = ref_call->attrs.as(); @@ -196,7 +202,7 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") Expr Conv2dRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->IsInstance() && !new_args[1]->IsInstance()) { @@ -214,7 +220,7 @@ Expr Conv2dRealize(const Call& ref_call, Expr rdata = Cast(rhs->data, cfg->dtype_weight); const auto ref_attrs = ref_call->attrs.as(); - auto attrs = make_node(); + auto attrs = make_object(); *attrs = *ref_attrs; DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; @@ -232,7 +238,7 @@ RELAY_REGISTER_OP("nn.conv2d") Expr DenseRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->IsInstance() || !new_args[1]->IsInstance()) { @@ -248,7 +254,7 @@ Expr DenseRealize(const Call& ref_call, Expr rdata = Cast(rhs->data, cfg->dtype_weight); const auto ref_attrs = ref_call->attrs.as(); - auto attrs = make_node(); + auto attrs = make_object(); *attrs = *ref_attrs; DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; @@ -266,7 +272,7 @@ RELAY_REGISTER_OP("nn.dense") Expr MulRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (new_args[0].as() && new_args[1].as()) { @@ -364,7 +370,7 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args Expr AddRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 2); if (new_args[0].as() && new_args[1].as()) { DataType dtype; @@ -383,11 +389,11 @@ RELAY_REGISTER_OP("add") Expr ClipRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { const auto ref_attrs = ref_call->attrs.as(); - auto attrs = make_node(); + auto attrs = make_object(); double dom_scale = GetScalarFromConstant(n->dom_scale); attrs->a_min = ref_attrs->a_min / dom_scale; attrs->a_max = ref_attrs->a_max / dom_scale; @@ -406,7 +412,7 @@ RELAY_REGISTER_OP("clip") Expr ConcatenateRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); CHECK_EQ(ref_call->args.size(), 1); @@ -438,7 +444,7 @@ RELAY_REGISTER_OP("concatenate") /* \brief forward the original operator */ Expr IdentityRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = ForwardOp(ref_call, {n->data}); @@ -460,7 +466,7 @@ RELAY_REGISTER_OP("annotation.stop_fusion") /* \brief for unary operators which requantize its input to dtype_nbit */ Expr CastDtypeInputRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { @@ -478,7 +484,7 @@ RELAY_REGISTER_OP("nn.max_pool2d") Expr AvgPoolRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { @@ -501,7 +507,7 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") Expr CastHintRealize(const Call& ref_call, const Array& new_args, - const NodeRef& ctx) { + const ObjectRef& ctx) { const auto param = ref_call->attrs.as(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index acd5163d1335..6d6171c9e461 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -173,7 +173,7 @@ class InferenceSimplifier : public ExprMutator { const Op& dropout_op_; const Op& instance_norm_op_; const Op& layer_norm_op_; - std::unordered_map ty_map_; + std::unordered_map ty_map_; }; Expr SimplifyInference(const Expr& e) { diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 3cce4b6b81a5..57894e015f0b 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -110,7 +110,7 @@ class Fill : ExprFunctor { private: const DependencyGraph& dg_; std::unordered_map* node_scope_; - std::unordered_map memo; + std::unordered_map memo; Fill(const DependencyGraph& dg, std::unordered_map* node_scope) : diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index c20695becd6c..1dfa327d8b0e 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -89,10 +89,10 @@ Type CPSType(const Type& t, const TypeVar& answer) { } // transform global functions into cps form. -using CPSMap = std::unordered_map; +using CPSMap = std::unordered_map; // transform vars from the original program into new vars, so their type will be correct. -using VarMap = std::unordered_map; +using VarMap = std::unordered_map; /* * The meta continuation. diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index 5060c13fc75f..b00e0d420641 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -52,7 +52,7 @@ class UseVarVisitor : public ExprVisitor { class GNF : public ExprMutator { private: - std::unordered_map var_map_; + std::unordered_map var_map_; Expr VisitExpr_(const VarNode* vn) override { Var v = GetRef(vn); return var_map_.count(v) == 0 ? v : var_map_.at(v); diff --git a/src/relay/pass/transform_layout.h b/src/relay/pass/transform_layout.h index f6c5e9af6d62..d283a239f2f6 100644 --- a/src/relay/pass/transform_layout.h +++ b/src/relay/pass/transform_layout.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -41,15 +41,16 @@ namespace relay { /*! * \brief Memorizes layout transformations to reuse. */ -class TransformMemorizerNode : public Node { +class TransformMemorizerNode : public Object { public: /*! \brief The key for the memorizer map is (Expr, src_layout, dst_layout). */ - using TransformKey = std::tuple; + using TransformKey = std::tuple; struct key_hash : public std::function { std::size_t operator()(const TransformKey& k) const { return dmlc::HashCombine( - dmlc::HashCombine(std::hash()(std::get<0>(k)), std::get<1>(k)), + dmlc::HashCombine( + std::hash()(std::get<0>(k)), std::get<1>(k)), (std::get<2>(k))); } }; @@ -58,16 +59,16 @@ class TransformMemorizerNode : public Node { std::unordered_map memo; static constexpr const char* _type_key = "relay.alter_op_layout.TransformMemorizerNode"; - TVM_DECLARE_NODE_TYPE_INFO(TransformMemorizerNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(TransformMemorizerNode, Object); }; /*! * \brief Container that transforms the layouts and memorizes them. */ -class TransformMemorizer : public NodeRef { +class TransformMemorizer : public ObjectRef { public: TransformMemorizer() {} - explicit TransformMemorizer(ObjectPtr n) : NodeRef(n) {} + explicit TransformMemorizer(ObjectPtr n) : ObjectRef(n) {} TransformMemorizerNode* operator->() { return static_cast(get_mutable()); @@ -85,7 +86,7 @@ class TransformMemorizer : public NodeRef { return raw; } - std::tuple key = + std::tuple key = std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name()); auto& memo = operator->()->memo; @@ -179,7 +180,7 @@ class LayoutAlternatedExprNode : public TempExprNode { } static constexpr const char* _type_key = "relay.alter_op_layout.LayoutAlternatedExprNode"; - TVM_DECLARE_NODE_TYPE_INFO(LayoutAlternatedExprNode, TempExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LayoutAlternatedExprNode, TempExprNode); }; /*! @@ -187,10 +188,10 @@ class LayoutAlternatedExprNode : public TempExprNode { * \tparam TransformMemorizerT The derived TransformMemorizer type. */ template -class LayoutAlternatedExpr : public NodeRef { +class LayoutAlternatedExpr : public ObjectRef { public: LayoutAlternatedExpr() {} - explicit LayoutAlternatedExpr(ObjectPtr n) : NodeRef(n) {} + explicit LayoutAlternatedExpr(ObjectPtr n) : ObjectRef(n) {} LayoutAlternatedExprNode* operator->() { return static_cast*>(get_mutable()); @@ -219,7 +220,7 @@ class LayoutAlternatedExpr : public NodeRef { * - Transform the original call to reuse the new layouts using TransformMemorizer. */ template -Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { +Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { std::vector> inputs; std::vector normal_new_args; Array> input_shapes; @@ -239,7 +240,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Nod inputs.push_back(GetRef>(inp)); return inp->value; } else { - auto inode = make_node>(); + auto inode = make_object>(); inode->value = arg; inode->memorizer = memorizer; inputs.push_back(LayoutAlternatedExpr(inode)); @@ -342,7 +343,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Nod Expr tuple_output = CallNode::make(new_call->op, transformed_args, new_call->attrs); Array fields; for (size_t i = 0; i < new_out.size(); ++i) { - auto rnode = make_node>(); + auto rnode = make_object>(); rnode->value = TupleGetItemNode::make(tuple_output, i); rnode->old_layout = old_out[i]; rnode->new_layout = new_out[i]; @@ -351,7 +352,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Nod } return TupleNode::make(fields); } else { - auto rnode = make_node>(); + auto rnode = make_object>(); CHECK_EQ(new_out.size(), 1); rnode->value = CallNode::make(new_call->op, transformed_args, new_call->attrs); rnode->old_layout = old_out[0]; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 2c4cff4983a6..6e992bbeea1a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -90,7 +90,7 @@ struct ResolvedTypeInfo { Type checked_type; // Only allocated when the expression is a call. - Array type_args = Array(NodePtr(nullptr)); + Array type_args = Array(ObjectPtr(nullptr)); }; // @@ -128,7 +128,7 @@ class TypeInferencer : private ExprFunctor, // map from expression to checked type // type inferencer will populate it up - std::unordered_map type_map_; + std::unordered_map type_map_; // The solver used by the inferencer. TypeSolver solver_; @@ -138,7 +138,7 @@ class TypeInferencer : private ExprFunctor, // Perform unification on two types and report the error at the expression // or the span of the expression. - Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) { + Type Unify(const Type& t1, const Type& t2, const ObjectRef& expr) { try { return solver_.Unify(t1, t2, expr); } catch (const dmlc::Error &e) { @@ -168,7 +168,7 @@ class TypeInferencer : private ExprFunctor, return ret; } - void ReportFatalError(const NodeRef& expr, const Error& err) { + void ReportFatalError(const ObjectRef& expr, const Error& err) { CHECK(this->current_func_.defined()); this->err_reporter.ReportAt(this->current_func_, expr, err); this->err_reporter.RenderErrors(this->mod_); @@ -215,7 +215,7 @@ class TypeInferencer : private ExprFunctor, } Type tuple_type = GetType(op->tuple); Type rtype = IncompleteTypeNode::make(Kind::kType); - auto attrs = make_node(); + auto attrs = make_object(); attrs->index = op->index; solver_.AddConstraint(TypeRelationNode::make( tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef(op)); @@ -235,7 +235,7 @@ class TypeInferencer : private ExprFunctor, unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } Type expected = TypeCallNode::make(con->constructor->belong_to, unknown_args); - Type unified = Unify(t, expected, GetRef(con)); + Type unified = Unify(t, expected, GetRef(con)); auto* tc = unified.as(); if (!tc) { @@ -250,7 +250,7 @@ class TypeInferencer : private ExprFunctor, << "the number of type vars in the type data: " << td->type_vars.size() << " != " << tc->args.size())); } - std::unordered_map type_var_map_; + std::unordered_map type_var_map_; for (size_t i = 0; i < td->type_vars.size(); ++i) { type_var_map_[td->type_vars[i]] = tc->args[i]; } @@ -274,7 +274,7 @@ class TypeInferencer : private ExprFunctor, unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } Type expected = TupleTypeNode::make(unknown_args); - Type unified = Unify(t, expected, GetRef(tup)); + Type unified = Unify(t, expected, GetRef(tup)); auto* tt = unified.as(); if (!tt) { @@ -372,7 +372,7 @@ class TypeInferencer : private ExprFunctor, Type PrimitiveCall(const FuncTypeNode* op, Array arg_types, const Attrs& attrs, - const NodeRef& loc) { + const ObjectRef& loc) { if (op->type_params.size() != arg_types.size() + 1) return Type(); if (op->type_constraints.size() != 1) return Type(); const TypeRelationNode* rel = op->type_constraints[0].as(); @@ -594,7 +594,7 @@ class TypeInferencer : private ExprFunctor, class TypeInferencer::Resolver : public ExprMutator, PatternMutator { public: - Resolver(const std::unordered_map& tmap, + Resolver(const std::unordered_map& tmap, TypeSolver* solver) : tmap_(tmap), solver_(solver) { } @@ -723,7 +723,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { // Copy on write optimization // If new_e is an old expression, // we make a copy mutating an existing reference. - NodePtr ptr = make_node(*new_e.as()); + ObjectPtr ptr = make_object(*new_e.as()); new_e = Expr(ptr); new_call = ( std::is_base_of::value ? @@ -763,8 +763,8 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } private: - std::unordered_map vmap_; - const std::unordered_map& tmap_; + std::unordered_map vmap_; + const std::unordered_map& tmap_; TypeSolver* solver_; // whether attach the checked type as type_annotation // if original type anntation is missing. @@ -814,7 +814,7 @@ Function InferType(const Function& func, const Module& mod, const GlobalVar& var) { CHECK(mod.defined()) << "internal error: module must be set for type inference"; - Function func_copy = Function(make_node(*func.operator->())); + Function func_copy = Function(make_object(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); mod->AddUnchecked(var, func_copy); Expr func_ret = TypeInferencer(mod, var).Infer(func_copy); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 8376d3669899..86ebe0f22c8d 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -56,7 +56,7 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } - TVM_DLL void SetLocation(const NodeRef& ref) final { + TVM_DLL void SetLocation(const ObjectRef& ref) final { location = ref; } @@ -66,7 +66,7 @@ class TypeSolver::Reporter : public TypeReporterNode { private: /*! \brief The location to report unification errors at. */ - mutable NodeRef location; + mutable ObjectRef location; TypeSolver* solver_; }; @@ -95,7 +95,7 @@ class TypeSolver::OccursChecker : public TypeVisitor { class TypeSolver::Unifier : public TypeFunctor { public: - explicit Unifier(TypeSolver* solver, const NodeRef& loc) : solver_(solver), loc(loc) {} + explicit Unifier(TypeSolver* solver, const ObjectRef& loc) : solver_(solver), loc(loc) {} Type Unify(const Type& src, const Type& dst) { // Known limitation @@ -150,8 +150,8 @@ class TypeSolver::Unifier : public TypeFunctor { } // default: unify only if alpha-equal - Type VisitTypeDefault_(const Node* op, const Type& tn) final { - NodeRef nr = GetRef(op); + Type VisitTypeDefault_(const Object* op, const Type& tn) final { + ObjectRef nr = GetRef(op); Type t1 = GetRef(nr.as()); if (!AlphaEqual(t1, tn)) { return Type(nullptr); @@ -365,7 +365,7 @@ class TypeSolver::Unifier : public TypeFunctor { private: TypeSolver* solver_; - NodeRef loc; + ObjectRef loc; }; class TypeSolver::Resolver : public TypeMutator { @@ -408,8 +408,8 @@ class TypeSolver::Propagator : public TypeFunctor { } } - void VisitTypeDefault_(const Node* op) override { - NodeRef nr = GetRef(op); + void VisitTypeDefault_(const Object* op) override { + ObjectRef nr = GetRef(op); Type t = GetRef(nr.as()); UpdateRelSet(t); } @@ -492,8 +492,8 @@ class TypeSolver::Merger : public TypeFunctor { } } - void VisitTypeDefault_(const Node* op) override { - NodeRef nr = GetRef(op); + void VisitTypeDefault_(const Object* op) override { + ObjectRef nr = GetRef(op); Type t = GetRef(nr.as()); TransferLinks(t); } @@ -533,7 +533,7 @@ TypeSolver::TypeSolver( const GlobalVar& current_func, const Module& module, ErrorReporter* err_reporter) - : reporter_(make_node(this)), + : reporter_(make_object(this)), current_func(current_func), err_reporter_(err_reporter), module_(module) { @@ -558,19 +558,19 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { } // Add equality constraint -Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef& loc) { +Type TypeSolver::Unify(const Type& dst, const Type& src, const ObjectRef& loc) { Unifier unifier(this, loc); return unifier.Unify(dst, src); } -void TypeSolver::ReportError(const Error& err, const NodeRef& location) { +void TypeSolver::ReportError(const Error& err, const ObjectRef& location) { CHECK(location.defined()); CHECK(current_func.defined()); err_reporter_->ReportAt(current_func, location, err); } // Add type constraint to the solver. -void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) { +void TypeSolver::AddConstraint(const TypeConstraint& constraint, const ObjectRef& loc) { if (const auto* op = constraint.as()) { // create a new relation node. RelationNode* rnode = arena_.make(); diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index fa9ef7a15646..bf1ac716cfc5 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -69,7 +69,7 @@ class TypeSolver { * \param constraint The constraint to be added. * \param location The location at which the constraint was incurred. */ - void AddConstraint(const TypeConstraint& constraint, const NodeRef& lcoation); + void AddConstraint(const TypeConstraint& constraint, const ObjectRef& lcoation); /*! * \brief Resolve type to the solution type in the solver. * \param type The type to be resolved. @@ -87,13 +87,13 @@ class TypeSolver { * \param rhs The right operand * \param location The location at which the unification problem arose. */ - Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location); + Type Unify(const Type& lhs, const Type& rhs, const ObjectRef& location); /*! * \brief Report an error at the provided location. * \param err The error to report. * \param loc The location at which to report the error. */ - void ReportError(const Error& err, const NodeRef& location); + void ReportError(const Error& err, const ObjectRef& location); private: class OccursChecker; @@ -155,7 +155,7 @@ class TypeSolver { /*! \brief list types to this relation */ LinkedList type_list; /*! \brief The location this type relation originated from. */ - NodeRef location; + ObjectRef location; }; /*! \brief A simple union find between shapes. */ @@ -167,7 +167,7 @@ class TypeSolver { /*! \brief Number of resolved relations */ size_t num_resolved_rels_{0}; /*! \brief map from types to type nodes. */ - std::unordered_map tmap_; + std::unordered_map tmap_; /*! \brief Internal queue to update the relation */ std::queue update_queue_; /*! \brief allocator of all the internal node obhect*/ diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 17c527b39237..2efb479c3156 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -35,7 +35,7 @@ namespace relay { template struct InsertionSet { - std::unordered_set set; + std::unordered_set set; std::vector data; void Insert(const T& t) { if (set.count(t) == 0) { @@ -279,7 +279,7 @@ TVM_REGISTER_API("relay._analysis.free_vars") TVM_REGISTER_API("relay._analysis.bound_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[0]; + ObjectRef x = args[0]; if (x.as()) { *ret = BoundVars(Downcast(x)); } else { @@ -292,7 +292,7 @@ TVM_REGISTER_API("relay._analysis.all_vars") TVM_REGISTER_API("relay._analysis.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[0]; + ObjectRef x = args[0]; Module mod = args[1]; if (x.as()) { *ret = FreeTypeVars(Downcast(x), mod); @@ -303,7 +303,7 @@ TVM_REGISTER_API("relay._analysis.free_type_vars") TVM_REGISTER_API("relay._analysis.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[0]; + ObjectRef x = args[0]; Module mod = args[1]; if (x.as()) { *ret = BoundTypeVars(Downcast(x), mod); @@ -314,7 +314,7 @@ TVM_REGISTER_API("relay._analysis.bound_type_vars") TVM_REGISTER_API("relay._analysis.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[0]; + ObjectRef x = args[0]; Module mod = args[1]; if (x.as()) { *ret = AllTypeVars(Downcast(x), mod); @@ -328,11 +328,11 @@ TVM_REGISTER_API("relay._analysis.all_type_vars") * \param body The body expression. * \return The reference count mapping. */ -std::unordered_map +std::unordered_map GetExprRefCount(const Expr& body) { class ExprRefCounter : private ExprVisitor { public: - std::unordered_map + std::unordered_map Get(const Expr& body) { this->VisitExpr(body); return std::move(this->visit_counter_); diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index abcedd2ab483..2bbf9792dd1d 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -34,10 +34,10 @@ namespace relay { class WellFormedChecker : private ExprVisitor, PatternVisitor { bool well_formed = true; - std::vector> scope; - std::unordered_set current_bound; - std::unordered_set total_bound; - std::unordered_set free; + std::vector> scope; + std::unordered_set current_bound; + std::unordered_set total_bound; + std::unordered_set free; struct Scope { WellFormedChecker* wfc; diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 8c27d47632a1..43d47e21822e 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -39,7 +39,7 @@ TVM_REGISTER_NODE_TYPE(QnnConcatenateAttrs); Expr MakeQnnConcatenate(Expr data, Array input_scales, Array input_zero_points, double output_scale, int32_t output_zero_point, int axis) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->input_scales = std::move(input_scales); attrs->input_zero_points = std::move(input_zero_points); attrs->output_scale = output_scale; diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index a629bf2b462e..669b04fdda48 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -607,7 +607,7 @@ Expr MakeQnnConv2D(Expr data, Expr weight, int32_t input_zero_point, int32_t ker int groups, IndexExpr channels, Array kernel_size, std::string data_layout, std::string kernel_layout, std::string out_layout, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index ad0da52ec120..2353e5a89096 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -60,7 +60,7 @@ Expr MakeQuantizedDense(Expr data, Expr weight, int32_t input_zero_point, int32_t kernel_zero_point, double input_scale, double kernel_scale, IndexExpr units, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->units = std::move(units); attrs->out_dtype = out_dtype; attrs->input_zero_point = input_zero_point; diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 7daee4664ac5..a1e23808d437 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -56,7 +56,7 @@ bool DequantizeRel(const Array& types, Expr MakeDequantize(Expr data, double input_scale, int32_t input_zero_point) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->input_scale = input_scale; attrs->input_zero_point = input_zero_point; // real_value = scale * (quantized_value - zero_point) diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index be8e197b78b0..2c116fedeaee 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -50,7 +50,7 @@ namespace qnn { .set_body_typed( \ [](Expr lhs, Expr rhs, double lhs_scale, int32_t lhs_zero_point, double rhs_scale, \ int32_t rhs_zero_point, double output_scale, int32_t output_zero_point) { \ - auto attrs = make_node(); \ + auto attrs = make_object(); \ attrs->lhs_scale = lhs_scale; \ attrs->lhs_zero_point = lhs_zero_point; \ attrs->rhs_scale = rhs_scale; \ diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 6b7fecd191fc..18dd9aa01af5 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -60,7 +60,7 @@ Expr MakeQuantize(Expr data, double output_scale, int32_t output_zero_point, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->output_scale = output_scale; attrs->output_zero_point = output_zero_point; attrs->out_dtype = std::move(out_dtype); diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index ec8c845dc8c6..93284cb38e87 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -164,7 +164,7 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, // used by frontend FFI. Expr MakeRequantize(Expr data, double input_scale, int32_t input_zero_point, double output_scale, int32_t output_zero_point, std::string rounding, DataType out_dtype) { - auto attrs = make_node(); + auto attrs = make_object(); attrs->input_scale = std::move(input_scale); attrs->input_zero_point = std::move(input_zero_point); attrs->output_scale = std::move(output_scale); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 6659a44e63f6..e359296c1d1a 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -84,7 +84,7 @@ static inline Expr Requantize(const Expr& data, const Array& input_sh double input_scale, int32_t input_zero_point, double output_scale, int32_t output_zero_point, const DataType& out_dtype, const std::string& rounding = "UPWARD") { - auto attrs = make_node(); + auto attrs = make_object(); attrs->input_scale = std::move(input_scale); attrs->input_zero_point = std::move(input_zero_point); attrs->output_scale = std::move(output_scale); diff --git a/src/runtime/vm/memory_manager.h b/src/runtime/vm/memory_manager.h index 292fb55e5995..95a154ce6bee 100644 --- a/src/runtime/vm/memory_manager.h +++ b/src/runtime/vm/memory_manager.h @@ -137,7 +137,7 @@ class Storage : public ObjectRef { public: explicit Storage(Buffer buffer); - TVM_DEFINE_OBJECT_REF_METHODS_MUT(Storage, ObjectRef, StorageObj); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Storage, ObjectRef, StorageObj); }; } // namespace vm diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc index 62739bb22004..e587f385734f 100644 --- a/src/schedule/auto_inline_elem_wise.cc +++ b/src/schedule/auto_inline_elem_wise.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -33,7 +33,7 @@ class ElemWiseDetector : public ir::IRVisitor { public: explicit ElemWiseDetector(Array axis) : axis_(axis) {} - void Visit(const NodeRef& e) final { + void Visit(const ObjectRef& e) final { if (!is_elem_wise_) return; IRVisitor::Visit(e); } diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index e213df5e659d..d4baded91f7c 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -47,7 +47,7 @@ struct GraphContext { /*! \brief The bind map */ std::unordered_map bind_map; /*! \brief map from op to stage */ - std::unordered_map op2stage_; + std::unordered_map op2stage_; }; bool NeedRelax(const IterVar& iv, diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index 518f05a03250..c3024a71977f 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -62,7 +62,7 @@ namespace std { template <> struct hash<::tvm::schedule::TensorDimKey> { std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const { - size_t lhs = ::tvm::NodeHash()(k.f); + size_t lhs = ::tvm::ObjectHash()(k.f); size_t rhs = static_cast(k.value_index) << 16UL | static_cast(k.dim); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); @@ -80,7 +80,7 @@ namespace schedule { ReadGraph CreateReadGraph(const Array& roots) { ReadGraph rmap; std::vector stack; - std::unordered_set visited; + std::unordered_set visited; // initialize the roots for (Operation op : roots) { stack.push_back(op); @@ -106,9 +106,9 @@ ReadGraph CreateReadGraph(const Array& roots) { // Return if op is inside the subgraph. bool GetSubGraphByPostDFS_( const Operation& op, - const std::unordered_set& boundary, + const std::unordered_set& boundary, bool include_bounary, - std::unordered_map* visited, + std::unordered_map* visited, Array* result) { if (visited->count(op.get())) { return visited->at(op.get()); @@ -143,11 +143,11 @@ Array GetSubGraph(const Array& outputs, const Array& inputs, bool include_inputs) { Array result; - std::unordered_set boundary; + std::unordered_set boundary; for (Tensor t : inputs) { boundary.insert(t->op.get()); } - std::unordered_map visited; + std::unordered_map visited; for (Tensor t : outputs) { GetSubGraphByPostDFS_(t->op, boundary, include_inputs, &visited, &result); @@ -192,7 +192,7 @@ FeedGraph CreateFeedGraph(const ReadGraph& g) { AttachPath CreateAttachPath(Schedule sch) { AttachPath ret; for (Stage stage : sch->stages) { - std::unordered_set visited; + std::unordered_set visited; Array path; for (Stage s = stage; s.defined();) { CHECK(!visited.count(s.get())) @@ -236,7 +236,7 @@ using ReachGraph = std::unordered_map >; ReachGraph GetReachGraph(const Array& ops) { ReachGraph reach; - std::unordered_set bset; + std::unordered_set bset; for (size_t i = 0; i < ops.size(); ++i) { bset.insert(ops[i].get()); } @@ -255,20 +255,20 @@ ReachGraph GetReachGraph(const Array& ops) { } } } else if (const auto* compute_op = op.as()) { - std::unordered_map vmap; + std::unordered_map vmap; const auto& axis = compute_op->axis; Tensor t = op.output(0); for (size_t i = 0; i < axis.size(); ++i) { vmap[axis[i]->var.get()] = TensorDimKey(t, i); reach[TensorDimKey(t, i)] = {}; } - auto fvisit = [&vmap, &reach, &bset](const NodeRef& n) { + auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) { const ir::Call *call = n.as(); if (call != nullptr && call->func.defined()) { if (!bset.count(call->func.get())) return; for (size_t i = 0; i < call->args.size(); ++i) { TensorDimKey dkey(call, static_cast(i)); - auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) { + auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) { const Variable *v = node.as(); auto it = vmap.find(v); if (it != vmap.end()) { @@ -304,8 +304,8 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { const ScanOpNode* scan = scan_op.as(); Array body = ScanGetBody(scan_op); - std::unordered_map exact_reach; - std::unordered_set fail_set; + std::unordered_map exact_reach; + std::unordered_set fail_set; for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) { for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { @@ -342,7 +342,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } } else if (const auto* compute_op = op.as()) { - std::unordered_map > vmap; + std::unordered_map > vmap; const auto& axis = compute_op->axis; for (size_t i = 0; i < axis.size(); ++i) { std::vector keys; @@ -352,7 +352,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { vmap[axis[i]->var.get()] = std::move(keys); } auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( - const NodeRef& n) { + const ObjectRef& n) { const ir::Call *call = n.as(); if (call != nullptr && call->func.defined()) { for (size_t i = 0; i < call->args.size(); ++i) { diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index c9afcf45a1f2..70a73abc4698 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -34,7 +34,7 @@ namespace tvm { // find first occurance location in leaf template size_t FindNodeRef(ArrayNode* array_node, const T& v) { - const Node* n = v.get(); + const Object* n = v.get(); for (size_t i = 0; i < array_node->data.size(); ++i) { if (array_node->data[i].get() == n) return i; } @@ -98,7 +98,7 @@ Expr InjectPredicate(const Array& predicates, if (predicates.size() == 0) return body; const Reduce* reduce = body.as(); if (reduce) { - auto n = make_node(*reduce); + auto n = make_object(*reduce); n->condition = n->condition && arith::ComputeReduce(predicates, Expr()); return Expr(n); } @@ -591,7 +591,7 @@ void InjectInline(ScheduleNode* sch) { CHECK_EQ(new_body[j].size(), r->source.size()); CHECK(r != nullptr); for (size_t k = 0; k < new_body[j].size(); ++k) { - auto n = make_node(*r); + auto n = make_object(*r); n->value_index = static_cast(k); n->dtype = r->source[k].dtype(); new_body[j].Set(k, Expr(n)); @@ -734,11 +734,11 @@ Array Schedule::rfactor(const Tensor& tensor, const int factor_axis_pos = \ factor_axis >= 0 ? factor_axis : static_cast(compute_op->axis.size() + 1) + factor_axis; CHECK_LE(factor_axis_pos, compute_op->axis.size()); - auto n = make_node(); + auto n = make_object(); n->name = compute_op->name + ".rf"; { // axis relacement. - auto iv_node = make_node(); + auto iv_node = make_object(); iv_node->dom = dom_map.at(axis); CHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0"; @@ -779,7 +779,7 @@ Array Schedule::rfactor(const Tensor& tensor, for (IterVar iv : reduce_stage->leaf_iter_vars) { if (touch_map.count(iv) && !iv.same_as(axis)) { CHECK_EQ(iv->iter_type, kCommReduce); - auto ncpy = make_node(*iv.operator->()); + auto ncpy = make_object(*iv.operator->()); ncpy->dom = dom_map.at(iv); n->reduce_axis.push_back(IterVar(ncpy)); } diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 7a2ab5a4d8b9..ec73c67bedff 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -33,7 +33,7 @@ namespace { // find first occurance location in leaf template size_t FindNodeRef(ArrayNode* array_node, const T& v) { - const Node* n = v.get(); + const Object* n = v.get(); for (size_t i = 0; i < array_node->data.size(); ++i) { if (array_node->data[i].get() == n) return i; } @@ -88,7 +88,7 @@ void Split(StageNode* self, } // namespace Stage::Stage(Operation op) { - auto n = make_node(); + auto n = make_object(); n->op = op; n->origin_op = op; n->all_iter_vars = op->root_iter_vars(); @@ -182,16 +182,16 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) FindLeafVar(all_vars, leaf_vars, ivar); auto it = self->iter_var_attrs.find(ivar); - NodePtr n; + ObjectPtr n; if (it != self->iter_var_attrs.end()) { - n = make_node(*(*it).second.operator->()); + n = make_object(*(*it).second.operator->()); if (n->bind_thread.defined() && !n->bind_thread.same_as(thread_ivar)) { LOG(WARNING) << "Axis " << ivar << " is already bind to another thread " << n->bind_thread; } } else { - n = make_node(); + n = make_object(); } n->bind_thread = thread_ivar; self->iter_var_attrs.Set(ivar, IterVarAttr(n)); @@ -353,11 +353,11 @@ inline void UpdateIterVarAttr(StageNode* self, FindLeafVar(all_vars, leaf_vars, var); } auto it = self->iter_var_attrs.find(var); - NodePtr n; + ObjectPtr n; if (it != self->iter_var_attrs.end()) { - n = make_node(*(*it).second.operator->()); + n = make_object(*(*it).second.operator->()); } else { - n = make_node(); + n = make_object(); } fupdate(n.get()); self->iter_var_attrs.Set(var, IterVarAttr(n)); @@ -422,11 +422,11 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) { ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); FindLeafVar(all_vars, leaf_vars, var); auto it = self->iter_var_attrs.find(var); - NodePtr n; + ObjectPtr n; if (it != self->iter_var_attrs.end()) { - n = make_node(*(*it).second.operator->()); + n = make_object(*(*it).second.operator->()); } else { - n = make_node(); + n = make_object(); } n->prefetch_data.push_back(tensor); n->prefetch_offset.push_back(offset); @@ -493,16 +493,16 @@ Stage& Stage::opengl() { } Stage CopyStage(const Stage& s) { - NodePtr n = - make_node(*s.operator->()); + ObjectPtr n = + make_object(*s.operator->()); return Stage(n); } Schedule Schedule::copy() const { // map of stages. const ScheduleNode* self = operator->(); - std::unordered_map smap; - NodePtr n = make_node(); + std::unordered_map smap; + ObjectPtr n = make_object(); n->outputs = self->outputs; // Copy the stages. for (Stage s : self->stages) { @@ -605,7 +605,7 @@ Stage Schedule::create_group(const Array& outputs, int count{0}; }; // Map of group->touched counter - std::unordered_map counter; + std::unordered_map counter; // The parent group; Stage parent_group; // Detect common parent and child. @@ -624,7 +624,7 @@ Stage Schedule::create_group(const Array& outputs, } } // Create the new group stage. - Stage gstage(make_node()); + Stage gstage(make_object()); gstage->group = parent_group; if (parent_group.defined()) { ++parent_group->num_child_stages; @@ -716,7 +716,7 @@ bool ScheduleNode::Contain(const Operation& op) const { } Schedule ScheduleNode::make(Array ops) { - auto n = make_node(); + auto n = make_object(); Schedule sch(n); n->outputs = ops; auto g = schedule::CreateReadGraph(n->outputs); @@ -759,7 +759,7 @@ IterVarRelation SplitNode::make(IterVar parent, IterVar inner, Expr factor, Expr nparts) { - auto n = make_node(); + auto n = make_object(); n->parent = parent; n->outer = outer; n->inner = inner; @@ -770,7 +770,7 @@ IterVarRelation SplitNode::make(IterVar parent, IterVarRelation FuseNode::make( IterVar outer, IterVar inner, IterVar fused) { - auto n = make_node(); + auto n = make_object(); n->outer = outer; n->inner = inner; n->fused = fused; @@ -778,14 +778,14 @@ IterVarRelation FuseNode::make( } IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { - auto n = make_node(); + auto n = make_object(); n->parent = parent; n->rebased = rebased; return IterVarRelation(n); } IterVarRelation SingletonNode::make(IterVar iter) { - auto n = make_node(); + auto n = make_object(); n->iter = iter; return IterVarRelation(n); } diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 3a9d0bcb2a98..0103410e6132 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -220,13 +220,13 @@ class SchedulePostProc : public IRMutator { } } } else if (op->attr_key == ir::attr::buffer_bind_scope) { - Array tuple = Downcast >(op->node); + Array tuple = Downcast >(op->node); Tensor tensor = Downcast(tuple[1]); auto it = replace_op_.find(tensor->op.get()); if (it != replace_op_.end()) { if (it->second.defined()) { return AttrStmt::make( - Array{tuple[0], it->second.output(tensor->value_index)}, + Array{tuple[0], it->second.output(tensor->value_index)}, op->attr_key, op->value, Mutate(op->body)); } else { return this->Mutate(op->body); @@ -344,7 +344,7 @@ class SchedulePostProc : public IRMutator { replace_op_[src->op.get()] = repl_op; } // The thread extent scope. - std::unordered_map thread_extent_scope_; + std::unordered_map thread_extent_scope_; // The scan value std::unordered_map var_value_; // buffer replacement @@ -352,7 +352,7 @@ class SchedulePostProc : public IRMutator { // buffere realization to be replaced std::unordered_map replace_realize_; // replace producer consumer. - std::unordered_map replace_op_; + std::unordered_map replace_op_; }; Stmt ScheduleOps( diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 4428642b281d..7aab3edb6aaf 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -220,8 +220,8 @@ TEST(Map, Iterator) { using namespace tvm; Expr a = 1, b = 2; Map map1{{a, b}}; - std::unordered_map map2(map1.begin(), - map1.end()); + std::unordered_map + map2(map1.begin(), map1.end()); CHECK(map2[a].as()->value == 2); } diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 7ecf4590ca12..debfb36f936b 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -25,7 +25,7 @@ TEST(Expr, Basic) { using namespace tvm; Var x("x"); auto z = max(x + 1 + 2, 100); - NodeRef tmp = z; + ObjectRef tmp = z; Expr zz = Downcast(tmp); std::ostringstream os; os << z; @@ -39,7 +39,7 @@ TEST(ExprNodeRef, Basic) { Var x("x"); Expr z = max(x + 1 + 2, 100); const ir::Max* op = z.as(); - CHECK(GetRef(op).same_as(z)); + CHECK(GetRef(op).same_as(z)); } diff --git a/tests/cpp/ir_visitor_test.cc b/tests/cpp/ir_visitor_test.cc index 079be65079ca..4282a0026ee6 100644 --- a/tests/cpp/ir_visitor_test.cc +++ b/tests/cpp/ir_visitor_test.cc @@ -28,7 +28,7 @@ TEST(IRVisitor, CountVar) { Var x("x"), y; auto z = x + 1 + y + y; - ir::PostOrderVisit(z, [&n_var](const NodeRef& n) { + ir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { if (n.as()) ++n_var; }); CHECK_EQ(n_var, 2); diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index 1b510a45661f..fa184bfee7d1 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -81,7 +81,7 @@ inline Array make_extern(const Array< Array >& out_shapes, FExtern fextern, std::string name, std::string tag, - ::tvm::Map attrs) { + ::tvm::Map attrs) { CHECK_EQ(out_shapes.size(), out_types.size()) << "make_extern: out_shapes and out_types must have equal size"; diff --git a/topi/include/topi/nn/softmax.h b/topi/include/topi/nn/softmax.h index fa985d1b2086..c3124bbe6f58 100644 --- a/topi/include/topi/nn/softmax.h +++ b/topi/include/topi/nn/softmax.h @@ -61,7 +61,7 @@ inline Tensor softmax(const Tensor &x, auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2"); auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false); - tvm::Map attrs; + tvm::Map attrs; attrs.Set("axis", Integer(axis)); auto insert_reduce_index = [axis, ndim](const Array &indices, diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 87837f82635b..11a90215d71f 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -757,7 +757,7 @@ inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { auto target = Target::Current(false); Array outs; - NodeRef argNodeRef = args[0]; + ObjectRef argNodeRef = args[0]; if (argNodeRef->type_index() == outs->type_index()) { outs = args[0]; } else {