From d2c950685fbffb89f32093604af34b9850145f34 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 11 Jun 2020 16:56:42 -0700 Subject: [PATCH] [REFACTOR][API-Change] Migrate all Object construction to constructor. This PR migrates all the remaining object constructions to the new constructor style that is consistent with the rest of the codebase and changes the affected files accordingly. Other changes: - ThreadScope::make -> ThreadScope::Create - StorageScope::make -> StorageScope::Create --- docs/dev/codebase_walkthrough.rst | 2 +- docs/dev/relay_add_pass.rst | 2 +- docs/dev/relay_pass_infra.rst | 10 +- include/tvm/ir/span.h | 4 +- include/tvm/te/operation.h | 90 ++++++++++--- include/tvm/te/schedule.h | 68 +++++++--- include/tvm/te/tensor.h | 123 ++++++++---------- include/tvm/te/tensor_intrin.h | 36 ++--- include/tvm/tir/buffer.h | 123 ++++++++---------- include/tvm/tir/data_layout.h | 58 +++------ src/driver/driver_api.cc | 4 +- src/ir/span.cc | 8 +- src/relay/backend/compile_engine.cc | 3 +- src/relay/backend/interpreter.cc | 13 +- src/runtime/thread_storage_scope.h | 10 +- src/target/llvm/codegen_amdgpu.cc | 2 +- src/target/llvm/codegen_llvm.cc | 2 +- src/target/llvm/codegen_nvptx.cc | 2 +- src/target/source/codegen_metal.cc | 2 +- src/target/source/codegen_opencl.cc | 2 +- src/target/spirv/codegen_spirv.cc | 4 +- src/te/autodiff/jacobian.cc | 5 +- src/te/operation/compute_op.cc | 25 ++-- src/te/operation/compute_op.h | 6 +- src/te/operation/extern_op.cc | 15 ++- src/te/operation/hybrid_op.cc | 12 +- src/te/operation/op_util.cc | 4 +- src/te/operation/placeholder_op.cc | 6 +- src/te/operation/scan_op.cc | 18 ++- src/te/operation/tensor_compute_op.cc | 21 ++- src/te/operation/tensorize.cc | 2 +- src/te/schedule/bound.cc | 6 +- src/te/schedule/schedule_dataflow_rewrite.cc | 27 ++-- src/te/schedule/schedule_lang.cc | 38 +++--- src/te/tensor.cc | 38 +++--- src/tir/ir/buffer.cc | 20 +-- src/tir/ir/data_layout.cc | 10 +- src/tir/transforms/inject_copy_intrin.cc | 12 +- src/tir/transforms/loop_partition.cc | 4 +- .../lower_device_storage_access_info.cc | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 4 +- src/tir/transforms/lower_warp_memory.cc | 2 +- src/tir/transforms/storage_access.cc | 6 +- src/tir/transforms/storage_flatten.cc | 10 +- src/tir/transforms/storage_rewrite.cc | 2 +- src/tir/transforms/thread_storage_sync.cc | 6 +- tests/cpp/utvm_runtime_standalone_test.cc | 6 +- topi/include/topi/detail/extern.h | 6 +- topi/include/topi/transform.h | 4 +- 49 files changed, 469 insertions(+), 416 deletions(-) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index a66328fef7c9..8674c8e2c07e 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -84,7 +84,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``. :: inline Schedule create_schedule(Array ops) { - return ScheduleNode::make(ops); + return Schedule(ops); } ``Schedule`` consists of collections of ``Stage`` and output ``Operation``. diff --git a/docs/dev/relay_add_pass.rst b/docs/dev/relay_add_pass.rst index 2fc463695259..a82ae4ff717a 100644 --- a/docs/dev/relay_add_pass.rst +++ b/docs/dev/relay_add_pass.rst @@ -138,7 +138,7 @@ is shown below. if (g->tuple == t) { return GetRef(g); } else { - return TupleGetItemNode::make(t, g->index); + return TupleGetItem(t, g->index); } } diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index 6c2b13947011..446a91bceff7 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -344,13 +344,13 @@ registration. .. code:: c++ // Create a simple Relay program. - auto tensor_type = relay::TensorTypeNode::make({}, tvm::Bool()); - auto x = relay::VarNode::make("x", relay::Type()); - auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); + auto tensor_type = relay::TensorType({}, tvm::Bool()); + auto x = relay::Var("x", relay::Type()); + auto f = relay::Function(tvm::Array{ x }, x, relay::Type(), {}); - auto y = relay::VarNode::make("y", tensor_type); + auto y = relay::Var("y", tensor_type); auto call = relay::Call(f, tvm::Array{ y }); - auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); + auto fx = relay::Function(tvm::Array{ y }, call, relay::Type(), {}); // Create a module for optimization. auto mod = IRModule::FromExpr(fx); diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 1ed6848eb9e1..84d6a7b0f877 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -97,14 +97,14 @@ class SpanNode : public Object { equal(col_offset, other->col_offset); } - TVM_DLL static Span make(SourceName source, int lineno, int col_offset); - static constexpr const char* _type_key = "Span"; TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); }; class Span : public ObjectRef { public: + TVM_DLL Span(SourceName source, int lineno, int col_offset); + TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); }; diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index c161cc9708df..4b7037aec7dc 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -177,12 +177,22 @@ class PlaceholderOpNode : public OperationNode { v->Visit("shape", &shape); v->Visit("dtype", &dtype); } - static Operation make(std::string name, Array shape, DataType dtype); static constexpr const char* _type_key = "PlaceholderOp"; TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; +/*! + * \brief Managed reference to PlaceholderOpNode + * \sa PlaceholderOpNode + */ +class PlaceholderOp : public Operation { + public: + TVM_DLL PlaceholderOp(std::string name, Array shape, DataType dtype); + + TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode); +}; + /*! * \brief A Compute op that compute a tensor on certain domain. * This is the base class for ComputeOp (operating on a scalar at a time) and @@ -237,13 +247,23 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { v->Visit("reduce_axis", &reduce_axis); v->Visit("body", &body); } - static Operation make(std::string name, std::string tag, Map attrs, - Array axis, Array body); static constexpr const char* _type_key = "ComputeOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); }; +/*! + * \brief Managed reference to ComputeOpNode + * \sa ComputeOpNode + */ +class ComputeOp : public Operation { + public: + TVM_DLL ComputeOp(std::string name, std::string tag, Map attrs, + Array axis, Array body); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode); +}; + /*! * \brief A TenorCompute op that compute a tensor with an tensor intrinsic. */ @@ -285,15 +305,25 @@ class TensorComputeOpNode : public BaseComputeOpNode { v->Visit("input_regions", &input_regions); v->Visit("scalar_inputs", &scalar_inputs); } - static Operation make(std::string name, std::string tag, Array axis, - Array reduce_axis, int schedulable_ndim, TensorIntrin intrin, - Array tensors, Array regions, - Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode); }; +/*! + * \brief Managed reference to TensorComputeOpNode + * \sa TensorComputeOpNode + */ +class TensorComputeOp : public Operation { + public: + TVM_DLL TensorComputeOp(std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, TensorIntrin intrin, + Array tensors, Array regions, + Array scalar_inputs); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode); +}; + /*! * \brief Symbolic scan. */ @@ -353,14 +383,24 @@ class ScanOpNode : public OperationNode { v->Visit("inputs", &inputs); v->Visit("spatial_axis_", &spatial_axis_); } - static Operation make(std::string name, std::string tag, Map attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array input); static constexpr const char* _type_key = "ScanOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); }; +/*! + * \brief Managed reference to ScanOpNode + * \sa ScanOpNode + */ +class ScanOp : public Operation { + public: + TVM_DLL ScanOp(std::string name, std::string tag, Map attrs, IterVar axis, + Array init, Array update, Array state_placeholder, + Array input); + + TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode); +}; + /*! * \brief External computation that cannot be splitted. */ @@ -404,14 +444,24 @@ class ExternOpNode : public OperationNode { v->Visit("output_placeholders", &output_placeholders); v->Visit("body", &body); } - TVM_DLL static Operation make(std::string name, std::string tag, Map attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body); static constexpr const char* _type_key = "ExternOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); }; +/*! + * \brief Managed reference to ExternOpNode + * \sa ExternOpNode + */ +class ExternOp : public Operation { + public: + TVM_DLL ExternOp(std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode); +}; + /*! * \brief A computation operator that generated by hybrid script. */ @@ -459,13 +509,23 @@ class HybridOpNode : public OperationNode { v->Visit("axis", &axis); v->Visit("body", &body); } - TVM_DLL static Operation make(std::string name, std::string tag, Map attrs, - Array inputs, Array outputs, Stmt body); static constexpr const char* _type_key = "HybridOp"; TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode); }; +/*! + * \brief Managed reference to HybridOpNode + * \sa HybridOpNode + */ +class HybridOp : public Operation { + public: + TVM_DLL HybridOp(std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode); +}; + /*! * \brief Construct a new Var expression * \param name_hint The name hint for the expression diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index f74a008a4c74..ee4fb33349f7 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -277,6 +277,12 @@ class Schedule : public ObjectRef { public: Schedule() {} explicit Schedule(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief Create a schedule for array of ops(and their dependencies). + * \param ops The ops to be scheduled. + * \return sch The created Schedule. + */ + TVM_DLL explicit Schedule(Array ops); /*! * \brief Get a copy of current schedule. * \return The copied schedule. @@ -553,13 +559,6 @@ class ScheduleNode : public Object { */ TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); } - /*! - * \brief Create a schedule for array of ops(and their dependencies). - * \param ops The ops to be scheduled. - * \return sch The created Schedule. - */ - TVM_DLL static Schedule make(Array ops); - static constexpr const char* _type_key = "Schedule"; TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object); }; @@ -569,7 +568,7 @@ class ScheduleNode : public Object { * \param ops The ops to be scheduled. * \return sch The created Schedule. */ -inline Schedule create_schedule(Array ops) { return ScheduleNode::make(ops); } +inline Schedule create_schedule(Array ops) { return Schedule(ops); } /*! \brief node container for IterVar attr */ class IterVarAttrNode : public Object { @@ -648,13 +647,21 @@ class SplitNode : public IterVarRelationNode { v->Visit("nparts", &nparts); } - static IterVarRelation make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, - PrimExpr nparts); - static constexpr const char* _type_key = "Split"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode); }; +/*! + * \brief Managed reference to SplitNode + * \sa SplitNode + */ +class Split : public IterVarRelation { + public: + TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts); + + TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode); +}; + /*! * \brief Fuse two domains into one domain. */ @@ -673,12 +680,21 @@ class FuseNode : public IterVarRelationNode { v->Visit("fused", &fused); } - static IterVarRelation make(IterVar outer, IterVar inner, IterVar fused); - static constexpr const char* _type_key = "Fuse"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode); }; +/*! + * \brief Managed reference to FuseNode + * \sa FuseNode + */ +class Fuse : public IterVarRelation { + public: + TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused); + + TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode); +}; + /*! * \brief Rebase the iteration to make min to be 0. * This is useful to normalize the Schedule @@ -696,12 +712,21 @@ class RebaseNode : public IterVarRelationNode { v->Visit("rebased", &rebased); } - static IterVarRelation make(IterVar parent, IterVar rebased); - static constexpr const char* _type_key = "Rebase"; TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode); }; +/*! + * \brief Managed reference to RebaseNode + * \sa RebaseNode + */ +class Rebase : public IterVarRelation { + public: + TVM_DLL Rebase(IterVar parent, IterVar rebased); + + TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode); +}; + /*! * \brief Singleton iterator [0, 1) */ @@ -712,12 +737,21 @@ class SingletonNode : public IterVarRelationNode { void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); } - static IterVarRelation make(IterVar iter); - static constexpr const char* _type_key = "Singleton"; TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode); }; +/*! + * \brief Managed reference to SingletonNode + * \sa SingletonNode + */ +class Singleton : public IterVarRelation { + public: + TVM_DLL explicit Singleton(IterVar iter); + + TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode); +}; + /*! \brief Container for specialization conditions. */ class SpecializedConditionNode : public Object { public: diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 045d186f58f5..0c4af4bc636a 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -40,25 +40,68 @@ namespace te { using arith::IntSet; using namespace tvm::tir; -// Internal node container of Tensor -class TensorNode; // internal node container for Operation class OperationNode; +/*! \brief Operation that produces tensors */ +class Operation : public tir::FunctionRef { + public: + /*! \brief default constructor */ + Operation() {} + explicit Operation(ObjectPtr n) : FunctionRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const OperationNode* operator->() const; + /*! + * \brief get the i-th output of the operation. + * \param i the output index. + * \return The i-th output. + */ + TVM_DLL Tensor output(size_t i) const; + /*! \brief specify container node */ + using ContainerType = OperationNode; +}; + +/*! \brief Node to represent a tensor */ +class TensorNode : public DataProducerNode { + public: + /*! \brief The shape of the tensor */ + Array shape; + /*! \brief data type in the content of the tensor */ + DataType dtype; + /*! \brief the source operation, can be None */ + Operation op; + /*! \brief the output index from source operation */ + int value_index{0}; + /*! \brief constructor */ + TensorNode() {} + + void VisitAttrs(AttrVisitor* v) { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("op", &op); + v->Visit("value_index", &value_index); + } + + Array GetShape() const final { return shape; } + + DataType GetDataType() const final { return dtype; } + + TVM_DLL String GetNameHint() const final; + + static constexpr const char* _type_key = "Tensor"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); +}; + /*! * \brief Tensor structure representing a possible input, * or intermediate computation result. */ class Tensor : public DataProducer { public: - /*! \brief default constructor, used internally */ - Tensor() {} - explicit Tensor(ObjectPtr n) : DataProducer(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const TensorNode* operator->() const; + TVM_DLL Tensor(Array shape, DataType dtype, Operation op, int value_index); /*! * \brief check if two tensors equals each other. * \param other tensor to be checked. @@ -131,69 +174,11 @@ class Tensor : public DataProducer { * \return the subsequent slice. */ inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); } - /*! \brief specify container node */ - using ContainerType = TensorNode; -}; -/*! \brief Operation that produces tensors */ -class Operation : public tir::FunctionRef { - public: - /*! \brief default constructor */ - Operation() {} - explicit Operation(ObjectPtr n) : FunctionRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const OperationNode* operator->() const; - /*! - * \brief get the i-th output of the operation. - * \param i the output index. - * \return The i-th output. - */ - TVM_DLL Tensor output(size_t i) const; - /*! \brief specify container node */ - using ContainerType = OperationNode; -}; - -/*! \brief Node to represent a tensor */ -class TensorNode : public DataProducerNode { - public: - /*! \brief The shape of the tensor */ - Array shape; - /*! \brief data type in the content of the tensor */ - DataType dtype; - /*! \brief the source operation, can be None */ - Operation op; - /*! \brief the output index from source operation */ - int value_index{0}; - /*! \brief constructor */ - TensorNode() {} - - void VisitAttrs(AttrVisitor* v) { - v->Visit("shape", &shape); - v->Visit("dtype", &dtype); - v->Visit("op", &op); - v->Visit("value_index", &value_index); - } - - Array GetShape() const final { return shape; } - - DataType GetDataType() const final { return dtype; } - - TVM_DLL String GetNameHint() const final; - - TVM_DLL static Tensor make(Array shape, DataType dtype, Operation op, int value_index); - - static constexpr const char* _type_key = "Tensor"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); + TVM_DEFINE_OBJECT_REF_METHODS(Tensor, DataProducer, TensorNode); }; // Implementations of inline functions -inline const TensorNode* Tensor::operator->() const { - return static_cast(get()); -} - inline size_t Tensor::ndim() const { return (*this)->shape.size(); } inline bool Tensor::operator==(const Tensor& other) const { diff --git a/include/tvm/te/tensor_intrin.h b/include/tvm/te/tensor_intrin.h index 7e76efe69691..22f29defbb64 100644 --- a/include/tvm/te/tensor_intrin.h +++ b/include/tvm/te/tensor_intrin.h @@ -32,24 +32,6 @@ namespace tvm { namespace te { -// Internal node container of tensor intrinsics. -class TensorIntrinNode; - -/*! \brief Tensor intrinsic node. */ -class TensorIntrin : public ObjectRef { - public: - TensorIntrin() {} - explicit TensorIntrin(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const TensorIntrinNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = TensorIntrinNode; -}; - /*! \brief Node to represent a Tensor intrinsic operator */ class TensorIntrinNode : public Object { public: @@ -100,17 +82,21 @@ class TensorIntrinNode : public Object { v->Visit("reduce_update", &reduce_update); } - TVM_DLL static TensorIntrin make(std::string name, Operation op, Array inputs, - Array buffers, Array scalar_params, Stmt body, - Stmt reduce_init, Stmt reduce_update); - static constexpr const char* _type_key = "TensorIntrin"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); }; -inline const TensorIntrinNode* TensorIntrin::operator->() const { - return static_cast(get()); -} +/*! + * \brief Managed reference to TensorIntrinNode + * \sa TensorIntrinNode + */ +class TensorIntrin : public ObjectRef { + public: + TVM_DLL TensorIntrin(std::string name, Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode); +}; class TensorIntrinCallNode : public Object { public: diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 6904f2a4ed40..5b07cc5ce7d6 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -32,8 +32,6 @@ namespace tvm { namespace tir { -// Internal node container Buffer -class BufferNode; // forward declare Stmt class Stmt; @@ -45,62 +43,6 @@ enum BufferType : int { kAutoBroadcast = 2, }; -/*! - * \brief Buffer is a symbolic n-darray structure. - * It is a composition of primitive symbolic types, - * used to specify the memory layout of the Tensor used in program input. - */ -class Buffer : public ObjectRef { - public: - Buffer() {} - explicit Buffer(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief Return a new buffer that is equivalent with current one - * but always add stride field. - * \return The strided version of the buffer. - */ - TVM_DLL Buffer MakeStrideView() const; - /*! - * \brief Make a new symbolic buffer representing a slice of the buffer. - * \param begins The beginning position of each dimension. - * \param extents The extent of each dimension. - * \note This function will make target buffer as compact as possible. - * If stride is not needed in the slice, it won't be presented - * \return the result buffer. - */ - TVM_DLL Buffer MakeSlice(Array begins, Array extents) const; - /*! - * \brief Get access ptr to the entire buffer. - * \param access_mask The access mask - * \param ptr_type The type of the pointer. - * \param content_lanes The number of lanes for the (data) type. - * \param offset The offset of ptr. - */ - TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), - int content_lanes = 1, - PrimExpr offset = IntImm(DataType::Int(32), 0)) const; - /*! - * \brief Create an Expr that does a vector load at begin index. - * \param begin The beginning index - * \param dtype The data type to be loaded. - */ - TVM_DLL PrimExpr vload(Array begin, DataType dtype) const; - /*! - * \brief Create a Stmt that does a vector store at begin index. - * \param begin The beginning index - * \param value The value to be stored. - */ - TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const BufferNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = BufferNode; -}; - /*! \brief Node to represent a buffer */ class BufferNode : public Object { public: @@ -176,22 +118,65 @@ class BufferNode : public Object { return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); } - // User can specify data_alignment and offset_factor to be 0 - // A default value will be picked. - TVM_DLL static Buffer make(Var ptr, DataType dtype, Array shape, - Array strides, PrimExpr elem_offset, std::string name, - std::string scope, int data_alignment, int offset_factor, - BufferType buffer_type); - static constexpr const char* _type_key = "Buffer"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); }; -inline const BufferNode* Buffer::operator->() const { - return static_cast(get()); -} +/*! + * \brief Buffer is a symbolic n-darray structure. + * It is a composition of primitive symbolic types, + * used to specify the memory layout of the Tensor used in program input. + */ +class Buffer : public ObjectRef { + public: + // User can specify data_alignment and offset_factor to be 0 + // A default value will be picked. + TVM_DLL Buffer(Var ptr, DataType dtype, Array shape, Array strides, + PrimExpr elem_offset, std::string name, std::string scope, int data_alignment, + int offset_factor, BufferType buffer_type); + + /*! + * \brief Return a new buffer that is equivalent with current one + * but always add stride field. + * \return The strided version of the buffer. + */ + TVM_DLL Buffer MakeStrideView() const; + /*! + * \brief Make a new symbolic buffer representing a slice of the buffer. + * \param begins The beginning position of each dimension. + * \param extents The extent of each dimension. + * \note This function will make target buffer as compact as possible. + * If stride is not needed in the slice, it won't be presented + * \return the result buffer. + */ + TVM_DLL Buffer MakeSlice(Array begins, Array extents) const; + /*! + * \brief Get access ptr to the entire buffer. + * \param access_mask The access mask + * \param ptr_type The type of the pointer. + * \param content_lanes The number of lanes for the (data) type. + * \param offset The offset of ptr. + */ + TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), + int content_lanes = 1, + PrimExpr offset = IntImm(DataType::Int(32), 0)) const; + /*! + * \brief Create an Expr that does a vector load at begin index. + * \param begin The beginning index + * \param dtype The data type to be loaded. + */ + TVM_DLL PrimExpr vload(Array begin, DataType dtype) const; + /*! + * \brief Create a Stmt that does a vector store at begin index. + * \param begin The beginning index + * \param value The value to be stored. + */ + TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; + + TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); +}; /*! * \brief Construct a new buffer given shape, and dtype. @@ -199,7 +184,7 @@ inline const BufferNode* Buffer::operator->() const { * \param dtype The content data type. * \param name The name of the buffer * \return The created buffer. - * \sa BufferNode::make for complete constructor. + * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), std::string name = "buffer"); diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 0a20db6a0a63..f705247f6986 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -37,6 +37,8 @@ namespace tvm { namespace tir { +class Layout; + class LayoutAxis { public: static const LayoutAxis& Get(const char name); @@ -45,7 +47,7 @@ class LayoutAxis { static const LayoutAxis& Get(const tir::IterVar& itvar); // Get the singleton LayoutAxis using name[0] (size of name must be 1). - static const LayoutAxis& make(const std::string& name); + static const LayoutAxis& Get(const std::string& name); inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; } inline std::string name() const { return std::string(1, name_); } @@ -83,8 +85,16 @@ class LayoutAxis { const char name_; }; -class Layout; -// Internal node container Buffer +/*! + * \brief Layout is to describe how data is organized within an N-dimention tensor. + * It is composed of upper cases, lower cases and numbers, + * where upper case indicates a primal axis and + * the corresponding lower case with factor size indicates the subordinate axis. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). + * Layout for scalar is defined, while both its name and axes have size 0. + */ class LayoutNode : public Object { public: /*! \brief string representation of layout, "" for scalar. */ @@ -102,29 +112,16 @@ class LayoutNode : public Object { v->Visit("axes", &axes); } - TVM_DLL static Layout make(const std::string& layout); - static constexpr const char* _type_key = "Layout"; TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object); }; /*! - * \brief Layout is to describe how data is organized within an N-dimention tensor. - * It is composed of upper cases, lower cases and numbers, - * where upper case indicates a primal axis and - * the corresponding lower case with factor size indicates the subordinate axis. - * For example, NCHW16c can describe a 5-D tensor of - * [batch_size, channel, height, width, channel_block]. - * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). - * Layout for scalar is defined, while both its name and axes have size 0. + * \brief Managed reference to LayoutNode + * \sa LayoutNode */ class Layout : public ObjectRef { public: - explicit Layout(ObjectPtr n) : ObjectRef(n) {} - - /*! \brief default constructor */ - Layout() = default; - explicit Layout(const Array& axes); /*! \brief construct from a string */ @@ -138,13 +135,7 @@ class Layout : public ObjectRef { * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ - Layout(const std::string& name); // NOLINT(*) - - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - const LayoutNode* operator->() const { return static_cast(get()); } + TVM_DLL Layout(const std::string& name); // NOLINT(*) /*! * \brief access the internal node container @@ -292,10 +283,9 @@ class Layout : public ObjectRef { return os; } - using ContainerType = LayoutNode; + TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode); }; -class BijectiveLayout; // Internal node container BijectiveLayout class BijectiveLayoutNode : public Object { public: @@ -329,8 +319,6 @@ class BijectiveLayoutNode : public Object { */ class BijectiveLayout : public ObjectRef { public: - BijectiveLayout() = default; - explicit BijectiveLayout(ObjectPtr n) : ObjectRef(n) {} /*! * \brief The constructor * \param src_layout The source layout @@ -347,19 +335,9 @@ class BijectiveLayout : public ObjectRef { // Given the destination indices, recover the source indices. TVM_DLL Array BackwardIndex(const Array& dst_index) const; - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const BijectiveLayoutNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = BijectiveLayoutNode; + TVM_DEFINE_OBJECT_REF_METHODS(BijectiveLayout, ObjectRef, BijectiveLayoutNode); }; -inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { - return static_cast(get()); -} } // namespace tir } // namespace tvm diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index c667b49483e7..9d2a11c265dd 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -88,8 +88,8 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std elem_offset = PrimExpr(); } - return tir::BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - data_alignment, offset_factor, buffer_type); + return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, "", data_alignment, + offset_factor, buffer_type); } void GetBinds(const Array& args, bool compact, diff --git a/src/ir/span.cc b/src/ir/span.cc index 742c9858950c..565439f2ad74 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -61,17 +61,19 @@ TVM_REGISTER_NODE_TYPE(SourceNameNode) return static_cast(n)->name; }); -Span SpanNode::make(SourceName source, int lineno, int col_offset) { +Span::Span(SourceName source, int lineno, int col_offset) { auto n = make_object(); n->source = std::move(source); n->lineno = lineno; n->col_offset = col_offset; - return Span(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed(SpanNode::make); +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int lineno, int col_offset) { + return Span(source, lineno, col_offset); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index be749fdd3a97..3687b75c8ce8 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -217,8 +217,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> // Skip fcompute for device copy operators as it is not registered. if (op == device_copy_op_) { const auto* copy_input = inputs[0].operator->(); - outputs.push_back( - te::TensorNode::make(copy_input->shape, copy_input->dtype, te::Operation(), 0)); + outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); } else { LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); outputs = lowered_out->outputs; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index d9be91d78af3..9a75c0ab76ee 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -181,22 +181,25 @@ class InterpreterStateObj : public Object { v->Visit("stack", &stack); } - static InterpreterState make(Expr current_expr, Stack stack); - static constexpr const char* _type_key = "relay.InterpreterState"; TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateObj, Object); }; class InterpreterState : public ObjectRef { public: + using Frame = tvm::Map; + using Stack = tvm::Array; + + InterpreterState(Expr current_expr, Stack stack); + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateObj); }; -InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) { +InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack stack) { ObjectPtr n = make_object(); n->current_expr = std::move(current_expr); n->stack = std::move(stack); - return InterpreterState(n); + data_ = std::move(n); } // NOTE: the current interpreter assumes A-normal form. @@ -701,7 +704,7 @@ class Interpreter : public ExprFunctor, InterpreterStateObj::Frame frame = fr.locals; stack.push_back(frame); } - auto state = InterpreterStateObj::make(e, stack); + auto state = InterpreterState(e, stack); return state; } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 92e12b5f3a38..1917096bb24c 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -112,11 +112,11 @@ struct StorageScope { } } /*! - * \brief make storage scope from string + * \brief Create storage scope from string * \param s The string to be parsed. * \return The storage scope. */ - static StorageScope make(const std::string& s) { + static StorageScope Create(const std::string& s) { StorageScope r; if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; @@ -153,11 +153,11 @@ struct ThreadScope { /*! \brief the dimension index under the rank */ int dim_index{0}; /*! - * \brief make storage scope from string + * \brief Create storage scope from string * \param s The string to be parsed. * \return The storage scope. */ - static ThreadScope make(const std::string& s) { + static ThreadScope Create(const std::string& s) { ThreadScope r; if (s == "vthread" || s == "cthread") { // virtual thread at the same level as local @@ -199,7 +199,7 @@ class ThreadAxisConfig { std::vector filled(6, false); for (size_t i = 0; i < thread_axis_tags.size(); ++i) { const std::string& tag = thread_axis_tags[i]; - ThreadScope ts = ThreadScope::make(tag); + ThreadScope ts = ThreadScope::Create(tag); arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); filled[ts.rank * 3 + ts.dim_index] = true; } diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 280c9998a4b0..8e6b3a2ff22c 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -125,7 +125,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { // Return the thread index via intrinsics. llvm::Value* GetThreadIndex(const IterVar& iv) final { - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; if (ts.rank == 1) { switch (ts.dim_index) { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index b43e9889a4ae..3af9fc3f4519 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1260,7 +1260,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { const VarNode* v = op->node.as(); CHECK(v); alloc_storage_info_[v].scope = - runtime::StorageScope::make(op->value.as()->value); + runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); CHECK(v); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 353f322eade2..bc47ce1b1014 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -101,7 +101,7 @@ class CodeGenNVPTX : public CodeGenLLVM { // Return the thread index via intrinsics. llvm::Value* GetThreadIndex(const IterVar& iv) final { - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; if (ts.rank == 1) { switch (ts.dim_index) { diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index e381afb4db84..2c26ee977639 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -122,7 +122,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { - runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); work_dim = std::max(work_dim, scope.dim_index + 1); } if (work_dim != 0) { diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 746d418b6a37..8616853d8883 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -75,7 +75,7 @@ std::string CodeGenOpenCL::Finish() { void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { CHECK(!var_idmap_.count(iv->var.get())); - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); std::ostringstream os; if (ts.rank == 1) { os << "get_local_id(" << ts.dim_index << ")"; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 364a62fa0e3e..699d3953f04c 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -92,7 +92,7 @@ void CodeGenSPIRV::InitFuncState() { } spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) { - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); spirv::Value v; if (ts.rank == 1) { v = builder_->GetLocalID(ts.dim_index); @@ -580,7 +580,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == tir::attr::storage_scope) { const VarNode* v = op->node.as(); CHECK(v); - storage_info_[v].scope = runtime::StorageScope::make(op->value.as()->value); + storage_info_[v].scope = runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index a8a9a0be2321..1834aa3decf7 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -340,8 +340,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { new_bodies.push_back(new_body); } - auto new_op = - ComputeOpNode::make(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); + auto new_op = ComputeOp(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); // Jacobian shape = output.shape + input.shape Array new_shape = output->shape; @@ -349,7 +348,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { new_shape.push_back(e); } - return TensorNode::make(new_shape, output->dtype, new_op, value_index); + return Tensor(new_shape, output->dtype, new_op, value_index); } } // namespace te diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 7f957b584c57..1fc0520143fb 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -99,7 +99,7 @@ Tensor compute(Array shape, FCompute fcompute, std::string name, std:: args.push_back(axis.back()->var); } - return ComputeOpNode::make(name, tag, attrs, axis, {fcompute(args)}).output(0); + return ComputeOp(name, tag, attrs, axis, {fcompute(args)}).output(0); } Array compute(Array shape, FBatchCompute fcompute, std::string name, @@ -116,7 +116,7 @@ Array compute(Array shape, FBatchCompute fcompute, std::string args.push_back(axis.back()->var); } - Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args)); + Operation op = ComputeOp(name, tag, attrs, axis, fcompute(args)); Array outputs; for (int idx = 0; idx < op->num_outputs(); ++idx) { outputs.push_back(op.output(idx)); @@ -124,8 +124,8 @@ Array compute(Array shape, FBatchCompute fcompute, std::string return outputs; } -Operation ComputeOpNode::make(std::string name, std::string tag, Map attrs, - Array axis, Array body) { +ComputeOp::ComputeOp(std::string name, std::string tag, Map attrs, + Array axis, Array body) { if (!attrs.defined()) { attrs = Map(); } @@ -140,10 +140,13 @@ Operation ComputeOpNode::make(std::string name, std::string tag, Mapreduce_axis = reduce->axis; } VerifyComputeOp(n.get()); - return Operation(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ComputeOp").set_body_typed(ComputeOpNode::make); +TVM_REGISTER_GLOBAL("te.ComputeOp") + .set_body_typed([](std::string name, std::string tag, Map attrs, + Array axis, + Array body) { return ComputeOp(name, tag, attrs, axis, body); }); // The schedule related logics Array ComputeOpNode::InputTensors() const { @@ -188,7 +191,7 @@ Operation ComputeOpNode::ReplaceInputs(const Operation& self, UpdateArray(this->body, [&rmap](const PrimExpr& e) { return te::ReplaceTensor(e, rmap); }); } if (!arr.same_as(this->body)) { - return ComputeOpNode::make(this->name, this->tag, this->attrs, this->axis, arr); + return ComputeOp(this->name, this->tag, this->attrs, this->axis, arr); } else { return self; } @@ -331,7 +334,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) { // grab the nest structure - ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop); // Normal loop structure n.init_nest.emplace_back(MakeIfNest(n.init_predicates)); n.main_nest.emplace_back(MakeIfNest(n.main_predicates)); @@ -424,9 +427,9 @@ Stmt ComputeOpNode::BuildProvide(const Stage& stage, } } -ComputeLoopNest ComputeLoopNest::make(const BaseComputeOpNode* self, const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { +ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { CHECK_EQ(stage->op.operator->(), self); ComputeLoopNest ret; // make main loop nest diff --git a/src/te/operation/compute_op.h b/src/te/operation/compute_op.h index 610c01468509..2661eb976f2e 100644 --- a/src/te/operation/compute_op.h +++ b/src/te/operation/compute_op.h @@ -59,9 +59,9 @@ struct ComputeLoopNest { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The constructed loop nest */ - static ComputeLoopNest make(const BaseComputeOpNode* self, const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); + static ComputeLoopNest Create(const BaseComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop); }; /*! diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 0933e303295c..ef55c44241b0 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -50,9 +50,9 @@ DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders Array ExternOpNode::output_shape(size_t i) const { return output_placeholders[i]->shape; } -Operation ExternOpNode::make(std::string name, std::string tag, Map attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body) { +ExternOp::ExternOp(std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body) { if (!attrs.defined()) { attrs = Map(); } @@ -73,10 +73,15 @@ Operation ExternOpNode::make(std::string name, std::string tag, Mapinput_placeholders = std::move(input_placeholders); n->output_placeholders = std::move(output_placeholders); n->body = std::move(body); - return Operation(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ExternOp").set_body_typed(ExternOpNode::make); +TVM_REGISTER_GLOBAL("te.ExternOp") + .set_body_typed([](std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body) { + return ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body); + }); Array ExternOpNode::InputTensors() const { return inputs; } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 9b3a79f33a4a..9be474d7d941 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -57,8 +57,8 @@ DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; Array HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } -Operation HybridOpNode::make(std::string name, std::string tag, Map attrs, - Array inputs, Array outputs, Stmt body) { +HybridOp::HybridOp(std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, Stmt body) { if (!attrs.defined()) { attrs = Map(); } @@ -70,11 +70,13 @@ Operation HybridOpNode::make(std::string name, std::string tag, Mapoutputs = std::move(outputs); n->axis = te::GatherLoopVars(body); n->body = std::move(body); - Operation res = Operation(n); - return res; + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.HybridOp").set_body_typed(HybridOpNode::make); +TVM_REGISTER_GLOBAL("te.HybridOp") + .set_body_typed([](std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, + Stmt body) { return HybridOp(name, tag, attrs, inputs, outputs, body); }); Array HybridOpNode::InputTensors() const { // Because input tensors could be potentially inlined into hybrid scripts, diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index f1b0527839e5..61b782629d19 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -156,9 +156,9 @@ std::vector > MakeLoopNest(const Stage& stage, if (!debug_keep_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; } else { - runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag); if (stage->scope == "" || - static_cast(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) { + static_cast(runtime::StorageScope::Create(stage->scope).rank) <= ts.rank) { value_map[iv] = var; } else if (stage->scope == "warp" && ts.rank == 1) { // To determine whether a thread index is inside or outside a warp, we need diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 9c536ebb8785..5b7ede314e49 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -50,16 +50,16 @@ Array PlaceholderOpNode::output_shape(size_t i) const { return shape; } -Operation PlaceholderOpNode::make(std::string name, Array shape, DataType dtype) { +PlaceholderOp::PlaceholderOp(std::string name, Array shape, DataType dtype) { auto n = make_object(); n->name = name; n->shape = shape; n->dtype = dtype; - return Operation(n); + data_ = std::move(n); } Tensor placeholder(Array shape, DataType dtype, std::string name) { - return PlaceholderOpNode::make(name, shape, dtype).output(0); + return PlaceholderOp(name, shape, dtype).output(0); } TVM_REGISTER_GLOBAL("te.Placeholder") diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 45e86e24d4ea..cc86d0f46e3b 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -55,9 +55,9 @@ Array ScanOpNode::output_shape(size_t i) const { return state_placeholder[i]->shape; } -Operation ScanOpNode::make(std::string name, std::string tag, Map attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array inputs) { +ScanOp::ScanOp(std::string name, std::string tag, Map attrs, IterVar axis, + Array init, Array update, Array state_placeholder, + Array inputs) { if (!attrs.defined()) { attrs = Map(); } @@ -104,10 +104,15 @@ Operation ScanOpNode::make(std::string name, std::string tag, Mapupdate = std::move(update); n->state_placeholder = std::move(state_placeholder); n->inputs = std::move(inputs); - return Operation(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ScanOp").set_body_typed(ScanOpNode::make); +TVM_REGISTER_GLOBAL("te.ScanOp") + .set_body_typed([](std::string name, std::string tag, Map attrs, + IterVar axis, Array init, Array update, + Array state_placeholder, Array inputs) { + return ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs); + }); Array scan(Array init, Array update, Array state_placeholder, Array inputs, std::string name, std::string tag, @@ -115,8 +120,7 @@ Array scan(Array init, Array update, Array state IterVar scan_axis = IterVar(Range::make_by_min_extent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), Var(name + ".idx"), kOrdered); - Operation op = - ScanOpNode::make(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); + Operation op = ScanOp(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); Array res; for (int i = 0; i < op->num_outputs(); ++i) { res.push_back(op.output(i)); diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index c8dfce8ea1ba..8d5265bcb14f 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -52,10 +52,10 @@ DataType TensorComputeOpNode::output_dtype(size_t i) const { return this->intrin->buffers[this->inputs.size() + i]->dtype; } -Operation TensorComputeOpNode::make(std::string name, std::string tag, Array axis, - Array reduce_axis, int schedulable_ndim, - TensorIntrin intrin, Array tensors, - Array regions, Array scalar_inputs) { +TensorComputeOp::TensorComputeOp(std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, + TensorIntrin intrin, Array tensors, Array regions, + Array scalar_inputs) { auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); @@ -66,10 +66,17 @@ Operation TensorComputeOpNode::make(std::string name, std::string tag, Arrayinputs = std::move(tensors); n->input_regions = std::move(regions); n->scalar_inputs = std::move(scalar_inputs); - return Operation(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.TensorComputeOp").set_body_typed(TensorComputeOpNode::make); +TVM_REGISTER_GLOBAL("te.TensorComputeOp") + .set_body_typed([](std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, TensorIntrin intrin, + Array tensors, Array regions, + Array scalar_inputs) { + return TensorComputeOp(name, tag, axis, reduce_axis, schedulable_ndim, intrin, tensors, + regions, scalar_inputs); + }); Array TensorComputeOpNode::InputTensors() const { return inputs; } @@ -191,7 +198,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, binder.BindArray(sp_expr, user_expr, this->name); size_t tloc = stage->leaf_iter_vars.size(); - ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = ComputeLoopNest::Create(this, stage, dom_map, debug_keep_trivial_loop); if (this->reduce_axis.size() == 0) { std::vector > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index af4b08e6b9a9..82832c927785 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -347,7 +347,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region); TensorIntrin intrin = stage->iter_var_attrs.at(stage->leaf_iter_vars[tloc])->tensor_intrin; CHECK(intrin.defined()); - ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop); VerifyTensorizeLoopNest(self, stage, n, tloc); VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin); // Start bind data. diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 01d4f93db45a..099f4882f16c 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -59,7 +59,7 @@ bool NeedRelax(const IterVar& iv, bool found_attach, if (tag.length() == 0 || tag == "pipeline") { return !found_attach; } - ThreadScope ts = ThreadScope::make(tag); + ThreadScope ts = ThreadScope::Create(tag); // When there is warp memory // threadIdx.x must be set to be warp index. @@ -72,14 +72,14 @@ bool NeedRelax(const IterVar& iv, bool found_attach, // infer storage scope, if not given StorageScope InferStorageScope(const Stage& stage, const GraphContext& ctx) { if (stage->scope.length() != 0) { - return StorageScope::make(stage->scope); + return StorageScope::Create(stage->scope); } int max_rank = -1; for (IterVar iv : ctx.attach_path.at(stage->op)) { auto it = ctx.bind_map.find(iv); const std::string& tag = (it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag != "pipeline" && tag.length() != 0) { - max_rank = std::max(max_rank, ThreadScope::make(tag).rank); + max_rank = std::max(max_rank, ThreadScope::Create(tag).rank); } } StorageScope s; diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index c36051341cc3..af72d3b1a1df 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -324,16 +324,16 @@ Array CacheWriteWithReLayout(Schedule sch, const Array& tensor_a args.push_back(value_map.at(iv)); } } - Operation cache_op = ComputeOpNode::make(compute->name + "." + scope, compute->tag, - compute->attrs, new_axis, body_list); + Operation cache_op = + ComputeOp(compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list); Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); cache_expr_list.push_back(cache_tensor(args)); } - Operation orig_new_op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs, - compute->axis, cache_expr_list); + Operation orig_new_op = + ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list); return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); } @@ -380,10 +380,10 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, const Array& te new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input)); } - Operation cache_op = TensorComputeOpNode::make(tensor_op->name + "." + scope, tensor_op->tag, - new_axis, tensor_op->reduce_axis, - tensor_op->schedulable_ndim, tensor_op->intrin, - tensor_op->inputs, new_regions, new_scalar_inputs); + Operation cache_op = + TensorComputeOp(tensor_op->name + "." + scope, tensor_op->tag, new_axis, + tensor_op->reduce_axis, tensor_op->schedulable_ndim, tensor_op->intrin, + tensor_op->inputs, new_regions, new_scalar_inputs); // axis will be used in generating compute op Array compute_axis = tensor_op->axis; @@ -419,7 +419,7 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, const Array& te cache_expr_list.push_back(cache_tensor(args)); } Operation orig_new_op = - ComputeOpNode::make(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list); + ComputeOp(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list); return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); } @@ -468,7 +468,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { if (idx < leaf_vars->size()) { // insert rebase IterVar rebased = IterVar(Range(), iv->var.copy_with_suffix(""), iv->iter_type); - s->relations.push_back(RebaseNode::make(iv, rebased)); + s->relations.push_back(te::Rebase(iv, rebased)); if (s->iter_var_attrs.count(iv)) { s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv)); } @@ -583,8 +583,7 @@ void InjectInline(ScheduleNode* sch) { CHECK(compute); Operation op = s->op; if (changed[i]) { - op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs, compute->axis, - new_body[i]); + op = ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, new_body[i]); } op = op->ReplaceInputs(op, repl); if (!op.same_as(s->op)) { @@ -596,8 +595,8 @@ void InjectInline(ScheduleNode* sch) { } else if (hybrid_changed[i]) { const HybridOpNode* hybrid = sch->stages[i]->op.as(); CHECK(hybrid); - Operation op = HybridOpNode::make(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, - hybrid->outputs, new_hybrid_body[i]); + Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, + hybrid->outputs, new_hybrid_body[i]); op = op->ReplaceInputs(op, repl); for (int idx = 0; idx < s->op->num_outputs(); ++idx) { repl[s->op.output(idx)] = op.output(idx); diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 24d910237627..707d52fb186a 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -55,8 +55,8 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) return 0; } -void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, IterVar* p_outer, - IterVar* p_inner) { +void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, + IterVar* p_outer, IterVar* p_inner) { // Check if split is valid. CHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || parent->iter_type == kOrdered) @@ -69,7 +69,7 @@ void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, It Array& all_vars = self->all_iter_vars; Array& leaf_vars = self->leaf_iter_vars; size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent); - self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts)); + self->relations.push_back(Split(parent, outer, inner, factor, nparts)); // add vars to all vars all_vars.push_back(outer); all_vars.push_back(inner); @@ -206,13 +206,13 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) { Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) - Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); + SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); return *this; } Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) - Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); + SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); return *this; } @@ -242,7 +242,7 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT } CHECK_EQ(pos_inner, pos_outer + 1) << "Can only fuse iterations that are consecutive between each other"; - self->relations.push_back(FuseNode::make(outer, inner, fused)); + self->relations.push_back(Fuse(outer, inner, fused)); all_vars.push_back(fused); leaf_vars.erase(leaf_vars.begin() + pos_outer, leaf_vars.begin() + pos_inner + 1); leaf_vars.insert(leaf_vars.begin() + pos_outer, fused); @@ -263,7 +263,7 @@ Stage& Stage::fuse(const Array& axes, IterVar* p_target) { // NOLINT(* // insert at the outer most loop IterVar singleton = IterVar(Range::make_by_min_extent(0, 1), Var("singleton", DataType::Int(32)), kDataPar); - self->relations.push_back(SingletonNode::make(singleton)); + self->relations.push_back(Singleton(singleton)); Array& all_vars = self->all_iter_vars; Array& leaf_vars = self->leaf_iter_vars; all_vars.push_back(singleton); @@ -624,9 +624,9 @@ bool ScheduleNode::Contain(const Operation& op) const { return stage_map.find(op) != stage_map.end(); } -Schedule ScheduleNode::make(Array ops) { +Schedule::Schedule(Array ops) { auto n = make_object(); - Schedule sch(n); + data_ = n; n->outputs = ops; auto g = te::CreateReadGraph(n->outputs); Array post_order = te::PostDFSOrder(n->outputs, g); @@ -650,7 +650,7 @@ Schedule ScheduleNode::make(Array ops) { inputs.push_back(t); } // Create the scan group. - Stage scan_group = sch.create_group(scan->update, inputs, false); + Stage scan_group = this->create_group(scan->update, inputs, false); scan_group->attach_type = kScanUpdate; scan_group->attach_stage = stage; @@ -660,39 +660,37 @@ Schedule ScheduleNode::make(Array ops) { } } } - return sch; } -IterVarRelation SplitNode::make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, - PrimExpr nparts) { +Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) { auto n = make_object(); n->parent = parent; n->outer = outer; n->inner = inner; n->factor = factor; n->nparts = nparts; - return IterVarRelation(n); + data_ = std::move(n); } -IterVarRelation FuseNode::make(IterVar outer, IterVar inner, IterVar fused) { +Fuse::Fuse(IterVar outer, IterVar inner, IterVar fused) { auto n = make_object(); n->outer = outer; n->inner = inner; n->fused = fused; - return IterVarRelation(n); + data_ = std::move(n); } -IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { +Rebase::Rebase(IterVar parent, IterVar rebased) { auto n = make_object(); n->parent = parent; n->rebased = rebased; - return IterVarRelation(n); + data_ = std::move(n); } -IterVarRelation SingletonNode::make(IterVar iter) { +Singleton::Singleton(IterVar iter) { auto n = make_object(); n->iter = iter; - return IterVarRelation(n); + data_ = std::move(n); } SpecializedCondition::SpecializedCondition(Array conditions) { diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 7e7f648f6594..e66b9632d8a2 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -66,28 +66,32 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } -Tensor TensorNode::make(Array shape, DataType dtype, Operation op, int value_index) { +Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_index) { auto n = make_object(); n->shape = std::move(shape); n->dtype = dtype; n->op = op; n->value_index = value_index; - return Tensor(n); + data_ = std::move(n); } +TVM_REGISTER_GLOBAL("te.Tensor") + .set_body_typed([](Array shape, DataType dtype, Operation op, int value_index) { + return Tensor(shape, dtype, op, value_index); + }); + +TVM_REGISTER_NODE_TYPE(TensorNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* t = static_cast(node.get()); p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; }); -TVM_REGISTER_NODE_TYPE(TensorNode); - // TensorIntrin - -TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array inputs, - Array buffers, Array scalar_params, Stmt body, - Stmt reduce_init, Stmt reduce_update) { +TensorIntrin::TensorIntrin(std::string name, Operation op, Array inputs, + Array buffers, Array scalar_params, Stmt body, + Stmt reduce_init, Stmt reduce_update) { auto n = make_object(); n->name = std::move(name); n->op = std::move(op); @@ -97,17 +101,24 @@ TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Arraybody = std::move(body); n->reduce_init = std::move(reduce_init); n->reduce_update = std::move(reduce_update); - return TensorIntrin(n); + data_ = std::move(n); } +TVM_REGISTER_GLOBAL("te.TensorIntrin") + .set_body_typed([](std::string name, Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { + return TensorIntrin(name, op, inputs, buffers, scalar_params, body, reduce_init, + reduce_update); + }); + +TVM_REGISTER_NODE_TYPE(TensorIntrinNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; }); -TVM_REGISTER_NODE_TYPE(TensorIntrinNode); - // TensorIntrinCall TensorIntrinCall::TensorIntrinCall(TensorIntrin intrin, Array tensors, Array regions, Array reduce_axis, @@ -135,10 +146,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); -TVM_REGISTER_GLOBAL("te.Tensor").set_body_typed(TensorNode::make); - -TVM_REGISTER_GLOBAL("te.TensorIntrin").set_body_typed(TensorIntrinNode::make); - +// Other tensor ops. TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 46f4160557ec..2b64f11ccc3c 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -45,8 +45,8 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { } Buffer decl_buffer(Array shape, DataType dtype, std::string name) { - return BufferNode::make(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array(), - PrimExpr(), name, "", 0, 0, kDefault); + return Buffer(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array(), + PrimExpr(), name, "", 0, 0, kDefault); } // Split the given expression w.r.t the add operator @@ -348,8 +348,8 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return MakeStrideView().MakeSlice(begins, extents); } } - return BufferNode::make(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", - n->scope, n->data_alignment, 0, n->buffer_type); + return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope, + n->data_alignment, 0, n->buffer_type); } PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, @@ -379,9 +379,9 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane return tir::Call(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); } -Buffer BufferNode::make(Var data, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, std::string name, std::string scope, - int data_alignment, int offset_factor, BufferType buffer_type) { +Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, + PrimExpr elem_offset, std::string name, std::string scope, int data_alignment, + int offset_factor, BufferType buffer_type) { auto n = make_object(); n->data = std::move(data); n->dtype = dtype; @@ -410,7 +410,7 @@ Buffer BufferNode::make(Var data, DataType dtype, Array shape, Array

strides.push_back(Var("stride", n->shape[i].dtype())); } } - return Buffer(n); + data_ = std::move(n); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -425,8 +425,8 @@ TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size(), 10); auto buffer_type = args[9].operator std::string(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], - args[8], type); + *ret = + Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], type); }); TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 1f17c35f7fd4..bc777db55dbe 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -66,7 +66,7 @@ const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) { return LayoutAxis::Get(axis[0]); } -const LayoutAxis& LayoutAxis::make(const std::string& name) { +const LayoutAxis& LayoutAxis::Get(const std::string& name) { CHECK_EQ(name.length(), 1) << "Invalid axis " << name; return LayoutAxis::Get(name[0]); } @@ -144,8 +144,6 @@ Layout::Layout(const std::string& name) { // NOLINT(*) data_ = std::move(node); } -Layout LayoutNode::make(const std::string& layout) { return Layout(layout); } - Layout Layout::SubLayout(size_t pos, size_t len) const { if (!defined() || pos > ndim()) return Layout::Undef(); if (len == 0) return Layout(Array()); @@ -365,15 +363,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed(LayoutNode::make); +TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name) { return Layout(name); }); TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int { - return layout.IndexOf(LayoutAxis::make(axis)); + return layout.IndexOf(LayoutAxis::Get(axis)); }); TVM_REGISTER_GLOBAL("tir.LayoutFactorOf") .set_body_typed([](Layout layout, std::string axis) -> int { - return layout.FactorOf(LayoutAxis::make(axis)); + return layout.FactorOf(LayoutAxis::Get(axis)); }); TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index c201b8fa103a..416358cebce0 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -147,12 +147,12 @@ class CopyIntrinInjector : public StmtMutator { src_strides.push_back(make_const(DataType::Int(32), 1)); dst_strides.push_back(make_const(DataType::Int(32), 1)); } - Buffer dst = BufferNode::make(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, - store_strides[loop_var_size], store->buffer_var->name_hint, - GetStorageScope(store->buffer_var.get()), 0, 0, kDefault); - Buffer src = BufferNode::make(load->buffer_var, load->dtype, src_shape, src_strides, - src_elem_offset, load->buffer_var->name_hint, - GetStorageScope(load->buffer_var.get()), 0, 0, kDefault); + Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, + store_strides[loop_var_size], store->buffer_var->name_hint, + GetStorageScope(store->buffer_var.get()), 0, 0, kDefault); + Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset, + load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), 0, 0, + kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); CHECK(out->defined()) << "flower function did not return correct stmt"; return true; diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 7dbf0fc6391d..3b2580c60074 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -113,7 +113,7 @@ class CandidateSelector final : public StmtExprVisitor { const IterVarNode* iv = op->node.as(); CHECK(iv); Var var = iv->var; - runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); if ((scope.rank == 0) && (!is_const(op->value) || partition_const_loop_)) { record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); @@ -361,7 +361,7 @@ class LoopPartitioner : public StmtMutator { } // normal path when loop parittion fails. - runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); Stmt res; if (scope.rank == 1) { // threadIdx should be put into relax map, in case of divergence. diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 0b8775761608..9d6b47a1ca37 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -63,7 +63,7 @@ class StorageAccessInfoLower : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::make(op->value.as()->value); + StorageScope scope = StorageScope::Create(op->value.as()->value); StorageEntry e; e.scope = scope; if (scope.tag.length() != 0) { diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 860401735896..dae428258541 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -157,7 +157,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (const AttrStmtNode* attr : thread_extents_) { ThreadEntry e; IterVar iv = Downcast(attr->node); - e.scope = runtime::ThreadScope::make(iv->thread_tag); + e.scope = runtime::ThreadScope::Create(iv->thread_tag); e.iv = iv; CHECK_LE(e.scope.rank, 1); CHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction"; @@ -503,7 +503,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { IterVar iv = Downcast(op->node); ThreadEntry e; - e.scope = runtime::ThreadScope::make(iv->thread_tag); + e.scope = runtime::ThreadScope::Create(iv->thread_tag); e.extent = 0; if (auto ptr = op->value.as()) { e.extent = static_cast(ptr->value); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index a0ddf26e0dcc..92f9ab54adb4 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -367,7 +367,7 @@ class WarpMemoryRewriter : private StmtMutator { using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::make(op->value.as()->value); + StorageScope scope = StorageScope::Create(op->value.as()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); Stmt ret = StmtMutator::VisitStmt_(op); diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 3a4213785a16..20cc6402135f 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -92,7 +92,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::make(op->value.as()->value); + storage_scope_[buf] = StorageScope::Create(op->value.as()->value); StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == attr::double_buffer_write) { CHECK(double_buffer_write_ == nullptr); @@ -215,11 +215,11 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { CHECK(allow_append_); const std::string& s = op->args[0].as()->value; if (s != "warp") { - StorageScope scope = StorageScope::make(s); + StorageScope scope = StorageScope::Create(s); AccessEntry e; e.threads = env_threads(); e.type = kSync; - e.scope = StorageScope::make(s); + e.scope = StorageScope::Create(s); curr_stmt_.access.emplace_back(std::move(e)); } } else { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 4c3de580160d..e29d978e0d42 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -91,7 +91,7 @@ class StorageFlattener : public StmtExprMutator { return body; } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); - ThreadScope ts = ThreadScope::make(iv->thread_tag); + ThreadScope ts = ThreadScope::Create(iv->thread_tag); curr_thread_scope_.push_back(ts); Stmt stmt = StmtExprMutator::VisitStmt_(op); curr_thread_scope_.pop_back(); @@ -165,7 +165,7 @@ class StorageFlattener : public StmtExprMutator { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); } } else { - skey = StorageScope::make(strkey); + skey = StorageScope::Create(strkey); } // use small alignment for small arrays @@ -200,9 +200,9 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = BufferNode::make(Var(op->buffer->data->name_hint, DataType::Handle()), - op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, - skey.to_string(), align, 0, kDefault); + e.buffer = + Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape, + strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 2d09e8bae64d..283ab0f6f703 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -178,7 +178,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = StorageScope::make(op->value.as()->value); + alloc_info_[buf].storage_scope = StorageScope::Create(op->value.as()->value); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index e5b4bdde7d90..b8575d28c8ce 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -251,7 +251,7 @@ class ThreadSyncInserter : public StmtExprMutator { return ret; } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::make(op->value.as()->value); + storage_scope_[buf] = StorageScope::Create(op->value.as()->value); return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -321,7 +321,7 @@ class ThreadSyncInserter : public StmtExprMutator { num_work_dim_ = thread_extents_.size(); for (const AttrStmtNode* attr : thread_extents_) { IterVar iv = Downcast(attr->node); - runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag); if (s.rank == 0) { num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); } else if (s.rank == 1) { @@ -353,7 +353,7 @@ class ThreadSyncInserter : public StmtExprMutator { }; Stmt ThreadSync(Stmt stmt, std::string storage_scope) { - StorageScope sync_scope = StorageScope::make(storage_scope); + StorageScope sync_scope = StorageScope::Create(storage_scope); ThreadSyncPlanner planner(sync_scope); planner(stmt); return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index 8823134faca7..70709b0f96a1 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -51,11 +51,11 @@ TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* TEST(MicroStandaloneRuntime, BuildModule) { using namespace tvm; auto tensor_type = relay::TensorType({2, 3}, ::tvm::Float(32)); - auto a = relay::VarNode::make("a", tensor_type); - auto b = relay::VarNode::make("b", tensor_type); + auto a = relay::Var("a", tensor_type); + auto b = relay::Var("b", tensor_type); auto add_op = relay::Op::Get("add"); auto x = relay::Call(add_op, {a, b}, tvm::Attrs(), {}); - auto c = relay::VarNode::make("c", tensor_type); + auto c = relay::Var("c", tensor_type); auto y = relay::Call(add_op, {x, c}, tvm::Attrs(), {}); auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {}); auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index 25b38008b6ed..b84fbc7722a1 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -46,8 +46,7 @@ using namespace tvm::te; inline Buffer DeclExternBuffer(Array shape, DataType dtype, std::string name) { auto data = var(name, DataType::Handle()); auto elem_offset = PrimExpr(); - return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, - kDefault); + return Buffer(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, kDefault); } /*! @@ -93,8 +92,7 @@ inline Array make_extern(const Array >& out_shapes, auto body = fextern(input_placeholders, output_placeholders); auto body_stmt = tvm::tir::Evaluate(body); - auto op = ExternOpNode::make(name, tag, attrs, inputs, input_placeholders, output_placeholders, - body_stmt); + auto op = ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body_stmt); Array outputs; for (size_t i = 0; i < output_placeholders.size(); ++i) { diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 9aa4e358eaf3..e830e099b0c0 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1194,8 +1194,8 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, const std::string& dst_layout, const std::string name = "T_layout_trans", const std::string tag = kInjective) { - Layout src_layout_struct = LayoutNode::make(src_layout); - Layout dst_layout_struct = LayoutNode::make(dst_layout); + Layout src_layout_struct(src_layout); + Layout dst_layout_struct(dst_layout); if (src_layout_struct.Equals(dst_layout_struct)) { return src;