diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index 6165a2ab546f..b1cb779b4227 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -78,15 +78,15 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond, /*! * \brief Infer a regular domain that covers all the calls or provides within the given statement. * \param body The given statement. - * \param tensor The name of the calls or provides. - * \param consider_calls If calls (read) are considered. - * \param consider_provides If provides (write) are considered. + * \param buffer The buffer to check the access info. + * \param consider_loads If loads are considered. + * \param consider_stores If stores are considered. * \return The domain that covers all the calls or provides within the given statement. */ -Domain DomainTouched(Stmt body, - const te::Tensor &tensor, - bool consider_calls, - bool consider_provides); +Domain DomainTouched(const Stmt& body, + const tir::Buffer& buffer, + bool consider_loads, + bool consider_stores); } // namespace arith } // namespace tvm diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 121dbdde37a6..b9b420ad02b6 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -70,7 +70,7 @@ class ObjAllocatorBase { static_assert(std::is_base_of::value, "make can only be used to create Object"); T* ptr = Handler::New(static_cast(this), - std::forward(args)...); + std::forward(args)...); ptr->type_index_ = T::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index b3ecbf8c08e1..e64ea21f5f7e 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -29,6 +29,7 @@ #define TVM_TE_SCHEDULE_PASS_H_ #include +#include namespace tvm { namespace te { @@ -54,6 +55,26 @@ Map InferBound(const Schedule& sch); */ Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivial_loop); +/*! + * \brief Postprocessing the Stmt generated by ScheduleOps to create + * a PrimFunc that can then be used for further TIR optimizations. + * + * Perform this translation before running any TIR optimizations. + * + * List of actions taken by the function: + * - Remove occurences of te::Tensor, te::Operation in the IR + * and replace them by corresponding IR nodes via tir::Buffer. + * - Add annotation of extern buffers using the buffer_map field + * in the PrimFunc type. + * + * \param arg_list Array of Tensor/Var/Buffer arguments to the function. + * \param body The body of the function. + * \param bindings potential Tensor to Buffer bindings for the Tensors in the body. + */ +PrimFunc SchedulePostProcToPrimFunc(Array arg_list, + Stmt body, + Optional> bindings); + /*! * \brief To automatically inline the element-wise operations. * diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 6764178cc23c..bf0d4f985a92 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -694,7 +694,10 @@ class CallNode : public PrimExprNode { ExternCPlusPlus = 1, /*! \brief Extern "C" without side-effect. */ PureExtern = 2, - /*! \brief Halide-style call, evaluates func(args). */ + /*! + * \brief Halide-style call, evaluates func(args). + * \note Deprecated, move to BufferLoad in the future. + */ Halide = 3, /*! \brief Intrinsic functions. */ Intrinsic = 4, @@ -707,9 +710,15 @@ class CallNode : public PrimExprNode { Array args; /*! \brief Type of calls. */ CallType call_type; - /*! \brief The function to be called. */ + /*! + * \brief The function to be called. + * \note Deprecated, move to BufferLoad in the future. + */ FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ + /*! + * \brief The output value index if func's value is a tuple. + * \note Deprecated, move to BufferLoad in the future. + */ int value_index{0}; void VisitAttrs(AttrVisitor* v) { diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index f3d447e4524a..e6e2de6f24f6 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -164,22 +164,6 @@ Stmt Inline(Stmt stmt, Array args, PrimExpr body); -/*! - * \brief Flatten the multi-dimensional read/write - * to single dimensional Load/Store - * - * \param stmt The stmt to be trasnformed. - * \param extern_buffer Map specifies external - * buffer assignment of input and outputs. - * \param cache_line_size The size of CPU cache line. - * \param create_bound_attribute Whether to create bound attributes. - * \return Transformed stmt. - */ -Stmt StorageFlatten(Stmt stmt, - Map extern_buffer, - int cache_line_size, - bool create_bound_attribute = false); - /*! * \brief Try to modify the AST to support TensorCore * @@ -202,13 +186,6 @@ Stmt RewriteForTensorCore(Stmt stmt, */ bool VerifyCompactBuffer(Stmt stmt); -/*! - * \brief Inject prefetch instructions into stmt. - * \param stmt The statement to be transformed. - * \return Transformed stmt. - */ -Stmt InjectPrefetch(Stmt stmt); - /*! * \brief Decorate the stmt with a device scope, this is helpful for * hardware accelerator without thread blocks. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5bc492fcefb8..20c2d009b93c 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -248,7 +248,6 @@ class StoreNode : public StmtNode { * \endcode * \sa BufferLoad */ -class BufferStore; class BufferStoreNode : public StmtNode { public: /*! \brief The buffer variable. */ @@ -281,6 +280,10 @@ class BufferStoreNode : public StmtNode { TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode); }; +/*! + * \brief Managed reference to BufferStoreNode. + * \sa BufferStoreNode + */ class BufferStore : public Stmt { public: TVM_DLL explicit BufferStore(Buffer buffer, @@ -289,8 +292,80 @@ class BufferStore : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); }; +/*! + * \brief Annotate the region where the buffer need to + * be read and write in the body. + * We only need to allocate the space for the corresponding region. + * + * \note There should be at most one BufferRealize for each buffer. + * BufferRealize is not necessary for external buffers, + * since they are assumed to be fully allocated. + * + * \sa BufferLoad, BufferStore + */ +class BufferRealizeNode : public StmtNode { + public: + /*! \brief The buffer variable. */ + Buffer buffer; + /*! \brief Bounds to be realized */ + Array bounds; + /*! \brief Only realize if condition holds. */ + PrimExpr condition; + /*! \brief The body of realization. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer", &buffer); + v->Visit("bounds", &bounds); + v->Visit("condition", &condition); + v->Visit("body", &body); + } + + bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const { + return + equal(buffer, other->buffer) && + equal(bounds, other->bounds) && + equal(condition, other->condition) && + equal(body, other->body); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(buffer); + hash_reduce(bounds); + hash_reduce(condition); + hash_reduce(body); + } + + BufferRealizeNode() = default; + BufferRealizeNode(Buffer buffer, + Array bounds, + PrimExpr condition, + Stmt body) + : buffer(buffer), bounds(bounds), + condition(condition), body(body) {} + + static constexpr const char* _type_key = "BufferRealize"; + TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode); +}; + +/*! + * \brief Managed reference to BufferRealizeNode. + * \sa BufferRealizeNode + */ +class BufferRealize : public Stmt { + public: + TVM_DLL explicit BufferRealize(Buffer buffer, + Array bounds, + PrimExpr condition, + Stmt body); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); +}; + /*! * \brief Store value into mult-dimensional array defined by func. + * + * \note Deprecated, move to BufferStore in the future. */ class ProvideNode : public StmtNode { public: @@ -430,6 +505,8 @@ class FreeNode : public StmtNode { /*! * \brief Annotate the bounds where func need to be written and read in body. * We will need to allocate space for the corresponding regions. + * + * \note Deprecated, move to BufferRealize in the future. */ class RealizeNode : public StmtNode { public: @@ -747,50 +824,50 @@ class ForNode : public StmtNode { }; /*! - * \brief A prefetch hint of func. + * \brief A prefetch hint for abuffer */ class PrefetchNode : public StmtNode { public: /*! \brief The function to be prefetched. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index; - /*! \brief The data type of the array. */ - DataType dtype; + Buffer buffer; /*! \brief Bounds to be prefetched. */ - Region bounds; + Array bounds; void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("value_index", &value_index); - v->Visit("dtype", &dtype); + v->Visit("buffer", &buffer); v->Visit("bounds", &bounds); } bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const { return - equal(func, other->func) && - equal(value_index, other->value_index) && - equal(dtype, other->dtype) && + equal(buffer, other->buffer) && equal(bounds, other->bounds); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(func); - hash_reduce(value_index); - hash_reduce(dtype); + hash_reduce(buffer); hash_reduce(bounds); } - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds); + PrefetchNode() = default; + PrefetchNode(Buffer buffer, Array bounds) + : buffer(buffer), bounds(bounds) {} static constexpr const char* _type_key = "Prefetch"; TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); }; +/*! + * \brief Managed reference to PrefetchNode. + * \sa PrefetchNode + */ +class Prefetch : public Stmt { + public: + TVM_DLL explicit Prefetch(Buffer buffer, Array bounds); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); +}; + /*! * \brief Auxiliary data structure used in IR Pass to indicate a tensor. */ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index f93e9080a377..a87ff9737d0c 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -92,6 +92,7 @@ class StmtFunctor { virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -121,6 +122,8 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); + IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); + IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode); return vtable; } }; @@ -154,6 +157,7 @@ class TVM_DLL StmtVisitor : void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; + void VisitStmt_(const BufferRealizeNode* op) override; void VisitStmt_(const FreeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const ProvideNode* op) override; @@ -248,6 +252,7 @@ class TVM_DLL StmtMutator : Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; + Stmt VisitStmt_(const BufferRealizeNode* op) override; Stmt VisitStmt_(const FreeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; Stmt VisitStmt_(const ProvideNode* op) override; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e593e1bf0fbc..09ea09731f51 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -58,6 +58,27 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< const std::string& name, const tvm::Array& required); + +/*! + * \brief Inject prefetch instructions into stmt. + * + * \return The pass. + */ +TVM_DLL Pass InjectPrefetch(); + +// TODO(tvm-team): consolidate configs to the PassContext +/*! + * \brief Flatten the multi-dimensional read/write + * to single dimensional Load/Store + * + * \param cache_line_size The size of CPU cache line. + * \param create_bound_attribute Whether to create bound attributes. + * + * \return The Pass + */ +TVM_DLL Pass StorageFlatten(int cache_line_size, + bool create_bound_attribute = false); + /*! * \brief Inject copy intrinsics with optional pad. * diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index c576ffd76e56..0c0591ccf2a1 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -31,7 +31,6 @@ import tvm._ffi from tvm import target as _target -from tvm.tir import ir_pass from tvm.te import schedule from tvm.driver import build_module @@ -46,10 +45,12 @@ def ana_lower(sch, args, # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds, True) - stmt = ir_pass.StorageFlatten(stmt, binds, 64) - stmt = ir_pass.CanonicalSimplify(stmt) + func = schedule.SchedulePostProcToPrimFunc(args, stmt, None) + mod = tvm.IRModule.from_expr(func._move()) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) + mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode - return stmt + return mod["main"].body try: _get_buffer_curve_sample_flatten = tvm._ffi.get_global_func( diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 18a8a47ad439..eea372733ca7 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -85,7 +85,8 @@ def get_binds(args, compact=False, binds=None): def form_body(sch): - """According to the given schedule, form the raw body + """According to the given schedule, form a function. + Parameters ---------- sch : tvm.te.schedule.Schedule @@ -99,13 +100,31 @@ def form_body(sch): sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) - stmt = ir_pass.InjectPrefetch(stmt) return stmt +def _wrap_as_prim_func_pass(flist, name): + """Wrap flist as a function pass. + + This is an temporary adapter before we fully + migrate to the new pass manager. + """ + def _transform(func, *_): + stmt = func.body + for f in flist: + stmt = f(stmt) + # create a new function with updated body. + return tvm.tir.PrimFunc(func.params, + stmt, + func.ret_type, + func.buffer_map, + func.attrs) + return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name) + + def lower(sch, args, - name="default_function", + name="main", binds=None, simple_mode=False): """Lowering step before build into target. @@ -154,56 +173,57 @@ def lower(sch, compact = ir_pass.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, binds) + stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) + + # Start the new style pass manager. + func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) + func = func.with_attr("global_symbol", name) + if cfg.restricted_func: + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({name: func}) # Phase 1 - stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) - stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) - stmt = ir_pass.NarrowDataType(stmt, 32) - stmt = ir_pass.CanonicalSimplify(stmt) - for f in lower_phase1: - stmt = f(stmt) + pass_list = [ + tvm.tir.transform.InjectPrefetch(), + tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers), + tvm.tir.transform.NarrowDataType(32), + tvm.tir.transform.Simplify(), + _wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"), + ] # Phase 2 if not simple_mode: - stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) - if cfg.disable_vectorize: - stmt = ir_pass.SkipVectorize(stmt) - else: - stmt = ir_pass.VectorizeLoop(stmt) - stmt = ir_pass.InjectVirtualThread(stmt) - stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) - stmt = ir_pass.StorageRewrite(stmt) - stmt = ir_pass.UnrollLoop( - stmt, - cfg.auto_unroll_max_step, - cfg.auto_unroll_max_depth, - cfg.auto_unroll_max_extent, - cfg.unroll_explicit) - - for f in lower_phase2: - stmt = f(stmt) + pass_list += [(tvm.tir.transform.LoopPartition(cfg.partition_const_loop))] + + pass_list += [ + tvm.tir.transform.VectorizeLoop(not cfg.disable_vectorize), + tvm.tir.transform.InjectVirtualThread(), + tvm.tir.transform.InjectDoubleBuffer(cfg.double_buffer_split_loop), + tvm.tir.transform.StorageRewrite(), + tvm.tir.transform.UnrollLoop( + cfg.auto_unroll_max_step, + cfg.auto_unroll_max_depth, + cfg.auto_unroll_max_extent, + cfg.unroll_explicit), + _wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"), + ] # Phase 3 - stmt = ir_pass.Simplify(stmt) - stmt = ir_pass.RemoveNoOp(stmt) - if not cfg.disable_select_rewriting: - stmt = ir_pass.RewriteUnsafeSelect(stmt) + pass_list += [ + tvm.tir.transform.Simplify(), + tvm.tir.transform.RemoveNoOp(), + ] - for f in lower_phase3: - stmt = f(stmt) + if not cfg.disable_select_rewriting: + pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] + pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")] # Instrument BoundCheckers if cfg.instrument_bound_checkers: - stmt = ir_pass.InstrumentBoundCheckers(stmt) + pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] - if simple_mode: - return stmt - - f = tvm.tir.PrimFunc(arg_list, stmt).with_attr( - "global_symbol", tvm.runtime.String(name)) - if cfg.restricted_func: - f = f.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: f}) + optimize = tvm.transform.Sequential(pass_list) + mod = optimize(mod) return mod diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 614f9690903a..af0be45b624e 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -157,11 +157,6 @@ class Sequential(Pass): """A pass that works on a sequence of pass objects. Multiple passes can be executed sequentially using this class. - Some typical usage of the sequential pass are: - 1. Users provide a list of passes for optimization. - 2. Only an optimization level is provided so that the backend system has - to glob all passes at this level and below to perform the optimizations. - Note that users can also provide a series of passes that they don't want to apply when running a sequential pass. Pass dependency will be resolved in the backend as well. @@ -173,6 +168,9 @@ class Sequential(Pass): opt_level : Optional[int] The optimization level of this sequential pass. + The opt_level of a default sequential pass is set to 0. + Note that some of the passes within the Sequantial may still not be executed + if their opt_level is higher than the provided opt_level. name : Optional[str] The name of the sequential pass. diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index d2238ad754ac..ddfb6a5f69c1 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -28,7 +28,8 @@ from .expr import IterVar, Any from .stmt import Stmt, LetStmt, AssertStmt, For -from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt +from .stmt import BufferStore, BufferRealize, Store, Provide, Allocate, AttrStmt +from .stmt import Free, Realize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .function import PrimFunc diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index c5b2a7957319..eee5b0b002e0 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -160,6 +160,29 @@ def __init__(self, buffer, value, indices): _ffi_api.BufferStore, buffer, value, indices) +@tvm._ffi.register_object +class BufferRealize(Stmt): + """Buffer realize node. + + Parameters + ---------- + buffer : Buffer + The buffer. + + bounds : List[Range] + The value we to be stored. + + condition : PrimExpr + The realize condition. + + body : Stmt + The body of the statement. + """ + def __init__(self, buffer, bounds, condition, body): + self.__init_handle_by_constructor__( + _ffi_api.BufferRealize, buffer, bounds, condition, body) + + @tvm._ffi.register_object class Provide(Stmt): """Provide node. @@ -348,21 +371,15 @@ class Prefetch(Stmt): Parameters ---------- - func : Operation - The operation to create the function. - - value_index : int - The output value index - - dtype : str - The data type to be prefetched. + buffer : Buffer + The buffer to be prefetched. bounds : list of Range The bounds to be prefetched. """ - def __init__(self, func, value_index, dtype, bounds): + def __init__(self, buffer, bounds): self.__init_handle_by_constructor__( - _ffi_api.Prefetch, func, value_index, dtype, bounds) + _ffi_api.Prefetch, buffer, bounds) def stmt_seq(*args): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index f83bb11ad51e..bb39c1f69131 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -60,6 +60,38 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") +def InjectPrefetch(): + """Inject prefetch instructions into stmt. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectPrefetch() + + +def StorageFlatten(cache_line_size, create_bound_attribute=False): + """Flatten the multi-dimensional read/write to 1D. + + + Parameters + ---------- + cache_line_size: int + The size of CPU cache line. + + create_bound_attribute: + Whether to create bound attributes. + + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute) + + def InjectCopyIntrin(pragma_key, fintrin): """Inject virtual thread loops. diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index bda70fb67cba..2467e758c3e4 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -36,10 +36,14 @@ namespace arith { using namespace tir; // Find Read region of the tensor in the stmt. -class FuncTouchedDomain final : public StmtExprVisitor { +class BufferTouchedDomain final : public StmtExprVisitor { public: - FuncTouchedDomain(const te::Tensor &tensor, bool consider_calls, bool consider_provides) - : tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {} + BufferTouchedDomain(const Buffer &buffer, + bool consider_loads, + bool consider_stores) + : buffer_(buffer), + consider_loads_(consider_loads), + consider_stores_(consider_stores) {} Domain Find(const Stmt& stmt) { operator()(stmt); @@ -80,18 +84,16 @@ class FuncTouchedDomain final : public StmtExprVisitor { } } - void VisitExpr_(const CallNode* op) final { - if (consider_calls_ && tensor_->op.same_as(op->func) - && tensor_->value_index == op->value_index) { - Touch(op->args); + void VisitExpr_(const BufferLoadNode* op) final { + if (consider_loads_ && buffer_.same_as(op->buffer)) { + Touch(op->indices); } StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const ProvideNode* op) final { - if (consider_provides_ && tensor_->op.same_as(op->func) - && tensor_->value_index == op->value_index) { - Touch(op->args); + void VisitStmt_(const BufferStoreNode* op) final { + if (consider_stores_ && buffer_.same_as(op->buffer)) { + Touch(op->indices); } StmtExprVisitor::VisitStmt_(op); } @@ -106,17 +108,17 @@ class FuncTouchedDomain final : public StmtExprVisitor { } } - const te::Tensor &tensor_; - bool consider_calls_, consider_provides_; + const Buffer &buffer_; + bool consider_loads_, consider_stores_; std::vector > bounds_; std::unordered_map dom_map_; }; -Domain DomainTouched(Stmt stmt, - const te::Tensor &tensor, - bool consider_calls, - bool consider_provides) { - return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt); +Domain DomainTouched(const Stmt& stmt, + const Buffer& buffer, + bool consider_loads, + bool consider_stores) { + return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt); } TVM_REGISTER_GLOBAL("arith.DomainTouched") diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e38179e965f5..c3802b1a63f4 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -130,35 +130,6 @@ transform::Pass Filter(FCond fcond) { } -IRModule BuildIRModule(const Array& out_arg_list, - tir::Stmt stmt, - const std::string& name, - const BuildConfig& config) { - Array params; - Map buffer_map; - - for (auto var : out_arg_list) { - if (auto* n = var.as()) { - params.push_back(GetRef(n)); - } else { - tir::Buffer buffer = Downcast(var); - tir::Var bptr(buffer->name, DataType::Handle()); - params.push_back(bptr); - buffer_map.Set(bptr, buffer); - } - } - - auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - - if (config->restricted_func) { - f = WithAttr(std::move(f), "tir.noalias", Integer(1)); - } - - return IRModule(Map({{GlobalVar(name), f}})); -} - - IRModule lower(te::Schedule sch, const Array& args, const std::string& name, @@ -168,23 +139,31 @@ IRModule lower(te::Schedule sch, sch = sch.normalize(); - // Phase 0 + // Before TIR transformation. auto bounds = te::InferBound(sch); auto stmt = te::ScheduleOps(sch, bounds, false); - stmt = tir::InjectPrefetch(stmt); - bool compact = tir::VerifyCompactBuffer(stmt); + Map out_binds; GetBinds(args, compact, binds, &out_binds, &out_arg_list, config); - // Phase 1 - stmt = tir::StorageFlatten(stmt, out_binds, 64, - config->instrument_bound_checkers); + // build the function + tir::PrimFunc f = te::SchedulePostProcToPrimFunc( + out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + if (config->restricted_func) { + f = WithAttr(std::move(f), "tir.noalias", Integer(1)); + } - // convert to IRModule. - auto mod = BuildIRModule(out_arg_list, stmt, name, config); + auto mod = IRModule(Map({{GlobalVar(name), f}})); auto pass_list = Array(); + // Phase 0 + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back( + tir::transform::StorageFlatten(64, config->instrument_bound_checkers)); + // Phase 1 + pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop)); pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize)); diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 4ecfe9472901..d022134065c6 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -132,8 +132,8 @@ MakeLoopNest(const Stage& stage, for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) { nest[i + 1].emplace_back( AttrStmtNode::make(it_attr->prefetch_data[j], - tir::attr::prefetch_scope, - it_attr->prefetch_offset[j], no_op)); + tir::attr::prefetch_scope, + it_attr->prefetch_offset[j], no_op)); } } } else if (bind_iv->thread_tag == "vthread" || diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc new file mode 100644 index 000000000000..bb52be4c0202 --- /dev/null +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file schedule_postproc_to_primfunc.cc + * + * \brief Translate the function body generated by ScheduleOps + * with te related dialects that incorporates Tensor + * into the Stmts to a PrimFunc. + * + * Perform this translation before running any TIR optimizations. + * + * Rationale: The body generated by ScheduleOps is not + * a formal PrimFunc and cannot be used for further optimization. + * This function canonicalize that body and creates a formal PrimFunc. + * + * List of actions taken by the function: + * - Remove occurences of te::Tensor, te::Operation in the IR + * and replace them by corresponding IR nodes via tir::Buffer. + * - Add annotation of extern buffers using the buffer_map field + * in the PrimFunc type. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace te { + +// create a buffer for tensor. +Buffer CreateBufferFor(const Tensor& tensor) { + std::string name = tensor->op->name; + if (tensor->op->num_outputs() != 1) { + name += ".v" + std::to_string(tensor->value_index); + } + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name); + return buffer; +} + +// A remapper that maps tensor to buffer +class TensorToBufferMapper : public StmtExprMutator { + public: + explicit TensorToBufferMapper(std::unordered_map buffer_map) + : buffer_map_(buffer_map) { + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + // TODO(tvm-team): remove realize_scope, turn the info into + // Buffer's scope field in this pass. + if (op->attr_key == tir::attr::realize_scope || + op->attr_key == tir::attr::double_buffer_scope) { + Stmt body = op->body; + Operation operation = Downcast(op->node); + for (int i = operation->num_outputs(); i != 0; --i) { + Buffer buffer = GetOrAllocBuffer(operation.output(i - 1)); + body = AttrStmtNode::make( + buffer, op->attr_key, op->value, body); + } + return body; + } else if (op->attr_key == tir::attr::buffer_bind_scope) { + Array tuple = Downcast >(op->node); + Tensor tensor = Downcast(tuple[1]); + return AttrStmtNode::make( + Array{tuple[0], GetOrAllocBuffer(tensor)}, + op->attr_key, op->value, op->body); + } else if (op->attr_key == tir::attr::buffer_dim_align|| + op->attr_key == tir::attr::prefetch_scope) { + Tensor tensor = Downcast(op->node); + Buffer buffer = GetOrAllocBuffer(tensor); + return AttrStmtNode::make( + buffer, op->attr_key, op->value, op->body); + } else { + return ret; + } + } + + Stmt VisitStmt_(const RealizeNode* op) final { + Tensor tensor = Downcast(op->func).output(op->value_index); + Buffer buffer = GetOrAllocBuffer(tensor); + + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + return BufferRealize(buffer, op->bounds, op->condition, op->body); + } + + Stmt VisitStmt_(const ProvideNode* op) final { + Tensor tensor = Downcast(op->func).output(op->value_index); + Buffer buffer = GetBuffer(tensor); + + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + return BufferStore(buffer, op->value, op->args); + } + + PrimExpr VisitExpr_(const CallNode* op) final { + auto ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + + if (op->call_type == CallNode::Halide) { + Tensor tensor = Downcast(op->func).output(op->value_index); + Buffer buffer = GetBuffer(tensor); + return tir::BufferLoad(buffer, op->args); + } else { + return ret; + } + } + + private: + Buffer GetOrAllocBuffer(const Tensor& tensor) { + return GetBuffer(tensor, true); + } + + Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { + auto it = buffer_map_.find(tensor); + if (it != buffer_map_.end()) return it->second; + CHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; + + auto buffer = CreateBufferFor(tensor); + buffer_map_[tensor] = buffer; + return buffer; + } + + // maps tensor to buffer. + std::unordered_map buffer_map_; +}; + + +PrimFunc SchedulePostProcToPrimFunc(Array arg_list, + Stmt body, + Optional> extern_buffer_opt) { + std::unordered_map extern_buffer; + + if (extern_buffer_opt.defined()) { + auto v = extern_buffer_opt.value(); + extern_buffer = std::unordered_map(v.begin(), v.end()); + } + + Array params; + Map buffer_map; + + for (auto var : arg_list) { + if (auto* n = var.as()) { + params.push_back(GetRef(n)); + } else if (auto* n = var.as()) { + te::Tensor tensor = GetRef(n); + CHECK(!extern_buffer.count(tensor)); + + tir::Buffer buffer = CreateBufferFor(tensor); + tir::Var bptr(buffer->name, DataType::Handle()); + params.push_back(bptr); + buffer_map.Set(bptr, buffer); + extern_buffer[tensor] = buffer; + } else { + tir::Buffer buffer = Downcast(var); + tir::Var bptr(buffer->name, DataType::Handle()); + params.push_back(bptr); + buffer_map.Set(bptr, buffer); + } + } + + body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body)); + return tir::PrimFunc(params, body, VoidType(), buffer_map); +} + +TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc") +.set_body_typed(SchedulePostProcToPrimFunc); + +} // namespace te +} // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 65d424e31212..03925ec3783a 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -645,6 +645,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; + } + } + p->stream << "]"; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1f6a7dd027ea..f8e82ea787be 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -253,24 +253,14 @@ TVM_REGISTER_GLOBAL("tir.Realize") .set_body_typed(RealizeNode::make); -Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) { - for (size_t i = 0; i < bounds.size(); ++i) { - CHECK(bounds[i]->min.defined()); - CHECK(bounds[i]->extent.defined()); - CHECK(bounds[i]->min.dtype().is_scalar()); - CHECK(bounds[i]->extent.dtype().is_scalar()); - } - - ObjectPtr node = make_object(); - node->func = std::move(func); - node->value_index = value_index; - node->dtype = dtype; - node->bounds = std::move(bounds); - return Stmt(node); +Prefetch::Prefetch(Buffer buffer, Array bounds) { + data_ = make_object(buffer, bounds); } TVM_REGISTER_GLOBAL("tir.Prefetch") -.set_body_typed(PrefetchNode::make); +.set_body_typed([](Buffer buffer, Array bounds) { + return Prefetch(buffer, bounds); +}); SeqStmt::SeqStmt(Array seq) { @@ -326,6 +316,25 @@ TVM_REGISTER_GLOBAL("tir.BufferStore") TVM_REGISTER_NODE_TYPE(BufferStoreNode); + +BufferRealize::BufferRealize(Buffer buffer, + Array bounds, + PrimExpr condition, + Stmt body) { + data_ = make_object( + buffer, bounds, condition, body); +} + +TVM_REGISTER_GLOBAL("tir.BufferRealize") +.set_body_typed([](Buffer buffer, + Array bounds, + PrimExpr condition, + Stmt body) { + return BufferRealize(buffer, bounds, condition, body); +}); + +TVM_REGISTER_NODE_TYPE(BufferRealizeNode); + // Printers TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -432,6 +441,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '\n'; }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) p->stream << ", "; + } + p->stream << "]"; + p->stream << " = "; + p->Print(op->value); + p->stream << '\n'; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -458,6 +482,34 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '\n'; }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "buffer_realize " << op->buffer->name << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << " {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + p->stream << "}\n"; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -493,7 +545,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << "prefetch " << op->func->func_name() << "("; + p->stream << "prefetch " << op->buffer << "("; for (size_t i = 0; i < op->bounds.size(); ++i) { p->stream << "["; p->Print(op->bounds[i]->min); @@ -503,9 +555,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) if (i < op->bounds.size() - 1) p->stream << ", "; } p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ed3c2c75ef47..5e584ebf55f4 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -158,9 +158,19 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) { } void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { + this->VisitExpr(op->value); VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } +void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { + VisitArray(op->bounds, [this](const Range& r) { + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); + this->VisitExpr(op->condition); + this->VisitStmt(op->body); +} + void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { this->VisitExpr(op->condition); this->VisitStmt(op->then_case); @@ -336,16 +346,38 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) { } Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { + PrimExpr value = this->VisitExpr(op->value); Array indices = Internal::Mutate(this, op->indices); - if (indices.same_as(op->indices)) { + + if (value.same_as(op->value) && + indices.same_as(op->indices)) { return GetRef(op); } else { auto n = CopyOnWrite(op); + n->value = std::move(value); n->indices = std::move(indices); return Stmt(n); } } +Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { + Region bounds = Internal::Mutate(this, op->bounds); + PrimExpr condition = this->VisitExpr(op->condition); + Stmt body = this->VisitStmt(op->body); + + if (bounds.same_as(op->bounds) && + condition.same_as(op->condition) && + body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->bounds = std::move(bounds); + n->condition = std::move(condition); + n->body = std::move(body); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { Array args = Internal::Mutate(this, op->args); PrimExpr value = this->VisitExpr(op->value); diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 65981b9b62f5..4d7ed5dc3c85 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -75,15 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute") } }); -TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() <= 3) { - *ret = StorageFlatten(args[0], args[1], args[2]); - } else { - *ret = StorageFlatten(args[0], args[1], args[2], args[3]); - } - }); - TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") .set_body_typed ([](const Stmt& stmt, @@ -116,7 +107,6 @@ REGISTER_PASS(ConvertSSA); REGISTER_PASS(VerifySSA); REGISTER_PASS(Inline); REGISTER_PASS(IRTransform); -REGISTER_PASS(InjectPrefetch); REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(VerifyCompactBuffer); diff --git a/src/tir/pass/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc similarity index 79% rename from src/tir/pass/inject_prefetch.cc rename to src/tir/transforms/inject_prefetch.cc index 894ff3864864..e9dae0a5dfc9 100644 --- a/src/tir/pass/inject_prefetch.cc +++ b/src/tir/transforms/inject_prefetch.cc @@ -21,9 +21,12 @@ * \file inject_prefetch.cc */ // Inject prefetch op in HalideIR +#include #include +#include #include -#include +#include +#include #include #include @@ -39,9 +42,9 @@ class PrefetchInjector : public StmtMutator { Stmt ret = StmtMutator::VisitStmt_(op); op = ret.as(); if (op && op->attr_key == attr::prefetch_scope) { - te::Tensor ts = Downcast(op->node); + Buffer buffer = Downcast(op->node); CHECK_NE(loop_nest_.size(), 0U); - Domain domain = DomainTouched(op->body, ts, true, false); + Domain domain = DomainTouched(op->body, buffer, true, false); Region region; auto iter_var = loop_nest_.back().get(); @@ -49,7 +52,7 @@ class PrefetchInjector : public StmtMutator { for (Range r : domain) { if (!r.defined()) { - LOG(WARNING) << "Cannot decide prefetch region for " << ts; + LOG(WARNING) << "Cannot decide prefetch region for " << buffer; return op->body; } Range res(EvalSet(r, vectorized_).cover_range(none)); @@ -58,7 +61,7 @@ class PrefetchInjector : public StmtMutator { vectorized_.erase(iter_var); - Stmt prefetch = PrefetchNode::make(ts->op, ts->value_index, ts->dtype, region); + Stmt prefetch = Prefetch(buffer, region); return SeqStmt({prefetch, op->body}); } return ret; @@ -90,5 +93,22 @@ Stmt InjectPrefetch(Stmt stmt) { return PrefetchInjector()(std::move(stmt)); } + +namespace transform { + +Pass InjectPrefetch() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = PrefetchInjector()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch") +.set_body_typed(InjectPrefetch); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc similarity index 77% rename from src/tir/pass/storage_flatten.cc rename to src/tir/transforms/storage_flatten.cc index f9533fa4820a..99d437d9c24e 100644 --- a/src/tir/pass/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -19,22 +19,24 @@ /*! * \file storage_flatten.cc + * \brief Flattens storage from multi-dimensional array to 1D buffer access */ -// Flattens storage from multi-dimensional array to 1D -// buffer access as in Halide pipeline. +// The pass definition originates from Halide pipeline. + +#include #include #include #include #include #include #include -#include +#include #include #include #include #include -#include "ir_util.h" -#include "arg_binder.h" +#include "../pass/ir_util.h" +#include "../pass/arg_binder.h" #include "../../arith/compute_expr.h" #include "../../arith/ir_visitor_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" @@ -49,16 +51,17 @@ using intrinsic::tvm_address_of; class StorageFlattener : public StmtExprMutator { public: - explicit StorageFlattener(Map extern_buffer, - int cache_line_size, bool create_bound_attributes, - IRVisitorWithAnalyzer* bounded_analyzer) - : bounded_analyzer_(bounded_analyzer), + explicit StorageFlattener(const Map& extern_buffer_map, + int cache_line_size, + bool create_bound_attributes, + IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) { - for (auto kv : extern_buffer) { + for (auto kv : extern_buffer_map) { BufferEntry e; e.buffer = kv.second; e.external = true; - buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e; + buf_map_[kv.second] = e; } cache_line_size_ = cache_line_size; } @@ -82,17 +85,14 @@ class StorageFlattener : public StmtExprMutator { storage_scope_[op->node.get()] = op->value.as()->value; return this->VisitStmt(op->body); } else if (op->attr_key == attr::double_buffer_scope && - op->node->IsInstance()) { - auto func = Downcast(op->node); + op->node->IsInstance()) { + auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); - for (int i = 0; i < func->num_outputs(); ++i) { - TensorKey key{func, i}; - auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; - body = AttrStmtNode::make( - it->second.buffer->data, op->attr_key, op->value, body); - } + auto it = buf_map_.find(buffer); + CHECK(it != buf_map_.end()) + << "Cannot find allocated buffer for " << buffer; + body = AttrStmtNode::make( + it->second.buffer->data, op->attr_key, op->value, std::move(body)); return body; } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); @@ -104,11 +104,10 @@ class StorageFlattener : public StmtExprMutator { } else if (op->attr_key == attr::buffer_bind_scope) { return HandleBufferBindScope(op); } else if (op->attr_key == attr::buffer_dim_align) { - auto tensor = Downcast(op->node); + auto buffer = Downcast(op->node); const CallNode* tuple = op->value.as(); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); - TensorKey key{tensor->op, tensor->value_index}; - auto& vinfo = dim_align_[key]; + auto& vinfo = dim_align_[buffer]; int dim = tuple->args[0].as()->value; if (static_cast(dim) >= vinfo.size()) { vinfo.resize(dim + 1); @@ -122,18 +121,21 @@ class StorageFlattener : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } - Stmt VisitStmt_(const ProvideNode* op) final { - if (create_bound_attributes_) - shape_collector_.clear(); + Stmt VisitStmt_(const BufferStoreNode* op) final { + if (create_bound_attributes_) shape_collector_.clear(); Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - TensorKey key{op->func, op->value_index}; + op = stmt.as(); + + const auto& key = op->buffer; + auto it = buf_map_.find(key); CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + << "Cannot find allocated buffer for " << key; + const BufferEntry& e = it->second; CHECK(!e.released) << "Read a buffer that is already out of scope"; + if (is_opengl_) { return EvaluateNode::make(CallNode::make( DataType(), @@ -141,7 +143,7 @@ class StorageFlattener : public StmtExprMutator { {e.buffer->data, op->value}, CallNode::Intrinsic)); } else { - Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value); + Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back( std::make_pair(e.buffer->data, e.buffer->shape)); @@ -158,8 +160,9 @@ class StorageFlattener : public StmtExprMutator { } } - Stmt VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + Stmt VisitStmt_(const BufferRealizeNode* op) final { + const auto& key = op->buffer; + if (buf_map_.count(key)) { CHECK(buf_map_.at(key).external); return this->VisitStmt(op->body); @@ -172,10 +175,9 @@ class StorageFlattener : public StmtExprMutator { shape.push_back(r->extent); } // deduce current storage scope. - auto it = storage_scope_.find(op->func.get()); + auto it = storage_scope_.find(op->buffer.get()); CHECK(it != storage_scope_.end()) - << "Cannot find storage scope of " << op->func - << " value_index=" << op->value_index; + << "Cannot find storage scope of " << op->buffer; StorageScope skey; const std::string& strkey = it->second; if (strkey.length() == 0) { @@ -188,13 +190,14 @@ class StorageFlattener : public StmtExprMutator { } // use small alignment for small arrays + auto dtype = op->buffer->dtype; int32_t const_size = AllocateNode::constant_allocation_size(shape); - int align = GetTempAllocaAlignment(op->dtype, const_size); + int align = GetTempAllocaAlignment(dtype, const_size); if (skey.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(skey.to_string()); if (info.defined()) { - align = (info->max_simd_bits + op->dtype.bits() - 1) / op->dtype.bits(); - CHECK_LE(const_size * op->dtype.bits(), info->max_num_bits) + align = (info->max_simd_bits + dtype.bits() - 1) / dtype.bits(); + CHECK_LE(const_size * dtype.bits(), info->max_num_bits) << "Allocation exceed bound of memory tag " << skey.to_string(); } } @@ -210,7 +213,7 @@ class StorageFlattener : public StmtExprMutator { PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = tir::Simplify(stride); + stride = bound_analyzer_->Simplify(stride); } rstrides.push_back(stride); stride = stride * shape[dim]; @@ -219,9 +222,9 @@ class StorageFlattener : public StmtExprMutator { } e.buffer = BufferNode::make( - Var(key.GetName(), DataType::Handle()), - op->dtype, shape, strides, PrimExpr(), - key.GetName(), skey.to_string(), + Var(op->buffer->data->name_hint, DataType::Handle()), + op->buffer->dtype, shape, strides, PrimExpr(), + op->buffer->name, skey.to_string(), align, 0, kDefault); buf_map_[key] = e; @@ -285,36 +288,36 @@ class StorageFlattener : public StmtExprMutator { } } - PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op != nullptr && op->call_type == CallNode::Halide) { - TensorKey key{op->func, op->value_index}; - auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; - const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; + op = expr.as(); - if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { + const auto& key = op->buffer; + + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) + << "Cannot find allocated buffer for " << key; + const BufferEntry& e = it->second; + CHECK(!e.released) + << "Read a buffer that is already out of scope"; + + if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back( std::make_pair(e.buffer->data, e.buffer->shape)); - } - return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype); - } else { - return expr; } + return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype); } + Stmt VisitStmt_(const PrefetchNode *op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); CHECK(op != nullptr); - TensorKey key{op->func, op->value_index}; + + const auto& key = op->buffer; auto it = buf_map_.find(key); CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; CHECK(!e.released) @@ -340,7 +343,7 @@ class StorageFlattener : public StmtExprMutator { for (int i = op->bounds.size() - 1; i > starts; --i) { args.push_back(op->bounds[i]->min); } - auto &func_name = op->func->func_name(); + auto &func_name = op->buffer->name; vars.push_back(Var( "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); args.push_back(op->bounds[starts]->min + stride * vars.back()); @@ -358,7 +361,7 @@ class StorageFlattener : public StmtExprMutator { PrimExpr address = CallNode::make( DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); PrimExpr prefetch = CallNode::make( - op->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); + op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); stmt = EvaluateNode::make(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); @@ -367,6 +370,26 @@ class StorageFlattener : public StmtExprMutator { return stmt; } + PrimExpr VisitExpr_(const CallNode* op) final { + CHECK(op->call_type != CallNode::Halide) + << "Cannot handle Halide calls " + << " please run SchedulePostProcToPrimFunc first"; + return StmtExprMutator::VisitExpr_(op); + } + + Stmt VisitStmt_(const ProvideNode* op) final { + LOG(FATAL) << "Cannot handle Provide " + << " please run SchedulePostProcToPrimFunc first"; + return Stmt(); + } + + Stmt VisitStmt_(const RealizeNode* op) final { + LOG(FATAL) << "Cannot handle Realize " + << " please run SchedulePostProcToPrimFunc first"; + return Stmt(); + } + + private: // The specific tensor data layout is not determined before // StorageFlatten pass. We use buffer_bind_scope @@ -406,14 +429,16 @@ class StorageFlattener : public StmtExprMutator { Array arr = Downcast > (op->node); CHECK_EQ(arr.size(), 2U); const BufferNode* buffer = arr[0].as(); - const te::TensorNode* tensor = arr[1].as(); + const BufferNode* target = arr[1].as(); const CallNode* tuple = op->value.as(); - CHECK(buffer && tensor); + CHECK(buffer && target); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); - TensorKey key{tensor->op, tensor->value_index}; - CHECK(buf_map_.count(key)) - << "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index; - const BufferEntry& be = buf_map_.at(key); + auto key = GetRef(target); + + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) + << "Cannot find buffer of " << key; + const BufferEntry& be = it->second; CHECK(!be.released); CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); Array begins, extents; @@ -426,7 +451,7 @@ class StorageFlattener : public StmtExprMutator { } else { for (size_t i = 0; i < tuple->args.size(); i += 2) { begins.push_back(tuple->args[i]); - auto new_extent = bounded_analyzer_->Simplify(tuple->args[i+1]); + auto new_extent = bound_analyzer_->Simplify(tuple->args[i+1]); extents.push_back(new_extent); } } @@ -451,6 +476,7 @@ class StorageFlattener : public StmtExprMutator { } return body; } + // The buffer entry in the flatten map struct DimAlignInfo { int align_factor{0}; @@ -509,9 +535,10 @@ class StorageFlattener : public StmtExprMutator { // Variable remap std::unordered_map var_remap_; // Buffer map - std::unordered_map buf_map_; + std::unordered_map buf_map_; // Dimension alignment - std::unordered_map > dim_align_; + std::unordered_map, + ObjectHash, ObjectEqual> dim_align_; // Storage scope std::unordered_map storage_scope_; // The current thread scope. @@ -520,7 +547,7 @@ class StorageFlattener : public StmtExprMutator { std::vector>> shape_collector_; // bounds populator. We really need the analyzer from it. // However - IRVisitorWithAnalyzer* bounded_analyzer_; + IRVisitorWithAnalyzer* bound_analyzer_; // The size of cacheline int cache_line_size_; // The current stage is an OpenGL shader. @@ -529,15 +556,37 @@ class StorageFlattener : public StmtExprMutator { bool create_bound_attributes_{false}; }; -Stmt StorageFlatten(Stmt stmt, Map extern_buffer, - int cache_line_size, bool create_bound_attributes) { - IRVisitorWithAnalyzer bounded_analyzer; - bounded_analyzer(stmt); - stmt = - StorageFlattener(extern_buffer, cache_line_size, - create_bound_attributes, &bounded_analyzer)(std::move(stmt)); - return stmt; +PrimFunc StorageFlatten(PrimFunc func, + int cache_line_size, + bool create_bound_attributes) { + auto fptr = func.CopyOnWrite(); + + IRVisitorWithAnalyzer bound_analyzer; + bound_analyzer(fptr->body); + fptr->body = StorageFlattener(fptr->buffer_map, + cache_line_size, + create_bound_attributes, + &bound_analyzer)(std::move(fptr->body)); + return func; } + +namespace transform { + +// TODO(tvm-team): consolidate configs to the PassContext +Pass StorageFlatten(int cache_line_size, + bool create_bound_attributes) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return StorageFlatten( + std::move(f), cache_line_size, create_bound_attributes); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten") +.set_body_typed(StorageFlatten); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index 0d769aabf247..10337218dc87 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -22,21 +22,25 @@ def test_domain_touched(): j = te.var('j') n = tvm.runtime.convert(100) m = te.var('m') - a = te.placeholder((n, m), name = 'a') - b = te.placeholder((n, m), name = 'b') + + a = tvm.tir.decl_buffer((n, m), name='a') + b = tvm.tir.decl_buffer((n, m), name='b') + + ir = tvm.tir.For( i, 0, n, 0, 0, tvm.tir.For(j, 0, m, 0, 0, - tvm.tir.Provide( - a.op, - 0, - tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) + - tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0), + tvm.tir.BufferStore( + a, + tvm.tir.BufferLoad(b, [i - 1, j + 1]) + + tvm.tir.BufferLoad(a, [i - 1, j - 1]), [i, j] ) ) ) + a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False) + assert a_domain_r[0].min.value == -1 assert a_domain_r[0].extent.value == 100 assert a_domain_r[1].min.value == -1 diff --git a/tests/python/unittest/test_te_build_lower.py b/tests/python/unittest/test_te_build_lower.py index 442c4fed7b2f..b1d754605a46 100644 --- a/tests/python/unittest/test_te_build_lower.py +++ b/tests/python/unittest/test_te_build_lower.py @@ -48,9 +48,9 @@ def test_split_uneven_unique_likely(): x, y = c.op.axis sch = te.create_schedule(c.op) xo, xi = sch[c].split(x, 5) - stmt = tvm.lower(sch, [a, b, c], simple_mode=True) + stmt = tvm.lower(sch, [a, b, c])["main"].body assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse) - assert str(stmt.body.body.body).count("likely") == 1 + if __name__ == "__main__": test_lower_rfactor() diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index b525d018340d..5b4a1c92a7e4 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -365,7 +365,7 @@ def foo(a): a = te.placeholder((8, 4), 'float32') c = foo(a) s = te.create_schedule(c.op) - ir = tvm.lower(s, [a, c], simple_mode=True) + ir = tvm.lower(s, [a, c]) func, ins, outs = run_and_check(foo, [a], target='cuda') run_and_check(func, ins, outs=outs, target='cuda') @@ -517,7 +517,7 @@ def upstream(a): c = te.compute((20, ), lambda x: a[x] + b[x]) d = upstream(c) sch = te.create_schedule([c.op, d.op]) - ir = tvm.lower(sch, [a, b, d], simple_mode=True) + ir = tvm.lower(sch, [a, b, d]) func = tvm.build(sch, [a, b, d]) assert(func) @@ -730,7 +730,7 @@ def outer_product(a, b): joo, joi = sch[c].split(jo, 4) sch[c].vectorize(ji) sch[c].reorder(ii, io, joo, joi, ji) - ir = tvm.lower(sch, [a, b, c], simple_mode=True) + ir = tvm.lower(sch, [a, b, c])["main"].body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) @@ -751,7 +751,7 @@ def outer_product(a, b): # Test fuse sch = te.create_schedule(c.op) sch[c].fuse(c.op.axis[0], c.op.axis[1]) - ir = tvm.lower(sch, [a, b, c], simple_mode=True) + ir = tvm.lower(sch, [a, b, c])["main"].body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index c9b422f7f0a4..9e4d45e9efaa 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -283,7 +283,7 @@ def intrin_func(ins, outs, sp): # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") s = te.create_schedule(C.op) - stmt = tvm.lower(s, [A, C], simple_mode=True) + stmt = tvm.lower(s, [A, C])["main"].body assert isinstance(stmt.body.body, tvm.tir.Evaluate) assert len(stmt.body.body.value.args) == 5 assert str(stmt.body.body.value.args[3]) == "(i*i)" diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index 3e521ab07023..2a0c6c1f40af 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -28,6 +28,9 @@ def test_schedule0(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [A, A1], stmt, None) + assert isinstance(func, tvm.tir.PrimFunc) def test_schedule1(): @@ -43,6 +46,10 @@ def test_schedule1(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [A, A1], stmt, None) + assert isinstance(func, tvm.tir.PrimFunc) + def test_schedule2(): m = te.var('m') @@ -57,6 +64,9 @@ def test_schedule2(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [A, A2], stmt, None) + assert isinstance(func, tvm.tir.PrimFunc) def test_schedule_scan(): @@ -77,6 +87,7 @@ def test_schedule_scan(): stmt = tvm.te.schedule.ScheduleOps(s, bounds) + def test_inline_multi_reduce(): def argmax_comp(x, y): idx = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) @@ -510,19 +521,19 @@ def collect_visit(stmt, f): return ret # local vs. threadIdx s = schedule(tx, "local") - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body assert (not any( collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))) # local vs. vthread s = schedule(vx, "local") - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body assert (not any( collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))) # shared vs. blockIdx s = schedule(by, "shared") - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body assert (not any( collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))) @@ -548,7 +559,7 @@ def test_local_stage_predicate2(): s[AA].compute_at(s[C], ooc) oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32) s[AA].bind(iaa, thread_x) - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body def collect_visit(stmt, f): ret = [] diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 55edd1c9958b..45280866af38 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -128,7 +128,7 @@ def intrin_func(ins, outs): lambda i: vadd(A[i, 0:factor], B[i, 0:factor])) s = te.create_schedule(C.op) - stmt = tvm.lower(s, [A, B, C], simple_mode=True) + stmt = tvm.lower(s, [A, B, C])["main"].body assert isinstance(stmt.body, tvm.tir.Evaluate) def test_tensor_compute2(): @@ -171,7 +171,7 @@ def intrin_func(ins, outs): lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k)) s = te.create_schedule(C.op) - stmt = tvm.lower(s, [A, B, C], simple_mode=True) + stmt = tvm.lower(s, [A, B, C])["main"].body assert isinstance(stmt.body.body[0], tvm.tir.Evaluate) assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate) diff --git a/tests/python/unittest/test_tir_analysis_verify_memory.py b/tests/python/unittest/test_tir_analysis_verify_memory.py index b3625082f6ed..b0de91b435ea 100644 --- a/tests/python/unittest/test_tir_analysis_verify_memory.py +++ b/tests/python/unittest/test_tir_analysis_verify_memory.py @@ -24,29 +24,6 @@ other_devices = ["llvm", "ext_dev"] -def lower(sch, args): - binds = {} - arg_list = [] - for x in args: - if isinstance(x, te.tensor.Tensor): - buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name) - assert x not in binds - binds[x] = buf - arg_list.append(buf) - else: - raise ValueError("args must be Tensor, Buffer or Var") - sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64) - - f = tvm.tir.PrimFunc(arg_list, stmt).with_attr( - "global_symbol", tvm.runtime.String("test")) - mod = tvm.IRModule({"test": f}) - return mod - - # All computations are bound. # So VerifyMemory pass is expected to succeed. # @@ -61,7 +38,7 @@ def test_verify_memory_all_bind(): s[B].bind(bx, te.thread_axis("blockIdx.x")) s[B].bind(tx, te.thread_axis("threadIdx.x")) - mod = lower(s, [A, B]) + mod = tvm.lower(s, [A, B]) for dev_type in gpu_devices + other_devices: binded_mod = tvm.tir.transform.Apply( @@ -81,7 +58,7 @@ def test_verify_memory_not_bind(): # B is not bound to threads. s = te.create_schedule(B.op) - mod = lower(s, [A, B]) + mod = tvm.lower(s, [A, B]) for dev_type in gpu_devices: binded_mod = tvm.tir.transform.Apply( @@ -111,7 +88,7 @@ def test_verify_memory_partially_bind(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) - mod = lower(s, [A, B, C, D]) + mod = tvm. lower(s, [A, B, C, D]) for dev_type in gpu_devices: binded_mod = tvm.tir.transform.Apply( diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 7a03e48e2270..4af93fd58c0a 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -194,9 +194,9 @@ def test_stmt_constructor(): assert x.then_case.value.value == 11 assert x.else_case == nop - x = tvm.tir.Prefetch(None, 1, "float32", []) + b = tvm.tir.decl_buffer((1, 2)) + x = tvm.tir.Prefetch(b, []) assert isinstance(x, tvm.tir.Prefetch) - assert x.value_index == 1 if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 9106be843b48..090acda00365 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -28,7 +28,6 @@ def test_for(): A[j] = A[j] + 2 body = ib.get() - print(body) assert isinstance(body, tvm.tir.AttrStmt) body = body.body assert isinstance(body, tvm.tir.Allocate) @@ -59,14 +58,13 @@ def test_if(): assert body.else_case.index.value == 0 def test_prefetch(): - A = te.placeholder((10, 20), name="A") + A = tvm.tir.decl_buffer((10, 20), name="A") ib = tvm.tir.ir_builder.create() n = te.size_var("n") with ib.for_range(0, n, name="i") as i: ib.emit( - tvm.tir.Prefetch( - A.op, A.value_index, A.dtype, + tvm.tir.Prefetch(A, [tvm.ir.Range.make_by_min_extent(i+1, 2), tvm.ir.Range.make_by_min_extent(0, 20)])) body = ib.get() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 9f4ccadde94d..468ab1dbad6a 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -301,6 +301,10 @@ def test_buffer_load_store(): s = tvm.tir.BufferStore(b, 0.1, [0]) assert isinstance(s, tvm.tir.BufferStore) + s = tvm.tir.BufferRealize(b, [tvm.ir.Range(0, 1)], + True, tvm.tir.Evaluate(0)) + assert isinstance(s, tvm.tir.BufferRealize) + def test_intimm_cond(): x = tvm.runtime.convert(1) diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py index 7ec2e48b4fe4..9d1641366d7d 100644 --- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -26,9 +26,10 @@ def test_copy2d(): s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def cb(src, dst, pad_before, pad_after, pad_value): assert dst.strides[0] == l assert dst.strides[1].value == 1 @@ -36,7 +37,6 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert tuple(src.shape) == (m, l) return tvm.tir.Evaluate(0) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body @@ -51,9 +51,11 @@ def test_copy_pad(): s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def cb(src, dst, pad_before, pad_after, pad_value): assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 assert pad_before[0].value == 1 @@ -63,7 +65,6 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert pad_value.value == 1.0 return tvm.tir.Evaluate(0) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body @@ -75,9 +76,11 @@ def test_single_point_test(): s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def cb(src, dst, pad_before, pad_after, pad_value): assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 assert tvm.tir.ir_pass.Simplify(dst.elem_offset).value == 0 @@ -85,7 +88,6 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1 return tvm.tir.Evaluate(0) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body @@ -105,11 +107,12 @@ def test_copy_pad_split(): s[Apad].pragma(s[Apad].op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) + mod = tvm.tir.transform.Simplify()(mod._move()) + def cb(src, dst, pad_before, pad_after, pad_value): assert(dst.elem_offset.value == 0) assert_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) @@ -121,12 +124,10 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) return tvm.tir.Evaluate(0) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body - if __name__ == "__main__": test_copy2d() test_copy_pad() diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index fb76597577b6..760cf477f959 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -28,18 +28,16 @@ def test_makeapi(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - Cb = tvm.tir.decl_buffer(C.shape, C.dtype, name='C') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr({ + "target": tvm.target.create("llvm"), + "global_symbol": "main", + }))(mod) num_unpacked_args = 2 - mod = tvm.IRModule.from_expr( - tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr({ - "global_symbol": "main", - "target": tvm.target.create("llvm") - })) f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"] assert(len(f.params) == 7) diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index dbf22679c1a0..6179bbbfbd07 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -40,8 +40,11 @@ def lower_sch(sch, args, target_bits): raise ValueError("args must be Tensor, Buffer or Var") bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) - return lower_stmt(arg_list, stmt, target_bits) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body def test_basic(): diff --git a/tests/python/unittest/test_tir_pass_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py similarity index 82% rename from tests/python/unittest/test_tir_pass_storage_flatten.py rename to tests/python/unittest/test_tir_transform_storage_flatten.py index 1eaadb35009d..e2bfeb009a11 100644 --- a/tests/python/unittest/test_tir_pass_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -30,11 +30,14 @@ def test_flatten2(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [Ab, A2b], stmt, {A: Ab, A2: A2b}) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def test_flatten_prefetch(): A = te.placeholder((25, 100, 4), name = 'A') @@ -42,8 +45,14 @@ def test_flatten_prefetch(): i = te.size_var('i') j = te.size_var('j') region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]] - stmt = tvm.tir.Prefetch(A.op, 0, A.dtype, region) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: _A}, 64) + stmt = tvm.tir.Prefetch(_A, region) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [_A], stmt, {A: _A}) + + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + stmt = mod["main"].body stmt = tvm.tir.ir_pass.Simplify(stmt) assert stmt.extent.value == 2 assert isinstance(stmt.body, tvm.tir.For) @@ -62,12 +71,15 @@ def test_flatten_storage_align(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + stmt = mod["main"].body stmt = tvm.tir.ir_pass.Simplify(stmt) assert(stmt.body.extents[0].value == 17 * 8) + def test_flatten_double_buffer(): dtype = 'int64' n = 100 @@ -87,7 +99,13 @@ def test_flatten_double_buffer(): C[j] = B[j] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {}, 64) + + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C], stmt)) + + mod = tvm.tir.transform.StorageFlatten(64)(mod) + stmt = mod["main"].body + stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2) stmt = tvm.tir.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.tir.Allocate) @@ -105,7 +123,7 @@ def count_sync(op): assert count[0] == 4 if __name__ == "__main__": - test_flatten_storage_align() test_flatten2() - test_flatten_prefetch() + test_flatten_storage_align() test_flatten_double_buffer() + test_flatten_prefetch() diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index e4e1b3102d4a..85f856db366b 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -30,11 +30,11 @@ def test_storage_share(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -166,11 +166,11 @@ def test_inplace_rule(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -201,11 +201,10 @@ def test_storage_combine(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -238,11 +237,9 @@ def test_storage_share_gpu(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A') - Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64) - - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -306,13 +303,11 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C') - Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb, Cc, Dd], stmt)) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -398,17 +393,11 @@ def test_inplace_rule3(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - B0a = tvm.tir.decl_buffer(B0.shape, B0.dtype, name='B0') - B1a = tvm.tir.decl_buffer(B1.shape, B1.dtype, name='B1') - B2a = tvm.tir.decl_buffer(B2.shape, B2.dtype, name='B2') - B3a = tvm.tir.decl_buffer(B3.shape, B3.dtype, name='B3') - B4a = tvm.tir.decl_buffer(B4.shape, B4.dtype, name='B4') - B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5') - - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B3a, B4: B4a, B5: B5a, B: Bb}, 64) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [B0, B1, B2, B3, B4, B5, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([B0a, B1a, B2a, B3a, B4a, B5a, Bb], stmt)) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -547,7 +536,7 @@ def compute(a, b): c = te.compute(shape, lambda i, j: compute(a, b)[i, j]) c = te.compute(shape, lambda i, j: 1 + c[i, j]) s = te.create_schedule(c.op) - stmt = tvm.lower(s, [a, b, c], simple_mode=True) + stmt = tvm.lower(s, [a, b, c])["main"].body def verify(n): if isinstance(n, tvm.tir.Allocate): assert n.extents[0].value == 268435456 diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 9257f6cd3320..783b66983c48 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -34,15 +34,15 @@ def test_thread_storage_sync(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) cuda_target = tvm.target.create("cuda") - mod = tvm.IRModule.from_expr( - tvm.tir.PrimFunc([Ab, A2b], stmt).with_attr({ - "global_symbol": "test", "target": cuda_target})) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr({ + "global_symbol": "test", "target": cuda_target}))(mod._move()) fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] mod = tvm.IRModule.from_expr(fdevice)