From 02c4118532575ce2be3c20d821d2b227d00a75b6 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 2 Apr 2020 12:36:55 -0700 Subject: [PATCH] [TIR] Introduce BufferLoad/Store (#5205) Co-authored-by: Siyuan Feng This PR introduces BufferLoad/Store to TIR. The new nodes will replace Provide and Call with Tensor arguments in the subsequent refactors. --- include/tvm/tir/buffer.h | 14 +- include/tvm/tir/expr.h | 356 +++--------------- include/tvm/tir/expr_functor.h | 4 + include/tvm/tir/stmt.h | 51 +++ include/tvm/tir/stmt_functor.h | 3 + include/tvm/tir/var.h | 343 +++++++++++++++++ python/tvm/tir/__init__.py | 4 +- python/tvm/tir/expr.py | 26 +- python/tvm/tir/stmt.py | 27 +- src/tir/ir/expr.cc | 16 + src/tir/ir/expr_functor.cc | 14 + src/tir/ir/stmt.cc | 15 + src/tir/ir/stmt_functor.cc | 15 + tests/python/unittest/test_tir_nodes.py | 11 + .../test_tir_structural_equal_hash.py | 17 + 15 files changed, 590 insertions(+), 326 deletions(-) create mode 100644 include/tvm/tir/var.h diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 7b15776d260e..08a8e69a4532 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -25,9 +25,8 @@ #define TVM_TIR_BUFFER_H_ #include -#include -#include - +#include +#include #include @@ -36,6 +35,9 @@ namespace tir { // Internal node container Buffer class BufferNode; +// forward declare Stmt +class Stmt; + /*! \brief buffer type */ enum BufferType : int { kDefault = 1, @@ -75,9 +77,9 @@ class Buffer : public ObjectRef { * \param offset The offset of ptr. */ TVM_DLL PrimExpr access_ptr(int access_mask, - DataType ptr_type = DataType::Handle(), - int content_lanes = 1, - PrimExpr offset = make_const(DataType::Int(32), 0)) const; + DataType ptr_type = DataType::Handle(), + int content_lanes = 1, + PrimExpr offset = IntImm(DataType::Int(32), 0)) const; /*! * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 7b8ab44036fd..6295a366c6ad 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -31,6 +31,8 @@ #include #include #include +#include +#include #include #include @@ -42,313 +44,6 @@ namespace tvm { namespace tir { -/*! - * \brief A variable node in the IR. - * - * A variable is uniquely identified by its address. - * - * Each variable is only binded once in the following nodes: - * - Allocate - * - For - * - Let - * - LetStmt - */ -class VarNode : public PrimExprNode { - public: - /*! - * \brief The hint to the variable name. - * \note Each variable is uniquely identified by its address. - */ - std::string name_hint; - /*! - * \brief type annotaion of the variable. - * - * It is an optional field that provides a refined type of the variable than dtype. - * - * \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type. - */ - Type type_annotation; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("name", &name_hint); - v->Visit("type_annotation", &type_annotation); - } - - bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { - if (!equal(dtype, other->dtype)) return false; - if (!equal(type_annotation, other->type_annotation)) return false; - return equal.FreeVarEqualImpl(this, other); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - hash_reduce(type_annotation); - hash_reduce.FreeVarHashImpl(this); - } - - static constexpr const char* _type_key = "tir.Var"; - TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); -}; - -/*! \brief a named variable in TVM */ -class Var : public PrimExpr { - public: - explicit Var(ObjectPtr n) : PrimExpr(n) {} - /*! - * \brief Constructor - * \param name_hint variable name - * \param dtype data type - */ - TVM_DLL explicit Var(std::string name_hint = "v", - DataType dtype = DataType::Int(32)); - /*! - * \brief Constructor which provides a more detailed type annotation. - * \param name_hint variable name. - * \param type_annotation The type annotation. - */ - TVM_DLL explicit Var(std::string name_hint, Type type_annotation); - /*! - * \brief Make a new copy of var with same type, append suffix - * \param suffix The suffix to be appended. - * \return the new Var copy - */ - TVM_DLL Var copy_with_suffix(const std::string& suffix) const; - /*! - * \brief Get pointer to the internal value. - * \return the corresponding Variable. - */ - const VarNode* operator->() const { - return get(); - } - /*! - * \brief Get pointer to the internal value. - * \return the corresponding Variable. - */ - const VarNode* get() const { - return static_cast(data_.get()); - } - /*! \brief type indicate the container type */ - using ContainerType = VarNode; -}; - -/*! - * \brief A variable node represent a tensor index size, - * whose value must be non-negative. - */ -class SizeVarNode : public VarNode { - public: - static constexpr const char* _type_key = "tir.SizeVar"; - TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode); -}; - -/*! \brief a named variable represents a tensor index size */ -class SizeVar : public Var { - public: - explicit SizeVar(ObjectPtr n) : Var(n) {} - /*! - * \brief constructor - * \param name_hint variable name - * \param t data type - */ - TVM_DLL explicit SizeVar(std::string name_hint = "s", - DataType t = DataType::Int(32)); - /*! - * \brief Get pointer to the internal value. - * \return the corresponding Variable. - */ - const SizeVarNode* operator->() const { - return get(); - } - /*! - * \brief Get pointer to the internal value. - * \return the corresponding Variable. - */ - const SizeVarNode* get() const { - return static_cast(data_.get()); - } - /*! \brief type indicate the container type */ - using ContainerType = SizeVarNode; -}; - - -/*! \brief container class of iteration variable. */ -class IterVarNode; - -using Region = Array; - -/*! - * \brief Type of iteration variable. - * Each IterVar have a specific type. - * - * The type of iter var can be overriden via - * stage.iter_var_attrs given they are compatible. - */ -enum IterVarType : int { - /*! - * \brief Data parallel iteration. - * This normally corresponds to axis of Tensor. - * Allow all IterVar manipulations. - * - * \note This does not mean the loop - * have to be executed in parallel fashion. - */ - kDataPar = 0, - /*! - * \brief The IterVar itself is a thread-index - * of a fixed thread launching group. - * Note that this is already assumed to be paralellized. - * - * Disallow: split/fuse/vectorize/parallel - */ - kThreadIndex = 1, - /*! - * \brief Communicative reduction. - * Cannot be directly parallelized. - * - * Disallow: parallel/vectorize - */ - kCommReduce = 2, - /*! - * \brief Serial loops with loop carry dependency, - * the iteration must execute in order. - * Cannot be re-ordered. - * - * Disallow: reorder/parallel/vectorize - */ - kOrdered = 3, - /*! - * \brief IterVar is opaque, - * - * May not corresponds to any generated loop - * Disallow all IterVar manipulations and compute_at - * - * \note This is usually used to implement composite op - * or external op, where the - */ - kOpaque = 4, - // The following are possible additional - // types that are provided during schedule - /*! - * \brief The execution is unrolled. - */ - kUnrolled = 5, - /*! - * \brief The loop is vectorized. - */ - kVectorized = 6, - /*! - * \brief The loop is parallelized. - */ - kParallelized = 7, - /*! - * \brief Marks boundary of tensorization intrinsic. - */ - kTensorized = 8 -}; - -/*! - * \brief Iteration Variable, - * represents an iteration over an integer interval. - */ -class IterVar : public ObjectRef { - public: - // construct a new iter var without a domain - IterVar() {} - // construct from shared ptr. - explicit IterVar(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const IterVarNode* operator->() const; - /*! - * \return the corresponding var in the IterVar. - */ - inline operator PrimExpr() const; - /*! \brief specify container node */ - using ContainerType = IterVarNode; -}; - -using Domain = Array; - -/*! - * \brief An iteration variable representing an iteration - * over a one dimensional interval. - */ -class IterVarNode : public Object { - public: - /*! - * \brief the domain of iteration, if known, can be None - * For the intermediate schedule node, before schedule. - */ - Range dom; - /*! \brief The looping variable */ - Var var; - /*! \brief The type of the IterVar */ - IterVarType iter_type; - /*! - * \brief additional tag on the iteration variable, - * set this if this is binded already to a known thread tag. - */ - std::string thread_tag; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dom", &dom); - v->Visit("var", &var); - v->Visit("iter_type", &iter_type); - v->Visit("thread_tag", &thread_tag); - } - - bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const { - return - equal(dom, other->dom) && - equal.DefEqual(var, other->var) && - equal(iter_type, other->iter_type) && - equal(thread_tag, other->thread_tag); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dom); - hash_reduce.DefHash(var); - hash_reduce(iter_type); - hash_reduce(thread_tag); - } - - TVM_DLL static IterVar make(Range dom, Var var, - IterVarType iter_type, - std::string thread_tag = ""); - - static constexpr const char* _type_key = "IterVar"; - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; - TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); -}; - -// inline implementations -inline const IterVarNode* IterVar::operator->() const { - return static_cast(data_.get()); -} - -inline IterVar::operator PrimExpr() const { - return (*this)->var; -} - -inline const char* IterVarType2String(IterVarType t) { - switch (t) { - case kDataPar: return "DataPar"; - case kThreadIndex: return "ThreadIndex"; - case kCommReduce: return "CommReduce"; - case kOrdered: return "Ordered"; - case kOpaque: return "Opaque"; - case kUnrolled: return "Unrolled"; - case kVectorized: return "Vectorized"; - case kParallelized: return "Parallelized"; - case kTensorized: return "Tensorized"; - } - return "Unknown"; -} - using IntImmNode = tvm::IntImmNode; using FloatImmNode = tvm::FloatImmNode; @@ -733,6 +428,53 @@ class SelectNode : public PrimExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode); }; +/*! + * \brief Load value from the high dimension buffer. + * + * \code + * + * value = buffer[i, j]; + * + * \endcode + * \sa BufferStore + */ +class BufferLoadNode : public PrimExprNode { + public: + /*! \brief The buffer variable. */ + Buffer buffer; + /*! \brief The indices location to be loaded. */ + Array indices; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &(this->dtype)); + v->Visit("buffer", &buffer); + v->Visit("indices", &indices); + } + + bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const { + return + equal(dtype, other->dtype) && + equal(buffer, other->buffer) && + equal(indices, other->indices); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(dtype); + hash_reduce(buffer); + hash_reduce(indices); + } + + static constexpr const char* _type_key = "BufferLoad"; + TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); +}; + +class BufferLoad : public PrimExpr { + public: + TVM_DLL explicit BufferLoad(Buffer buffer, + Array indices); + TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); +}; + /*! * \brief Load the value from buffer_var. * diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index 0de05a682703..dcf04c346454 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -121,6 +121,7 @@ class ExprFunctor { virtual R VisitExpr_(const SizeVarNode* op, Args... args) { return VisitExpr_(static_cast(op), std::forward(args)...); } + virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -164,6 +165,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(VarNode); IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode); IR_EXPR_FUNCTOR_DISPATCH(LoadNode); + IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode); IR_EXPR_FUNCTOR_DISPATCH(LetNode); IR_EXPR_FUNCTOR_DISPATCH(CallNode); IR_EXPR_FUNCTOR_DISPATCH(AddNode); @@ -214,6 +216,7 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const VarNode* op) override; void VisitExpr_(const SizeVarNode* op) override; void VisitExpr_(const LoadNode* op) override; + void VisitExpr_(const BufferLoadNode* op) override; void VisitExpr_(const LetNode* op) override; void VisitExpr_(const CallNode* op) override; void VisitExpr_(const AddNode* op) override; @@ -259,6 +262,7 @@ class TVM_DLL ExprMutator : PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const SizeVarNode* op) override; PrimExpr VisitExpr_(const LoadNode* op) override; + PrimExpr VisitExpr_(const BufferLoadNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override; PrimExpr VisitExpr_(const CallNode* op) override; PrimExpr VisitExpr_(const AddNode* op) override; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 47ec30546f51..fe0d9ed44ae6 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -274,6 +274,57 @@ class StoreNode : public StmtNode { TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); }; +/*! + * \brief Store value to the high dimension buffer. + * + * \code + * + * buffer[i, j] = value; + * + * \endcode + * \sa BufferLoad + */ +class BufferStore; +class BufferStoreNode : public StmtNode { + public: + /*! \brief The buffer variable. */ + Buffer buffer; + /*! \brief The value to be stored. */ + PrimExpr value; + /*! \brief The indices location to be stored. */ + Array indices; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer", &buffer); + v->Visit("value", &value); + v->Visit("indices", &indices); + } + + bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const { + return + equal(buffer, other->buffer) && + equal(value, other->value) && + equal(indices, other->indices); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(buffer); + hash_reduce(value); + hash_reduce(indices); + } + + static constexpr const char* _type_key = "BufferStore"; + TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode); +}; + +class BufferStore : public Stmt { + public: + TVM_DLL explicit BufferStore(Buffer buffer, + PrimExpr value, + Array indices); + TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); +}; + /*! * \brief Store value into mult-dimensional array defined by func. */ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index c880a4847356..682402221b3a 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -91,6 +91,7 @@ class StmtFunctor { virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -154,6 +155,7 @@ class TVM_DLL StmtVisitor : void VisitStmt_(const ForNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const FreeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const ProducerConsumerNode* op) override; @@ -248,6 +250,7 @@ class TVM_DLL StmtMutator : Stmt VisitStmt_(const ForNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; + Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const FreeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; Stmt VisitStmt_(const ProducerConsumerNode* op) override; diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h new file mode 100644 index 000000000000..19c904a1230f --- /dev/null +++ b/include/tvm/tir/var.h @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/var.h + * \brief Variables in the TIR. + */ +#ifndef TVM_TIR_VAR_H_ +#define TVM_TIR_VAR_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +/*! + * \brief A variable node in the IR. + * + * A variable is uniquely identified by its address. + * + * Each variable is only binded once in the following nodes: + * - Allocate + * - For + * - Let + * - LetStmt + */ +class VarNode : public PrimExprNode { + public: + /*! + * \brief The hint to the variable name. + * \note Each variable is uniquely identified by its address. + */ + std::string name_hint; + /*! + * \brief type annotaion of the variable. + * + * It is an optional field that provides a refined type of the variable than dtype. + * + * \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type. + */ + Type type_annotation; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("name", &name_hint); + v->Visit("type_annotation", &type_annotation); + } + + bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { + if (!equal(dtype, other->dtype)) return false; + if (!equal(type_annotation, other->type_annotation)) return false; + return equal.FreeVarEqualImpl(this, other); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(dtype); + hash_reduce(type_annotation); + hash_reduce.FreeVarHashImpl(this); + } + + static constexpr const char* _type_key = "tir.Var"; + TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); +}; + +/*! \brief a named variable in TVM */ +class Var : public PrimExpr { + public: + explicit Var(ObjectPtr n) : PrimExpr(n) {} + /*! + * \brief Constructor + * \param name_hint variable name + * \param dtype data type + */ + TVM_DLL explicit Var(std::string name_hint = "v", + DataType dtype = DataType::Int(32)); + /*! + * \brief Constructor which provides a more detailed type annotation. + * \param name_hint variable name. + * \param type_annotation The type annotation. + */ + TVM_DLL explicit Var(std::string name_hint, Type type_annotation); + /*! + * \brief Make a new copy of var with same type, append suffix + * \param suffix The suffix to be appended. + * \return the new Var copy + */ + TVM_DLL Var copy_with_suffix(const std::string& suffix) const; + /*! + * \brief Get pointer to the internal value. + * \return the corresponding Variable. + */ + const VarNode* operator->() const { + return get(); + } + /*! + * \brief Get pointer to the internal value. + * \return the corresponding Variable. + */ + const VarNode* get() const { + return static_cast(data_.get()); + } + /*! \brief type indicate the container type */ + using ContainerType = VarNode; +}; + +/*! + * \brief A variable node represent a tensor index size, + * whose value must be non-negative. + */ +class SizeVarNode : public VarNode { + public: + static constexpr const char* _type_key = "tir.SizeVar"; + TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode); +}; + +/*! \brief a named variable represents a tensor index size */ +class SizeVar : public Var { + public: + explicit SizeVar(ObjectPtr n) : Var(n) {} + /*! + * \brief constructor + * \param name_hint variable name + * \param t data type + */ + TVM_DLL explicit SizeVar(std::string name_hint = "s", + DataType t = DataType::Int(32)); + /*! + * \brief Get pointer to the internal value. + * \return the corresponding Variable. + */ + const SizeVarNode* operator->() const { + return get(); + } + /*! + * \brief Get pointer to the internal value. + * \return the corresponding Variable. + */ + const SizeVarNode* get() const { + return static_cast(data_.get()); + } + /*! \brief type indicate the container type */ + using ContainerType = SizeVarNode; +}; + + +/*! \brief container class of iteration variable. */ +class IterVarNode; + +using Region = Array; + +/*! + * \brief Type of iteration variable. + * Each IterVar have a specific type. + * + * The type of iter var can be overriden via + * stage.iter_var_attrs given they are compatible. + */ +enum IterVarType : int { + /*! + * \brief Data parallel iteration. + * This normally corresponds to axis of Tensor. + * Allow all IterVar manipulations. + * + * \note This does not mean the loop + * have to be executed in parallel fashion. + */ + kDataPar = 0, + /*! + * \brief The IterVar itself is a thread-index + * of a fixed thread launching group. + * Note that this is already assumed to be paralellized. + * + * Disallow: split/fuse/vectorize/parallel + */ + kThreadIndex = 1, + /*! + * \brief Communicative reduction. + * Cannot be directly parallelized. + * + * Disallow: parallel/vectorize + */ + kCommReduce = 2, + /*! + * \brief Serial loops with loop carry dependency, + * the iteration must execute in order. + * Cannot be re-ordered. + * + * Disallow: reorder/parallel/vectorize + */ + kOrdered = 3, + /*! + * \brief IterVar is opaque, + * + * May not corresponds to any generated loop + * Disallow all IterVar manipulations and compute_at + * + * \note This is usually used to implement composite op + * or external op, where the + */ + kOpaque = 4, + // The following are possible additional + // types that are provided during schedule + /*! + * \brief The execution is unrolled. + */ + kUnrolled = 5, + /*! + * \brief The loop is vectorized. + */ + kVectorized = 6, + /*! + * \brief The loop is parallelized. + */ + kParallelized = 7, + /*! + * \brief Marks boundary of tensorization intrinsic. + */ + kTensorized = 8 +}; + +/*! + * \brief Iteration Variable, + * represents an iteration over an integer interval. + */ +class IterVar : public ObjectRef { + public: + // construct a new iter var without a domain + IterVar() {} + // construct from shared ptr. + explicit IterVar(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const IterVarNode* operator->() const; + /*! + * \return the corresponding var in the IterVar. + */ + inline operator PrimExpr() const; + /*! \brief specify container node */ + using ContainerType = IterVarNode; +}; + +using Domain = Array; + +/*! + * \brief An iteration variable representing an iteration + * over a one dimensional interval. + */ +class IterVarNode : public Object { + public: + /*! + * \brief the domain of iteration, if known, can be None + * For the intermediate schedule node, before schedule. + */ + Range dom; + /*! \brief The looping variable */ + Var var; + /*! \brief The type of the IterVar */ + IterVarType iter_type; + /*! + * \brief additional tag on the iteration variable, + * set this if this is binded already to a known thread tag. + */ + std::string thread_tag; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dom", &dom); + v->Visit("var", &var); + v->Visit("iter_type", &iter_type); + v->Visit("thread_tag", &thread_tag); + } + + bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const { + return + equal(dom, other->dom) && + equal.DefEqual(var, other->var) && + equal(iter_type, other->iter_type) && + equal(thread_tag, other->thread_tag); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(dom); + hash_reduce.DefHash(var); + hash_reduce(iter_type); + hash_reduce(thread_tag); + } + + TVM_DLL static IterVar make(Range dom, Var var, + IterVarType iter_type, + std::string thread_tag = ""); + + static constexpr const char* _type_key = "IterVar"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); +}; + +// inline implementations +inline const IterVarNode* IterVar::operator->() const { + return static_cast(data_.get()); +} + +inline IterVar::operator PrimExpr() const { + return (*this)->var; +} + +inline const char* IterVarType2String(IterVarType t) { + switch (t) { + case kDataPar: return "DataPar"; + case kThreadIndex: return "ThreadIndex"; + case kCommReduce: return "CommReduce"; + case kOrdered: return "Ordered"; + case kOpaque: return "Opaque"; + case kUnrolled: return "Unrolled"; + case kVectorized: return "Vectorized"; + case kParallelized: return "Parallelized"; + case kTensorized: return "Tensorized"; + } + return "Unknown"; +} +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_VAR_H_ diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 653c3954f489..bd8e33fe4c3b 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -24,11 +24,11 @@ from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not -from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let +from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let from .expr import IterVar, Any from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For -from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt +from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list from .function import PrimFunc diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index a192fce6439a..20a3bca75368 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -14,18 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Expression AST Node in TVM. +# pylint: disable=redefined-builtin +"""TIR expression nodes. -User do not need to deal with expression AST node directly. -But they can be helpful for developer to do quick proptyping. -While not displayed in the document and python file. Each expression node have subfields that can be visited from python side. - For example, you can use addexp.a to get the left operand of an Add node. .. code-block:: python - x = te.var("n") + x = tvm.tir.Var("n", "int32") y = x + 2 assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) @@ -858,6 +855,23 @@ def __init__(self, dtype, buffer_var, index, predicate=None): _ffi_api.Load, dtype, buffer_var, index, *args) +@tvm._ffi.register_object +class BufferLoad(PrimExprWithOp): + """Buffer load node. + + Parameters + ---------- + buffer : Buffer + The buffer to be loaded. + + indices : List[PrimExpr] + The buffer indices. + """ + def __init__(self, buffer, indices): + self.__init_handle_by_constructor__( + _ffi_api.BufferLoad, buffer, indices) + + @tvm._ffi.register_object class Ramp(PrimExprWithOp): """Ramp node. diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 65c72ddfeb36..0badad3c092f 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -16,15 +16,12 @@ # under the License. """Statement AST Node in TVM. -User do not need to deal with AST node directly. -But they can be helpful for developer to do quick proptyping. -While not displayed in the document and python file. Each statement node have subfields that can be visited from python side. .. code-block:: python - x = te.var("n") - a = te.var("array", "handle") + x = tvm.tir.Var("n", "int32") + a = tvm.tir.Var("array", "handle") st = tvm.tir.stmt.Store(a, x + 1, 1) assert isinstance(st, tvm.tir.stmt.Store) assert(st.buffer_var == a) @@ -163,6 +160,26 @@ def __init__(self, buffer_var, value, index, predicate=None): _ffi_api.Store, buffer_var, value, index, *args) +@tvm._ffi.register_object +class BufferStore(Stmt): + """Buffer store node. + + Parameters + ---------- + buffer : Buffer + The buffer. + + value : PrimExpr + The value we to be stored. + + indices : List[PrimExpr] + The indices location to be stored. + """ + def __init__(self, buffer, value, indices): + self.__init_handle_by_constructor__( + _ffi_api.BufferStore, buffer, value, indices) + + @tvm._ffi.register_object class Provide(Stmt): """Provide node. diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index bee025687173..891d13723d9a 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -407,6 +407,22 @@ PrimExpr AnyNode::make() { return PrimExpr(n); } +BufferLoad::BufferLoad(Buffer buffer, Array indices) { + ObjectPtr node = make_object(); + node->dtype = buffer->dtype; + node->buffer = std::move(buffer); + node->indices = std::move(indices); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.BufferLoad") +.set_body_typed([](Buffer buffer, Array indices) { + return BufferLoad(buffer, indices); +}); + +TVM_REGISTER_NODE_TYPE(BufferLoadNode); + + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index f8371f3765a4..57ff627ceaf1 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -36,6 +36,10 @@ void ExprVisitor::VisitExpr_(const LoadNode* op) { this->VisitExpr(op->predicate); } +void ExprVisitor::VisitExpr_(const BufferLoadNode* op) { + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); +} + void ExprVisitor::VisitExpr_(const LetNode* op) { this->VisitExpr(op->value); this->VisitExpr(op->body); @@ -128,6 +132,16 @@ PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { } } +PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { + auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + Array indices = MutateArray(op->indices, fmutate); + if (indices.same_as(op->indices)) { + return GetRef(op); + } else { + return BufferLoad(op->buffer, indices); + } +} + PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr body = this->VisitExpr(op->body); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index a8fe9cd2bad3..64e7ef572673 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -324,6 +324,21 @@ Stmt EvaluateNode::make(PrimExpr value) { TVM_REGISTER_GLOBAL("tir.Evaluate") .set_body_typed(EvaluateNode::make); +BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices) { + ObjectPtr node = make_object(); + node->buffer = std::move(buffer); + node->value = std::move(value); + node->indices = std::move(indices); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.BufferStore") +.set_body_typed([](Buffer buffer, PrimExpr value, Array indices) { + return BufferStore(buffer, value, indices); +}); + +TVM_REGISTER_NODE_TYPE(BufferStoreNode); + // Printers TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index b4b27b9abef9..ea199821b236 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -160,6 +160,10 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) { this->VisitExpr(op->predicate); } +void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); +} + void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { this->VisitExpr(op->condition); this->VisitStmt(op->then_case); @@ -343,6 +347,17 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) { } } +Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { + Array indices = Internal::Mutate(this, op->indices); + if (indices.same_as(op->indices)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->indices = std::move(indices); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { Array args = Internal::Mutate(this, op->args); PrimExpr value = this->VisitExpr(op->value); diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 290495381283..2e23a6108ccd 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -292,7 +292,18 @@ def test_vars(): assert isinstance(ptype.element_type, tvm.ir.PrimType) +def test_buffer_load_store(): + b = tvm.tir.decl_buffer((10,), "float32") + x = tvm.tir.BufferLoad(b, [0]) + assert isinstance(x, tvm.tir.BufferLoad) + assert x.dtype == "float32" + assert x.buffer == b + s = tvm.tir.BufferStore(b, 0.1, [0]) + assert isinstance(s, tvm.tir.BufferStore) + + if __name__ == "__main__": + test_buffer_load_store() test_vars() test_prim_func() test_cast() diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py index 3fcdc65c30ce..593b845396ef 100644 --- a/tests/python/unittest/test_tir_structural_equal_hash.py +++ b/tests/python/unittest/test_tir_structural_equal_hash.py @@ -166,6 +166,22 @@ def func2(): assert consistent_equal(func2(), func2()) +def test_buffer_load_store(): + b = tvm.tir.decl_buffer((10, 10), "float32") + x = tvm.tir.BufferLoad(b, [0, 1]) + y = tvm.tir.BufferLoad(b, [0, 1]) + z = tvm.tir.BufferLoad(b, [1, 2]) + assert consistent_equal(y, x) + assert not consistent_equal(y, z) + + i = tvm.tir.Var("x", "int32") + sx = tvm.tir.BufferStore(b, 0.1, [0, i]) + sy = tvm.tir.BufferStore(b, 0.1, [0, i]) + sz = tvm.tir.BufferStore(b, 0.1, [1, i]) + assert consistent_equal(sy, sx) + assert not consistent_equal(sy, sz) + + if __name__ == "__main__": test_exprs() test_prim_func() @@ -173,3 +189,4 @@ def func2(): test_array() test_env_func() test_stmt() + test_buffer_load_store()