From 1a23678c1949b82d7b22b06a1d80f776ca3d0aca Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 19 Apr 2020 15:26:51 -0700 Subject: [PATCH] [TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes. (#5372) * [TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes. te::Tensor is an useful object for tensor expression, but brings un-necessary reverse dependency in TIR nodes such as Provide and Realize. This PR is a first step to remove this dependency. We will use Buffer in all the places where the te::Tensor was used. The rough correspondence are: - Provide -> BufferStore - Realize -> BufferRealize - HalideCall -> BufferLoad. After this change, we can not use IRModule of PrimFuncs cleanly to represent TIR at any point of the optimizations. Buffer will serve as the abstraction for the TIR data models to represent the intermediate storages and their constraints. We still keep Realize/HalideCall and Provide as TIR nodes for now to make the change minimum. Right after ScheduleOps, we call SchedulePostProcToPrimFunc to canonicalize the temporary IR generated by TE(which contains these nodes) to the TIR. The TIR optimizations are now mostly migrated to to the pass manager. Followup PRs are needed to migrate the remaining few passes. * Fix dev tutorial --- include/tvm/arith/bound.h | 14 +- include/tvm/runtime/memory.h | 2 +- include/tvm/te/schedule_pass.h | 21 ++ include/tvm/tir/expr.h | 15 +- include/tvm/tir/ir_pass.h | 23 -- include/tvm/tir/stmt.h | 119 ++++++++-- include/tvm/tir/stmt_functor.h | 5 + include/tvm/tir/transform.h | 21 ++ python/tvm/autotvm/feature.py | 9 +- python/tvm/driver/build_module.py | 102 +++++---- python/tvm/ir/transform.py | 8 +- python/tvm/tir/__init__.py | 3 +- python/tvm/tir/stmt.py | 37 +++- python/tvm/tir/transform/transform.py | 32 +++ src/arith/domain_touched.cc | 38 ++-- src/driver/driver_api.cc | 53 ++--- src/te/operation/op_util.cc | 4 +- .../schedule/schedule_postproc_to_primfunc.cc | 194 ++++++++++++++++ src/tir/ir/expr.cc | 13 ++ src/tir/ir/stmt.cc | 87 ++++++-- src/tir/ir/stmt_functor.cc | 34 ++- src/tir/pass/ffi_api.cc | 10 - .../{pass => transforms}/inject_prefetch.cc | 30 ++- .../{pass => transforms}/storage_flatten.cc | 209 +++++++++++------- .../unittest/test_arith_domain_touched.py | 18 +- tests/python/unittest/test_te_build_lower.py | 4 +- .../python/unittest/test_te_hybrid_script.py | 8 +- tests/python/unittest/test_te_schedule.py | 2 +- tests/python/unittest/test_te_schedule_ops.py | 19 +- tests/python/unittest/test_te_tensor.py | 4 +- .../test_tir_analysis_verify_memory.py | 29 +-- tests/python/unittest/test_tir_constructor.py | 4 +- tests/python/unittest/test_tir_ir_builder.py | 6 +- tests/python/unittest/test_tir_nodes.py | 4 + .../test_tir_transform_inject_copy_intrin.py | 39 ++-- .../test_tir_transform_make_packed_api.py | 18 +- .../test_tir_transform_narrow_datatype.py | 7 +- ... => test_tir_transform_storage_flatten.py} | 40 +++- .../test_tir_transform_storage_rewrite.py | 57 ++--- .../test_tir_transform_thread_sync.py | 12 +- tutorials/dev/low_level_custom_pass.py | 12 +- 41 files changed, 935 insertions(+), 431 deletions(-) create mode 100644 src/te/schedule/schedule_postproc_to_primfunc.cc rename src/tir/{pass => transforms}/inject_prefetch.cc (79%) rename src/tir/{pass => transforms}/storage_flatten.cc (77%) rename tests/python/unittest/{test_tir_pass_storage_flatten.py => test_tir_transform_storage_flatten.py} (82%) diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index 6165a2ab546f..b1cb779b4227 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -78,15 +78,15 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond, /*! * \brief Infer a regular domain that covers all the calls or provides within the given statement. * \param body The given statement. - * \param tensor The name of the calls or provides. - * \param consider_calls If calls (read) are considered. - * \param consider_provides If provides (write) are considered. + * \param buffer The buffer to check the access info. + * \param consider_loads If loads are considered. + * \param consider_stores If stores are considered. * \return The domain that covers all the calls or provides within the given statement. */ -Domain DomainTouched(Stmt body, - const te::Tensor &tensor, - bool consider_calls, - bool consider_provides); +Domain DomainTouched(const Stmt& body, + const tir::Buffer& buffer, + bool consider_loads, + bool consider_stores); } // namespace arith } // namespace tvm diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 121dbdde37a6..b9b420ad02b6 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -70,7 +70,7 @@ class ObjAllocatorBase { static_assert(std::is_base_of::value, "make can only be used to create Object"); T* ptr = Handler::New(static_cast(this), - std::forward(args)...); + std::forward(args)...); ptr->type_index_ = T::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index b3ecbf8c08e1..e64ea21f5f7e 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -29,6 +29,7 @@ #define TVM_TE_SCHEDULE_PASS_H_ #include +#include namespace tvm { namespace te { @@ -54,6 +55,26 @@ Map InferBound(const Schedule& sch); */ Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivial_loop); +/*! + * \brief Postprocessing the Stmt generated by ScheduleOps to create + * a PrimFunc that can then be used for further TIR optimizations. + * + * Perform this translation before running any TIR optimizations. + * + * List of actions taken by the function: + * - Remove occurences of te::Tensor, te::Operation in the IR + * and replace them by corresponding IR nodes via tir::Buffer. + * - Add annotation of extern buffers using the buffer_map field + * in the PrimFunc type. + * + * \param arg_list Array of Tensor/Var/Buffer arguments to the function. + * \param body The body of the function. + * \param bindings potential Tensor to Buffer bindings for the Tensors in the body. + */ +PrimFunc SchedulePostProcToPrimFunc(Array arg_list, + Stmt body, + Optional> bindings); + /*! * \brief To automatically inline the element-wise operations. * diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 6764178cc23c..bf0d4f985a92 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -694,7 +694,10 @@ class CallNode : public PrimExprNode { ExternCPlusPlus = 1, /*! \brief Extern "C" without side-effect. */ PureExtern = 2, - /*! \brief Halide-style call, evaluates func(args). */ + /*! + * \brief Halide-style call, evaluates func(args). + * \note Deprecated, move to BufferLoad in the future. + */ Halide = 3, /*! \brief Intrinsic functions. */ Intrinsic = 4, @@ -707,9 +710,15 @@ class CallNode : public PrimExprNode { Array args; /*! \brief Type of calls. */ CallType call_type; - /*! \brief The function to be called. */ + /*! + * \brief The function to be called. + * \note Deprecated, move to BufferLoad in the future. + */ FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ + /*! + * \brief The output value index if func's value is a tuple. + * \note Deprecated, move to BufferLoad in the future. + */ int value_index{0}; void VisitAttrs(AttrVisitor* v) { diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index f3d447e4524a..e6e2de6f24f6 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -164,22 +164,6 @@ Stmt Inline(Stmt stmt, Array args, PrimExpr body); -/*! - * \brief Flatten the multi-dimensional read/write - * to single dimensional Load/Store - * - * \param stmt The stmt to be trasnformed. - * \param extern_buffer Map specifies external - * buffer assignment of input and outputs. - * \param cache_line_size The size of CPU cache line. - * \param create_bound_attribute Whether to create bound attributes. - * \return Transformed stmt. - */ -Stmt StorageFlatten(Stmt stmt, - Map extern_buffer, - int cache_line_size, - bool create_bound_attribute = false); - /*! * \brief Try to modify the AST to support TensorCore * @@ -202,13 +186,6 @@ Stmt RewriteForTensorCore(Stmt stmt, */ bool VerifyCompactBuffer(Stmt stmt); -/*! - * \brief Inject prefetch instructions into stmt. - * \param stmt The statement to be transformed. - * \return Transformed stmt. - */ -Stmt InjectPrefetch(Stmt stmt); - /*! * \brief Decorate the stmt with a device scope, this is helpful for * hardware accelerator without thread blocks. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5bc492fcefb8..20c2d009b93c 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -248,7 +248,6 @@ class StoreNode : public StmtNode { * \endcode * \sa BufferLoad */ -class BufferStore; class BufferStoreNode : public StmtNode { public: /*! \brief The buffer variable. */ @@ -281,6 +280,10 @@ class BufferStoreNode : public StmtNode { TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode); }; +/*! + * \brief Managed reference to BufferStoreNode. + * \sa BufferStoreNode + */ class BufferStore : public Stmt { public: TVM_DLL explicit BufferStore(Buffer buffer, @@ -289,8 +292,80 @@ class BufferStore : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); }; +/*! + * \brief Annotate the region where the buffer need to + * be read and write in the body. + * We only need to allocate the space for the corresponding region. + * + * \note There should be at most one BufferRealize for each buffer. + * BufferRealize is not necessary for external buffers, + * since they are assumed to be fully allocated. + * + * \sa BufferLoad, BufferStore + */ +class BufferRealizeNode : public StmtNode { + public: + /*! \brief The buffer variable. */ + Buffer buffer; + /*! \brief Bounds to be realized */ + Array bounds; + /*! \brief Only realize if condition holds. */ + PrimExpr condition; + /*! \brief The body of realization. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer", &buffer); + v->Visit("bounds", &bounds); + v->Visit("condition", &condition); + v->Visit("body", &body); + } + + bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const { + return + equal(buffer, other->buffer) && + equal(bounds, other->bounds) && + equal(condition, other->condition) && + equal(body, other->body); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(buffer); + hash_reduce(bounds); + hash_reduce(condition); + hash_reduce(body); + } + + BufferRealizeNode() = default; + BufferRealizeNode(Buffer buffer, + Array bounds, + PrimExpr condition, + Stmt body) + : buffer(buffer), bounds(bounds), + condition(condition), body(body) {} + + static constexpr const char* _type_key = "BufferRealize"; + TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode); +}; + +/*! + * \brief Managed reference to BufferRealizeNode. + * \sa BufferRealizeNode + */ +class BufferRealize : public Stmt { + public: + TVM_DLL explicit BufferRealize(Buffer buffer, + Array bounds, + PrimExpr condition, + Stmt body); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); +}; + /*! * \brief Store value into mult-dimensional array defined by func. + * + * \note Deprecated, move to BufferStore in the future. */ class ProvideNode : public StmtNode { public: @@ -430,6 +505,8 @@ class FreeNode : public StmtNode { /*! * \brief Annotate the bounds where func need to be written and read in body. * We will need to allocate space for the corresponding regions. + * + * \note Deprecated, move to BufferRealize in the future. */ class RealizeNode : public StmtNode { public: @@ -747,50 +824,50 @@ class ForNode : public StmtNode { }; /*! - * \brief A prefetch hint of func. + * \brief A prefetch hint for abuffer */ class PrefetchNode : public StmtNode { public: /*! \brief The function to be prefetched. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index; - /*! \brief The data type of the array. */ - DataType dtype; + Buffer buffer; /*! \brief Bounds to be prefetched. */ - Region bounds; + Array bounds; void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("value_index", &value_index); - v->Visit("dtype", &dtype); + v->Visit("buffer", &buffer); v->Visit("bounds", &bounds); } bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const { return - equal(func, other->func) && - equal(value_index, other->value_index) && - equal(dtype, other->dtype) && + equal(buffer, other->buffer) && equal(bounds, other->bounds); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(func); - hash_reduce(value_index); - hash_reduce(dtype); + hash_reduce(buffer); hash_reduce(bounds); } - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds); + PrefetchNode() = default; + PrefetchNode(Buffer buffer, Array bounds) + : buffer(buffer), bounds(bounds) {} static constexpr const char* _type_key = "Prefetch"; TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); }; +/*! + * \brief Managed reference to PrefetchNode. + * \sa PrefetchNode + */ +class Prefetch : public Stmt { + public: + TVM_DLL explicit Prefetch(Buffer buffer, Array bounds); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); +}; + /*! * \brief Auxiliary data structure used in IR Pass to indicate a tensor. */ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index f93e9080a377..a87ff9737d0c 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -92,6 +92,7 @@ class StmtFunctor { virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -121,6 +122,8 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); + IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); + IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode); return vtable; } }; @@ -154,6 +157,7 @@ class TVM_DLL StmtVisitor : void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; + void VisitStmt_(const BufferRealizeNode* op) override; void VisitStmt_(const FreeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const ProvideNode* op) override; @@ -248,6 +252,7 @@ class TVM_DLL StmtMutator : Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; + Stmt VisitStmt_(const BufferRealizeNode* op) override; Stmt VisitStmt_(const FreeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; Stmt VisitStmt_(const ProvideNode* op) override; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e593e1bf0fbc..09ea09731f51 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -58,6 +58,27 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< const std::string& name, const tvm::Array& required); + +/*! + * \brief Inject prefetch instructions into stmt. + * + * \return The pass. + */ +TVM_DLL Pass InjectPrefetch(); + +// TODO(tvm-team): consolidate configs to the PassContext +/*! + * \brief Flatten the multi-dimensional read/write + * to single dimensional Load/Store + * + * \param cache_line_size The size of CPU cache line. + * \param create_bound_attribute Whether to create bound attributes. + * + * \return The Pass + */ +TVM_DLL Pass StorageFlatten(int cache_line_size, + bool create_bound_attribute = false); + /*! * \brief Inject copy intrinsics with optional pad. * diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index c576ffd76e56..0c0591ccf2a1 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -31,7 +31,6 @@ import tvm._ffi from tvm import target as _target -from tvm.tir import ir_pass from tvm.te import schedule from tvm.driver import build_module @@ -46,10 +45,12 @@ def ana_lower(sch, args, # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds, True) - stmt = ir_pass.StorageFlatten(stmt, binds, 64) - stmt = ir_pass.CanonicalSimplify(stmt) + func = schedule.SchedulePostProcToPrimFunc(args, stmt, None) + mod = tvm.IRModule.from_expr(func._move()) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) + mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode - return stmt + return mod["main"].body try: _get_buffer_curve_sample_flatten = tvm._ffi.get_global_func( diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 18a8a47ad439..eea372733ca7 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -85,7 +85,8 @@ def get_binds(args, compact=False, binds=None): def form_body(sch): - """According to the given schedule, form the raw body + """According to the given schedule, form a function. + Parameters ---------- sch : tvm.te.schedule.Schedule @@ -99,13 +100,31 @@ def form_body(sch): sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) - stmt = ir_pass.InjectPrefetch(stmt) return stmt +def _wrap_as_prim_func_pass(flist, name): + """Wrap flist as a function pass. + + This is an temporary adapter before we fully + migrate to the new pass manager. + """ + def _transform(func, *_): + stmt = func.body + for f in flist: + stmt = f(stmt) + # create a new function with updated body. + return tvm.tir.PrimFunc(func.params, + stmt, + func.ret_type, + func.buffer_map, + func.attrs) + return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name) + + def lower(sch, args, - name="default_function", + name="main", binds=None, simple_mode=False): """Lowering step before build into target. @@ -154,56 +173,57 @@ def lower(sch, compact = ir_pass.VerifyCompactBuffer(stmt) binds, arg_list = get_binds(args, compact, binds) + stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) + + # Start the new style pass manager. + func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) + func = func.with_attr("global_symbol", name) + if cfg.restricted_func: + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({name: func}) # Phase 1 - stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) - stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) - stmt = ir_pass.NarrowDataType(stmt, 32) - stmt = ir_pass.CanonicalSimplify(stmt) - for f in lower_phase1: - stmt = f(stmt) + pass_list = [ + tvm.tir.transform.InjectPrefetch(), + tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers), + tvm.tir.transform.NarrowDataType(32), + tvm.tir.transform.Simplify(), + _wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"), + ] # Phase 2 if not simple_mode: - stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) - if cfg.disable_vectorize: - stmt = ir_pass.SkipVectorize(stmt) - else: - stmt = ir_pass.VectorizeLoop(stmt) - stmt = ir_pass.InjectVirtualThread(stmt) - stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) - stmt = ir_pass.StorageRewrite(stmt) - stmt = ir_pass.UnrollLoop( - stmt, - cfg.auto_unroll_max_step, - cfg.auto_unroll_max_depth, - cfg.auto_unroll_max_extent, - cfg.unroll_explicit) - - for f in lower_phase2: - stmt = f(stmt) + pass_list += [(tvm.tir.transform.LoopPartition(cfg.partition_const_loop))] + + pass_list += [ + tvm.tir.transform.VectorizeLoop(not cfg.disable_vectorize), + tvm.tir.transform.InjectVirtualThread(), + tvm.tir.transform.InjectDoubleBuffer(cfg.double_buffer_split_loop), + tvm.tir.transform.StorageRewrite(), + tvm.tir.transform.UnrollLoop( + cfg.auto_unroll_max_step, + cfg.auto_unroll_max_depth, + cfg.auto_unroll_max_extent, + cfg.unroll_explicit), + _wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"), + ] # Phase 3 - stmt = ir_pass.Simplify(stmt) - stmt = ir_pass.RemoveNoOp(stmt) - if not cfg.disable_select_rewriting: - stmt = ir_pass.RewriteUnsafeSelect(stmt) + pass_list += [ + tvm.tir.transform.Simplify(), + tvm.tir.transform.RemoveNoOp(), + ] - for f in lower_phase3: - stmt = f(stmt) + if not cfg.disable_select_rewriting: + pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] + pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")] # Instrument BoundCheckers if cfg.instrument_bound_checkers: - stmt = ir_pass.InstrumentBoundCheckers(stmt) + pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] - if simple_mode: - return stmt - - f = tvm.tir.PrimFunc(arg_list, stmt).with_attr( - "global_symbol", tvm.runtime.String(name)) - if cfg.restricted_func: - f = f.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: f}) + optimize = tvm.transform.Sequential(pass_list) + mod = optimize(mod) return mod diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 614f9690903a..af0be45b624e 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -157,11 +157,6 @@ class Sequential(Pass): """A pass that works on a sequence of pass objects. Multiple passes can be executed sequentially using this class. - Some typical usage of the sequential pass are: - 1. Users provide a list of passes for optimization. - 2. Only an optimization level is provided so that the backend system has - to glob all passes at this level and below to perform the optimizations. - Note that users can also provide a series of passes that they don't want to apply when running a sequential pass. Pass dependency will be resolved in the backend as well. @@ -173,6 +168,9 @@ class Sequential(Pass): opt_level : Optional[int] The optimization level of this sequential pass. + The opt_level of a default sequential pass is set to 0. + Note that some of the passes within the Sequantial may still not be executed + if their opt_level is higher than the provided opt_level. name : Optional[str] The name of the sequential pass. diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index d2238ad754ac..ddfb6a5f69c1 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -28,7 +28,8 @@ from .expr import IterVar, Any from .stmt import Stmt, LetStmt, AssertStmt, For -from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt +from .stmt import BufferStore, BufferRealize, Store, Provide, Allocate, AttrStmt +from .stmt import Free, Realize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .function import PrimFunc diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index c5b2a7957319..eee5b0b002e0 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -160,6 +160,29 @@ def __init__(self, buffer, value, indices): _ffi_api.BufferStore, buffer, value, indices) +@tvm._ffi.register_object +class BufferRealize(Stmt): + """Buffer realize node. + + Parameters + ---------- + buffer : Buffer + The buffer. + + bounds : List[Range] + The value we to be stored. + + condition : PrimExpr + The realize condition. + + body : Stmt + The body of the statement. + """ + def __init__(self, buffer, bounds, condition, body): + self.__init_handle_by_constructor__( + _ffi_api.BufferRealize, buffer, bounds, condition, body) + + @tvm._ffi.register_object class Provide(Stmt): """Provide node. @@ -348,21 +371,15 @@ class Prefetch(Stmt): Parameters ---------- - func : Operation - The operation to create the function. - - value_index : int - The output value index - - dtype : str - The data type to be prefetched. + buffer : Buffer + The buffer to be prefetched. bounds : list of Range The bounds to be prefetched. """ - def __init__(self, func, value_index, dtype, bounds): + def __init__(self, buffer, bounds): self.__init_handle_by_constructor__( - _ffi_api.Prefetch, func, value_index, dtype, bounds) + _ffi_api.Prefetch, buffer, bounds) def stmt_seq(*args): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index f83bb11ad51e..bb39c1f69131 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -60,6 +60,38 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") +def InjectPrefetch(): + """Inject prefetch instructions into stmt. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectPrefetch() + + +def StorageFlatten(cache_line_size, create_bound_attribute=False): + """Flatten the multi-dimensional read/write to 1D. + + + Parameters + ---------- + cache_line_size: int + The size of CPU cache line. + + create_bound_attribute: + Whether to create bound attributes. + + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute) + + def InjectCopyIntrin(pragma_key, fintrin): """Inject virtual thread loops. diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index bda70fb67cba..2467e758c3e4 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -36,10 +36,14 @@ namespace arith { using namespace tir; // Find Read region of the tensor in the stmt. -class FuncTouchedDomain final : public StmtExprVisitor { +class BufferTouchedDomain final : public StmtExprVisitor { public: - FuncTouchedDomain(const te::Tensor &tensor, bool consider_calls, bool consider_provides) - : tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {} + BufferTouchedDomain(const Buffer &buffer, + bool consider_loads, + bool consider_stores) + : buffer_(buffer), + consider_loads_(consider_loads), + consider_stores_(consider_stores) {} Domain Find(const Stmt& stmt) { operator()(stmt); @@ -80,18 +84,16 @@ class FuncTouchedDomain final : public StmtExprVisitor { } } - void VisitExpr_(const CallNode* op) final { - if (consider_calls_ && tensor_->op.same_as(op->func) - && tensor_->value_index == op->value_index) { - Touch(op->args); + void VisitExpr_(const BufferLoadNode* op) final { + if (consider_loads_ && buffer_.same_as(op->buffer)) { + Touch(op->indices); } StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const ProvideNode* op) final { - if (consider_provides_ && tensor_->op.same_as(op->func) - && tensor_->value_index == op->value_index) { - Touch(op->args); + void VisitStmt_(const BufferStoreNode* op) final { + if (consider_stores_ && buffer_.same_as(op->buffer)) { + Touch(op->indices); } StmtExprVisitor::VisitStmt_(op); } @@ -106,17 +108,17 @@ class FuncTouchedDomain final : public StmtExprVisitor { } } - const te::Tensor &tensor_; - bool consider_calls_, consider_provides_; + const Buffer &buffer_; + bool consider_loads_, consider_stores_; std::vector > bounds_; std::unordered_map dom_map_; }; -Domain DomainTouched(Stmt stmt, - const te::Tensor &tensor, - bool consider_calls, - bool consider_provides) { - return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt); +Domain DomainTouched(const Stmt& stmt, + const Buffer& buffer, + bool consider_loads, + bool consider_stores) { + return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt); } TVM_REGISTER_GLOBAL("arith.DomainTouched") diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e38179e965f5..c3802b1a63f4 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -130,35 +130,6 @@ transform::Pass Filter(FCond fcond) { } -IRModule BuildIRModule(const Array& out_arg_list, - tir::Stmt stmt, - const std::string& name, - const BuildConfig& config) { - Array params; - Map buffer_map; - - for (auto var : out_arg_list) { - if (auto* n = var.as()) { - params.push_back(GetRef(n)); - } else { - tir::Buffer buffer = Downcast(var); - tir::Var bptr(buffer->name, DataType::Handle()); - params.push_back(bptr); - buffer_map.Set(bptr, buffer); - } - } - - auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - - if (config->restricted_func) { - f = WithAttr(std::move(f), "tir.noalias", Integer(1)); - } - - return IRModule(Map({{GlobalVar(name), f}})); -} - - IRModule lower(te::Schedule sch, const Array& args, const std::string& name, @@ -168,23 +139,31 @@ IRModule lower(te::Schedule sch, sch = sch.normalize(); - // Phase 0 + // Before TIR transformation. auto bounds = te::InferBound(sch); auto stmt = te::ScheduleOps(sch, bounds, false); - stmt = tir::InjectPrefetch(stmt); - bool compact = tir::VerifyCompactBuffer(stmt); + Map out_binds; GetBinds(args, compact, binds, &out_binds, &out_arg_list, config); - // Phase 1 - stmt = tir::StorageFlatten(stmt, out_binds, 64, - config->instrument_bound_checkers); + // build the function + tir::PrimFunc f = te::SchedulePostProcToPrimFunc( + out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + if (config->restricted_func) { + f = WithAttr(std::move(f), "tir.noalias", Integer(1)); + } - // convert to IRModule. - auto mod = BuildIRModule(out_arg_list, stmt, name, config); + auto mod = IRModule(Map({{GlobalVar(name), f}})); auto pass_list = Array(); + // Phase 0 + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back( + tir::transform::StorageFlatten(64, config->instrument_bound_checkers)); + // Phase 1 + pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop)); pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize)); diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 4ecfe9472901..d022134065c6 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -132,8 +132,8 @@ MakeLoopNest(const Stage& stage, for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) { nest[i + 1].emplace_back( AttrStmtNode::make(it_attr->prefetch_data[j], - tir::attr::prefetch_scope, - it_attr->prefetch_offset[j], no_op)); + tir::attr::prefetch_scope, + it_attr->prefetch_offset[j], no_op)); } } } else if (bind_iv->thread_tag == "vthread" || diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc new file mode 100644 index 000000000000..bb52be4c0202 --- /dev/null +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file schedule_postproc_to_primfunc.cc + * + * \brief Translate the function body generated by ScheduleOps + * with te related dialects that incorporates Tensor + * into the Stmts to a PrimFunc. + * + * Perform this translation before running any TIR optimizations. + * + * Rationale: The body generated by ScheduleOps is not + * a formal PrimFunc and cannot be used for further optimization. + * This function canonicalize that body and creates a formal PrimFunc. + * + * List of actions taken by the function: + * - Remove occurences of te::Tensor, te::Operation in the IR + * and replace them by corresponding IR nodes via tir::Buffer. + * - Add annotation of extern buffers using the buffer_map field + * in the PrimFunc type. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace te { + +// create a buffer for tensor. +Buffer CreateBufferFor(const Tensor& tensor) { + std::string name = tensor->op->name; + if (tensor->op->num_outputs() != 1) { + name += ".v" + std::to_string(tensor->value_index); + } + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name); + return buffer; +} + +// A remapper that maps tensor to buffer +class TensorToBufferMapper : public StmtExprMutator { + public: + explicit TensorToBufferMapper(std::unordered_map buffer_map) + : buffer_map_(buffer_map) { + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + // TODO(tvm-team): remove realize_scope, turn the info into + // Buffer's scope field in this pass. + if (op->attr_key == tir::attr::realize_scope || + op->attr_key == tir::attr::double_buffer_scope) { + Stmt body = op->body; + Operation operation = Downcast(op->node); + for (int i = operation->num_outputs(); i != 0; --i) { + Buffer buffer = GetOrAllocBuffer(operation.output(i - 1)); + body = AttrStmtNode::make( + buffer, op->attr_key, op->value, body); + } + return body; + } else if (op->attr_key == tir::attr::buffer_bind_scope) { + Array tuple = Downcast >(op->node); + Tensor tensor = Downcast(tuple[1]); + return AttrStmtNode::make( + Array{tuple[0], GetOrAllocBuffer(tensor)}, + op->attr_key, op->value, op->body); + } else if (op->attr_key == tir::attr::buffer_dim_align|| + op->attr_key == tir::attr::prefetch_scope) { + Tensor tensor = Downcast(op->node); + Buffer buffer = GetOrAllocBuffer(tensor); + return AttrStmtNode::make( + buffer, op->attr_key, op->value, op->body); + } else { + return ret; + } + } + + Stmt VisitStmt_(const RealizeNode* op) final { + Tensor tensor = Downcast(op->func).output(op->value_index); + Buffer buffer = GetOrAllocBuffer(tensor); + + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + return BufferRealize(buffer, op->bounds, op->condition, op->body); + } + + Stmt VisitStmt_(const ProvideNode* op) final { + Tensor tensor = Downcast(op->func).output(op->value_index); + Buffer buffer = GetBuffer(tensor); + + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + return BufferStore(buffer, op->value, op->args); + } + + PrimExpr VisitExpr_(const CallNode* op) final { + auto ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + + if (op->call_type == CallNode::Halide) { + Tensor tensor = Downcast(op->func).output(op->value_index); + Buffer buffer = GetBuffer(tensor); + return tir::BufferLoad(buffer, op->args); + } else { + return ret; + } + } + + private: + Buffer GetOrAllocBuffer(const Tensor& tensor) { + return GetBuffer(tensor, true); + } + + Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { + auto it = buffer_map_.find(tensor); + if (it != buffer_map_.end()) return it->second; + CHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; + + auto buffer = CreateBufferFor(tensor); + buffer_map_[tensor] = buffer; + return buffer; + } + + // maps tensor to buffer. + std::unordered_map buffer_map_; +}; + + +PrimFunc SchedulePostProcToPrimFunc(Array arg_list, + Stmt body, + Optional> extern_buffer_opt) { + std::unordered_map extern_buffer; + + if (extern_buffer_opt.defined()) { + auto v = extern_buffer_opt.value(); + extern_buffer = std::unordered_map(v.begin(), v.end()); + } + + Array params; + Map buffer_map; + + for (auto var : arg_list) { + if (auto* n = var.as()) { + params.push_back(GetRef(n)); + } else if (auto* n = var.as()) { + te::Tensor tensor = GetRef(n); + CHECK(!extern_buffer.count(tensor)); + + tir::Buffer buffer = CreateBufferFor(tensor); + tir::Var bptr(buffer->name, DataType::Handle()); + params.push_back(bptr); + buffer_map.Set(bptr, buffer); + extern_buffer[tensor] = buffer; + } else { + tir::Buffer buffer = Downcast(var); + tir::Var bptr(buffer->name, DataType::Handle()); + params.push_back(bptr); + buffer_map.Set(bptr, buffer); + } + } + + body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body)); + return tir::PrimFunc(params, body, VoidType(), buffer_map); +} + +TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc") +.set_body_typed(SchedulePostProcToPrimFunc); + +} // namespace te +} // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 65d424e31212..03925ec3783a 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -645,6 +645,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; + } + } + p->stream << "]"; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1f6a7dd027ea..f8e82ea787be 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -253,24 +253,14 @@ TVM_REGISTER_GLOBAL("tir.Realize") .set_body_typed(RealizeNode::make); -Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) { - for (size_t i = 0; i < bounds.size(); ++i) { - CHECK(bounds[i]->min.defined()); - CHECK(bounds[i]->extent.defined()); - CHECK(bounds[i]->min.dtype().is_scalar()); - CHECK(bounds[i]->extent.dtype().is_scalar()); - } - - ObjectPtr node = make_object(); - node->func = std::move(func); - node->value_index = value_index; - node->dtype = dtype; - node->bounds = std::move(bounds); - return Stmt(node); +Prefetch::Prefetch(Buffer buffer, Array bounds) { + data_ = make_object(buffer, bounds); } TVM_REGISTER_GLOBAL("tir.Prefetch") -.set_body_typed(PrefetchNode::make); +.set_body_typed([](Buffer buffer, Array bounds) { + return Prefetch(buffer, bounds); +}); SeqStmt::SeqStmt(Array seq) { @@ -326,6 +316,25 @@ TVM_REGISTER_GLOBAL("tir.BufferStore") TVM_REGISTER_NODE_TYPE(BufferStoreNode); + +BufferRealize::BufferRealize(Buffer buffer, + Array bounds, + PrimExpr condition, + Stmt body) { + data_ = make_object( + buffer, bounds, condition, body); +} + +TVM_REGISTER_GLOBAL("tir.BufferRealize") +.set_body_typed([](Buffer buffer, + Array bounds, + PrimExpr condition, + Stmt body) { + return BufferRealize(buffer, bounds, condition, body); +}); + +TVM_REGISTER_NODE_TYPE(BufferRealizeNode); + // Printers TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -432,6 +441,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '\n'; }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) p->stream << ", "; + } + p->stream << "]"; + p->stream << " = "; + p->Print(op->value); + p->stream << '\n'; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -458,6 +482,34 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '\n'; }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "buffer_realize " << op->buffer->name << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << " {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + p->stream << "}\n"; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -493,7 +545,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << "prefetch " << op->func->func_name() << "("; + p->stream << "prefetch " << op->buffer << "("; for (size_t i = 0; i < op->bounds.size(); ++i) { p->stream << "["; p->Print(op->bounds[i]->min); @@ -503,9 +555,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) if (i < op->bounds.size() - 1) p->stream << ", "; } p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ed3c2c75ef47..5e584ebf55f4 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -158,9 +158,19 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) { } void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { + this->VisitExpr(op->value); VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } +void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { + VisitArray(op->bounds, [this](const Range& r) { + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); + this->VisitExpr(op->condition); + this->VisitStmt(op->body); +} + void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { this->VisitExpr(op->condition); this->VisitStmt(op->then_case); @@ -336,16 +346,38 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) { } Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { + PrimExpr value = this->VisitExpr(op->value); Array indices = Internal::Mutate(this, op->indices); - if (indices.same_as(op->indices)) { + + if (value.same_as(op->value) && + indices.same_as(op->indices)) { return GetRef(op); } else { auto n = CopyOnWrite(op); + n->value = std::move(value); n->indices = std::move(indices); return Stmt(n); } } +Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { + Region bounds = Internal::Mutate(this, op->bounds); + PrimExpr condition = this->VisitExpr(op->condition); + Stmt body = this->VisitStmt(op->body); + + if (bounds.same_as(op->bounds) && + condition.same_as(op->condition) && + body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->bounds = std::move(bounds); + n->condition = std::move(condition); + n->body = std::move(body); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { Array args = Internal::Mutate(this, op->args); PrimExpr value = this->VisitExpr(op->value); diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 65981b9b62f5..4d7ed5dc3c85 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -75,15 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute") } }); -TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() <= 3) { - *ret = StorageFlatten(args[0], args[1], args[2]); - } else { - *ret = StorageFlatten(args[0], args[1], args[2], args[3]); - } - }); - TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") .set_body_typed ([](const Stmt& stmt, @@ -116,7 +107,6 @@ REGISTER_PASS(ConvertSSA); REGISTER_PASS(VerifySSA); REGISTER_PASS(Inline); REGISTER_PASS(IRTransform); -REGISTER_PASS(InjectPrefetch); REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(VerifyCompactBuffer); diff --git a/src/tir/pass/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc similarity index 79% rename from src/tir/pass/inject_prefetch.cc rename to src/tir/transforms/inject_prefetch.cc index 894ff3864864..e9dae0a5dfc9 100644 --- a/src/tir/pass/inject_prefetch.cc +++ b/src/tir/transforms/inject_prefetch.cc @@ -21,9 +21,12 @@ * \file inject_prefetch.cc */ // Inject prefetch op in HalideIR +#include #include +#include #include -#include +#include +#include #include #include @@ -39,9 +42,9 @@ class PrefetchInjector : public StmtMutator { Stmt ret = StmtMutator::VisitStmt_(op); op = ret.as(); if (op && op->attr_key == attr::prefetch_scope) { - te::Tensor ts = Downcast(op->node); + Buffer buffer = Downcast(op->node); CHECK_NE(loop_nest_.size(), 0U); - Domain domain = DomainTouched(op->body, ts, true, false); + Domain domain = DomainTouched(op->body, buffer, true, false); Region region; auto iter_var = loop_nest_.back().get(); @@ -49,7 +52,7 @@ class PrefetchInjector : public StmtMutator { for (Range r : domain) { if (!r.defined()) { - LOG(WARNING) << "Cannot decide prefetch region for " << ts; + LOG(WARNING) << "Cannot decide prefetch region for " << buffer; return op->body; } Range res(EvalSet(r, vectorized_).cover_range(none)); @@ -58,7 +61,7 @@ class PrefetchInjector : public StmtMutator { vectorized_.erase(iter_var); - Stmt prefetch = PrefetchNode::make(ts->op, ts->value_index, ts->dtype, region); + Stmt prefetch = Prefetch(buffer, region); return SeqStmt({prefetch, op->body}); } return ret; @@ -90,5 +93,22 @@ Stmt InjectPrefetch(Stmt stmt) { return PrefetchInjector()(std::move(stmt)); } + +namespace transform { + +Pass InjectPrefetch() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = PrefetchInjector()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch") +.set_body_typed(InjectPrefetch); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc similarity index 77% rename from src/tir/pass/storage_flatten.cc rename to src/tir/transforms/storage_flatten.cc index f9533fa4820a..99d437d9c24e 100644 --- a/src/tir/pass/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -19,22 +19,24 @@ /*! * \file storage_flatten.cc + * \brief Flattens storage from multi-dimensional array to 1D buffer access */ -// Flattens storage from multi-dimensional array to 1D -// buffer access as in Halide pipeline. +// The pass definition originates from Halide pipeline. + +#include #include #include #include #include #include #include -#include +#include #include #include #include #include -#include "ir_util.h" -#include "arg_binder.h" +#include "../pass/ir_util.h" +#include "../pass/arg_binder.h" #include "../../arith/compute_expr.h" #include "../../arith/ir_visitor_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" @@ -49,16 +51,17 @@ using intrinsic::tvm_address_of; class StorageFlattener : public StmtExprMutator { public: - explicit StorageFlattener(Map extern_buffer, - int cache_line_size, bool create_bound_attributes, - IRVisitorWithAnalyzer* bounded_analyzer) - : bounded_analyzer_(bounded_analyzer), + explicit StorageFlattener(const Map& extern_buffer_map, + int cache_line_size, + bool create_bound_attributes, + IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) { - for (auto kv : extern_buffer) { + for (auto kv : extern_buffer_map) { BufferEntry e; e.buffer = kv.second; e.external = true; - buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e; + buf_map_[kv.second] = e; } cache_line_size_ = cache_line_size; } @@ -82,17 +85,14 @@ class StorageFlattener : public StmtExprMutator { storage_scope_[op->node.get()] = op->value.as()->value; return this->VisitStmt(op->body); } else if (op->attr_key == attr::double_buffer_scope && - op->node->IsInstance()) { - auto func = Downcast(op->node); + op->node->IsInstance()) { + auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); - for (int i = 0; i < func->num_outputs(); ++i) { - TensorKey key{func, i}; - auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; - body = AttrStmtNode::make( - it->second.buffer->data, op->attr_key, op->value, body); - } + auto it = buf_map_.find(buffer); + CHECK(it != buf_map_.end()) + << "Cannot find allocated buffer for " << buffer; + body = AttrStmtNode::make( + it->second.buffer->data, op->attr_key, op->value, std::move(body)); return body; } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); @@ -104,11 +104,10 @@ class StorageFlattener : public StmtExprMutator { } else if (op->attr_key == attr::buffer_bind_scope) { return HandleBufferBindScope(op); } else if (op->attr_key == attr::buffer_dim_align) { - auto tensor = Downcast(op->node); + auto buffer = Downcast(op->node); const CallNode* tuple = op->value.as(); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); - TensorKey key{tensor->op, tensor->value_index}; - auto& vinfo = dim_align_[key]; + auto& vinfo = dim_align_[buffer]; int dim = tuple->args[0].as()->value; if (static_cast(dim) >= vinfo.size()) { vinfo.resize(dim + 1); @@ -122,18 +121,21 @@ class StorageFlattener : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } - Stmt VisitStmt_(const ProvideNode* op) final { - if (create_bound_attributes_) - shape_collector_.clear(); + Stmt VisitStmt_(const BufferStoreNode* op) final { + if (create_bound_attributes_) shape_collector_.clear(); Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - TensorKey key{op->func, op->value_index}; + op = stmt.as(); + + const auto& key = op->buffer; + auto it = buf_map_.find(key); CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + << "Cannot find allocated buffer for " << key; + const BufferEntry& e = it->second; CHECK(!e.released) << "Read a buffer that is already out of scope"; + if (is_opengl_) { return EvaluateNode::make(CallNode::make( DataType(), @@ -141,7 +143,7 @@ class StorageFlattener : public StmtExprMutator { {e.buffer->data, op->value}, CallNode::Intrinsic)); } else { - Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value); + Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back( std::make_pair(e.buffer->data, e.buffer->shape)); @@ -158,8 +160,9 @@ class StorageFlattener : public StmtExprMutator { } } - Stmt VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + Stmt VisitStmt_(const BufferRealizeNode* op) final { + const auto& key = op->buffer; + if (buf_map_.count(key)) { CHECK(buf_map_.at(key).external); return this->VisitStmt(op->body); @@ -172,10 +175,9 @@ class StorageFlattener : public StmtExprMutator { shape.push_back(r->extent); } // deduce current storage scope. - auto it = storage_scope_.find(op->func.get()); + auto it = storage_scope_.find(op->buffer.get()); CHECK(it != storage_scope_.end()) - << "Cannot find storage scope of " << op->func - << " value_index=" << op->value_index; + << "Cannot find storage scope of " << op->buffer; StorageScope skey; const std::string& strkey = it->second; if (strkey.length() == 0) { @@ -188,13 +190,14 @@ class StorageFlattener : public StmtExprMutator { } // use small alignment for small arrays + auto dtype = op->buffer->dtype; int32_t const_size = AllocateNode::constant_allocation_size(shape); - int align = GetTempAllocaAlignment(op->dtype, const_size); + int align = GetTempAllocaAlignment(dtype, const_size); if (skey.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(skey.to_string()); if (info.defined()) { - align = (info->max_simd_bits + op->dtype.bits() - 1) / op->dtype.bits(); - CHECK_LE(const_size * op->dtype.bits(), info->max_num_bits) + align = (info->max_simd_bits + dtype.bits() - 1) / dtype.bits(); + CHECK_LE(const_size * dtype.bits(), info->max_num_bits) << "Allocation exceed bound of memory tag " << skey.to_string(); } } @@ -210,7 +213,7 @@ class StorageFlattener : public StmtExprMutator { PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = tir::Simplify(stride); + stride = bound_analyzer_->Simplify(stride); } rstrides.push_back(stride); stride = stride * shape[dim]; @@ -219,9 +222,9 @@ class StorageFlattener : public StmtExprMutator { } e.buffer = BufferNode::make( - Var(key.GetName(), DataType::Handle()), - op->dtype, shape, strides, PrimExpr(), - key.GetName(), skey.to_string(), + Var(op->buffer->data->name_hint, DataType::Handle()), + op->buffer->dtype, shape, strides, PrimExpr(), + op->buffer->name, skey.to_string(), align, 0, kDefault); buf_map_[key] = e; @@ -285,36 +288,36 @@ class StorageFlattener : public StmtExprMutator { } } - PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op != nullptr && op->call_type == CallNode::Halide) { - TensorKey key{op->func, op->value_index}; - auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; - const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; + op = expr.as(); - if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { + const auto& key = op->buffer; + + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) + << "Cannot find allocated buffer for " << key; + const BufferEntry& e = it->second; + CHECK(!e.released) + << "Read a buffer that is already out of scope"; + + if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back( std::make_pair(e.buffer->data, e.buffer->shape)); - } - return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype); - } else { - return expr; } + return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype); } + Stmt VisitStmt_(const PrefetchNode *op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); CHECK(op != nullptr); - TensorKey key{op->func, op->value_index}; + + const auto& key = op->buffer; auto it = buf_map_.find(key); CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; CHECK(!e.released) @@ -340,7 +343,7 @@ class StorageFlattener : public StmtExprMutator { for (int i = op->bounds.size() - 1; i > starts; --i) { args.push_back(op->bounds[i]->min); } - auto &func_name = op->func->func_name(); + auto &func_name = op->buffer->name; vars.push_back(Var( "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); args.push_back(op->bounds[starts]->min + stride * vars.back()); @@ -358,7 +361,7 @@ class StorageFlattener : public StmtExprMutator { PrimExpr address = CallNode::make( DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); PrimExpr prefetch = CallNode::make( - op->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); + op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); stmt = EvaluateNode::make(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); @@ -367,6 +370,26 @@ class StorageFlattener : public StmtExprMutator { return stmt; } + PrimExpr VisitExpr_(const CallNode* op) final { + CHECK(op->call_type != CallNode::Halide) + << "Cannot handle Halide calls " + << " please run SchedulePostProcToPrimFunc first"; + return StmtExprMutator::VisitExpr_(op); + } + + Stmt VisitStmt_(const ProvideNode* op) final { + LOG(FATAL) << "Cannot handle Provide " + << " please run SchedulePostProcToPrimFunc first"; + return Stmt(); + } + + Stmt VisitStmt_(const RealizeNode* op) final { + LOG(FATAL) << "Cannot handle Realize " + << " please run SchedulePostProcToPrimFunc first"; + return Stmt(); + } + + private: // The specific tensor data layout is not determined before // StorageFlatten pass. We use buffer_bind_scope @@ -406,14 +429,16 @@ class StorageFlattener : public StmtExprMutator { Array arr = Downcast > (op->node); CHECK_EQ(arr.size(), 2U); const BufferNode* buffer = arr[0].as(); - const te::TensorNode* tensor = arr[1].as(); + const BufferNode* target = arr[1].as(); const CallNode* tuple = op->value.as(); - CHECK(buffer && tensor); + CHECK(buffer && target); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); - TensorKey key{tensor->op, tensor->value_index}; - CHECK(buf_map_.count(key)) - << "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index; - const BufferEntry& be = buf_map_.at(key); + auto key = GetRef(target); + + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) + << "Cannot find buffer of " << key; + const BufferEntry& be = it->second; CHECK(!be.released); CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); Array begins, extents; @@ -426,7 +451,7 @@ class StorageFlattener : public StmtExprMutator { } else { for (size_t i = 0; i < tuple->args.size(); i += 2) { begins.push_back(tuple->args[i]); - auto new_extent = bounded_analyzer_->Simplify(tuple->args[i+1]); + auto new_extent = bound_analyzer_->Simplify(tuple->args[i+1]); extents.push_back(new_extent); } } @@ -451,6 +476,7 @@ class StorageFlattener : public StmtExprMutator { } return body; } + // The buffer entry in the flatten map struct DimAlignInfo { int align_factor{0}; @@ -509,9 +535,10 @@ class StorageFlattener : public StmtExprMutator { // Variable remap std::unordered_map var_remap_; // Buffer map - std::unordered_map buf_map_; + std::unordered_map buf_map_; // Dimension alignment - std::unordered_map > dim_align_; + std::unordered_map, + ObjectHash, ObjectEqual> dim_align_; // Storage scope std::unordered_map storage_scope_; // The current thread scope. @@ -520,7 +547,7 @@ class StorageFlattener : public StmtExprMutator { std::vector>> shape_collector_; // bounds populator. We really need the analyzer from it. // However - IRVisitorWithAnalyzer* bounded_analyzer_; + IRVisitorWithAnalyzer* bound_analyzer_; // The size of cacheline int cache_line_size_; // The current stage is an OpenGL shader. @@ -529,15 +556,37 @@ class StorageFlattener : public StmtExprMutator { bool create_bound_attributes_{false}; }; -Stmt StorageFlatten(Stmt stmt, Map extern_buffer, - int cache_line_size, bool create_bound_attributes) { - IRVisitorWithAnalyzer bounded_analyzer; - bounded_analyzer(stmt); - stmt = - StorageFlattener(extern_buffer, cache_line_size, - create_bound_attributes, &bounded_analyzer)(std::move(stmt)); - return stmt; +PrimFunc StorageFlatten(PrimFunc func, + int cache_line_size, + bool create_bound_attributes) { + auto fptr = func.CopyOnWrite(); + + IRVisitorWithAnalyzer bound_analyzer; + bound_analyzer(fptr->body); + fptr->body = StorageFlattener(fptr->buffer_map, + cache_line_size, + create_bound_attributes, + &bound_analyzer)(std::move(fptr->body)); + return func; } + +namespace transform { + +// TODO(tvm-team): consolidate configs to the PassContext +Pass StorageFlatten(int cache_line_size, + bool create_bound_attributes) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return StorageFlatten( + std::move(f), cache_line_size, create_bound_attributes); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten") +.set_body_typed(StorageFlatten); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index 0d769aabf247..10337218dc87 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -22,21 +22,25 @@ def test_domain_touched(): j = te.var('j') n = tvm.runtime.convert(100) m = te.var('m') - a = te.placeholder((n, m), name = 'a') - b = te.placeholder((n, m), name = 'b') + + a = tvm.tir.decl_buffer((n, m), name='a') + b = tvm.tir.decl_buffer((n, m), name='b') + + ir = tvm.tir.For( i, 0, n, 0, 0, tvm.tir.For(j, 0, m, 0, 0, - tvm.tir.Provide( - a.op, - 0, - tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) + - tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0), + tvm.tir.BufferStore( + a, + tvm.tir.BufferLoad(b, [i - 1, j + 1]) + + tvm.tir.BufferLoad(a, [i - 1, j - 1]), [i, j] ) ) ) + a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False) + assert a_domain_r[0].min.value == -1 assert a_domain_r[0].extent.value == 100 assert a_domain_r[1].min.value == -1 diff --git a/tests/python/unittest/test_te_build_lower.py b/tests/python/unittest/test_te_build_lower.py index 442c4fed7b2f..b1d754605a46 100644 --- a/tests/python/unittest/test_te_build_lower.py +++ b/tests/python/unittest/test_te_build_lower.py @@ -48,9 +48,9 @@ def test_split_uneven_unique_likely(): x, y = c.op.axis sch = te.create_schedule(c.op) xo, xi = sch[c].split(x, 5) - stmt = tvm.lower(sch, [a, b, c], simple_mode=True) + stmt = tvm.lower(sch, [a, b, c])["main"].body assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse) - assert str(stmt.body.body.body).count("likely") == 1 + if __name__ == "__main__": test_lower_rfactor() diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index b525d018340d..5b4a1c92a7e4 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -365,7 +365,7 @@ def foo(a): a = te.placeholder((8, 4), 'float32') c = foo(a) s = te.create_schedule(c.op) - ir = tvm.lower(s, [a, c], simple_mode=True) + ir = tvm.lower(s, [a, c]) func, ins, outs = run_and_check(foo, [a], target='cuda') run_and_check(func, ins, outs=outs, target='cuda') @@ -517,7 +517,7 @@ def upstream(a): c = te.compute((20, ), lambda x: a[x] + b[x]) d = upstream(c) sch = te.create_schedule([c.op, d.op]) - ir = tvm.lower(sch, [a, b, d], simple_mode=True) + ir = tvm.lower(sch, [a, b, d]) func = tvm.build(sch, [a, b, d]) assert(func) @@ -730,7 +730,7 @@ def outer_product(a, b): joo, joi = sch[c].split(jo, 4) sch[c].vectorize(ji) sch[c].reorder(ii, io, joo, joi, ji) - ir = tvm.lower(sch, [a, b, c], simple_mode=True) + ir = tvm.lower(sch, [a, b, c])["main"].body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) @@ -751,7 +751,7 @@ def outer_product(a, b): # Test fuse sch = te.create_schedule(c.op) sch[c].fuse(c.op.axis[0], c.op.axis[1]) - ir = tvm.lower(sch, [a, b, c], simple_mode=True) + ir = tvm.lower(sch, [a, b, c])["main"].body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index c9b422f7f0a4..9e4d45e9efaa 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -283,7 +283,7 @@ def intrin_func(ins, outs, sp): # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") s = te.create_schedule(C.op) - stmt = tvm.lower(s, [A, C], simple_mode=True) + stmt = tvm.lower(s, [A, C])["main"].body assert isinstance(stmt.body.body, tvm.tir.Evaluate) assert len(stmt.body.body.value.args) == 5 assert str(stmt.body.body.value.args[3]) == "(i*i)" diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index 3e521ab07023..2a0c6c1f40af 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -28,6 +28,9 @@ def test_schedule0(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [A, A1], stmt, None) + assert isinstance(func, tvm.tir.PrimFunc) def test_schedule1(): @@ -43,6 +46,10 @@ def test_schedule1(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [A, A1], stmt, None) + assert isinstance(func, tvm.tir.PrimFunc) + def test_schedule2(): m = te.var('m') @@ -57,6 +64,9 @@ def test_schedule2(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [A, A2], stmt, None) + assert isinstance(func, tvm.tir.PrimFunc) def test_schedule_scan(): @@ -77,6 +87,7 @@ def test_schedule_scan(): stmt = tvm.te.schedule.ScheduleOps(s, bounds) + def test_inline_multi_reduce(): def argmax_comp(x, y): idx = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) @@ -510,19 +521,19 @@ def collect_visit(stmt, f): return ret # local vs. threadIdx s = schedule(tx, "local") - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body assert (not any( collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))) # local vs. vthread s = schedule(vx, "local") - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body assert (not any( collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))) # shared vs. blockIdx s = schedule(by, "shared") - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body assert (not any( collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))) @@ -548,7 +559,7 @@ def test_local_stage_predicate2(): s[AA].compute_at(s[C], ooc) oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32) s[AA].bind(iaa, thread_x) - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body def collect_visit(stmt, f): ret = [] diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 55edd1c9958b..45280866af38 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -128,7 +128,7 @@ def intrin_func(ins, outs): lambda i: vadd(A[i, 0:factor], B[i, 0:factor])) s = te.create_schedule(C.op) - stmt = tvm.lower(s, [A, B, C], simple_mode=True) + stmt = tvm.lower(s, [A, B, C])["main"].body assert isinstance(stmt.body, tvm.tir.Evaluate) def test_tensor_compute2(): @@ -171,7 +171,7 @@ def intrin_func(ins, outs): lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k)) s = te.create_schedule(C.op) - stmt = tvm.lower(s, [A, B, C], simple_mode=True) + stmt = tvm.lower(s, [A, B, C])["main"].body assert isinstance(stmt.body.body[0], tvm.tir.Evaluate) assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate) diff --git a/tests/python/unittest/test_tir_analysis_verify_memory.py b/tests/python/unittest/test_tir_analysis_verify_memory.py index b3625082f6ed..b0de91b435ea 100644 --- a/tests/python/unittest/test_tir_analysis_verify_memory.py +++ b/tests/python/unittest/test_tir_analysis_verify_memory.py @@ -24,29 +24,6 @@ other_devices = ["llvm", "ext_dev"] -def lower(sch, args): - binds = {} - arg_list = [] - for x in args: - if isinstance(x, te.tensor.Tensor): - buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name) - assert x not in binds - binds[x] = buf - arg_list.append(buf) - else: - raise ValueError("args must be Tensor, Buffer or Var") - sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64) - - f = tvm.tir.PrimFunc(arg_list, stmt).with_attr( - "global_symbol", tvm.runtime.String("test")) - mod = tvm.IRModule({"test": f}) - return mod - - # All computations are bound. # So VerifyMemory pass is expected to succeed. # @@ -61,7 +38,7 @@ def test_verify_memory_all_bind(): s[B].bind(bx, te.thread_axis("blockIdx.x")) s[B].bind(tx, te.thread_axis("threadIdx.x")) - mod = lower(s, [A, B]) + mod = tvm.lower(s, [A, B]) for dev_type in gpu_devices + other_devices: binded_mod = tvm.tir.transform.Apply( @@ -81,7 +58,7 @@ def test_verify_memory_not_bind(): # B is not bound to threads. s = te.create_schedule(B.op) - mod = lower(s, [A, B]) + mod = tvm.lower(s, [A, B]) for dev_type in gpu_devices: binded_mod = tvm.tir.transform.Apply( @@ -111,7 +88,7 @@ def test_verify_memory_partially_bind(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) - mod = lower(s, [A, B, C, D]) + mod = tvm. lower(s, [A, B, C, D]) for dev_type in gpu_devices: binded_mod = tvm.tir.transform.Apply( diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 7a03e48e2270..4af93fd58c0a 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -194,9 +194,9 @@ def test_stmt_constructor(): assert x.then_case.value.value == 11 assert x.else_case == nop - x = tvm.tir.Prefetch(None, 1, "float32", []) + b = tvm.tir.decl_buffer((1, 2)) + x = tvm.tir.Prefetch(b, []) assert isinstance(x, tvm.tir.Prefetch) - assert x.value_index == 1 if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 9106be843b48..090acda00365 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -28,7 +28,6 @@ def test_for(): A[j] = A[j] + 2 body = ib.get() - print(body) assert isinstance(body, tvm.tir.AttrStmt) body = body.body assert isinstance(body, tvm.tir.Allocate) @@ -59,14 +58,13 @@ def test_if(): assert body.else_case.index.value == 0 def test_prefetch(): - A = te.placeholder((10, 20), name="A") + A = tvm.tir.decl_buffer((10, 20), name="A") ib = tvm.tir.ir_builder.create() n = te.size_var("n") with ib.for_range(0, n, name="i") as i: ib.emit( - tvm.tir.Prefetch( - A.op, A.value_index, A.dtype, + tvm.tir.Prefetch(A, [tvm.ir.Range.make_by_min_extent(i+1, 2), tvm.ir.Range.make_by_min_extent(0, 20)])) body = ib.get() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 9f4ccadde94d..468ab1dbad6a 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -301,6 +301,10 @@ def test_buffer_load_store(): s = tvm.tir.BufferStore(b, 0.1, [0]) assert isinstance(s, tvm.tir.BufferStore) + s = tvm.tir.BufferRealize(b, [tvm.ir.Range(0, 1)], + True, tvm.tir.Evaluate(0)) + assert isinstance(s, tvm.tir.BufferRealize) + def test_intimm_cond(): x = tvm.runtime.convert(1) diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py index 7ec2e48b4fe4..9d1641366d7d 100644 --- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -26,9 +26,10 @@ def test_copy2d(): s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def cb(src, dst, pad_before, pad_after, pad_value): assert dst.strides[0] == l assert dst.strides[1].value == 1 @@ -36,7 +37,6 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert tuple(src.shape) == (m, l) return tvm.tir.Evaluate(0) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body @@ -51,9 +51,11 @@ def test_copy_pad(): s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def cb(src, dst, pad_before, pad_after, pad_value): assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 assert pad_before[0].value == 1 @@ -63,7 +65,6 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert pad_value.value == 1.0 return tvm.tir.Evaluate(0) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body @@ -75,9 +76,11 @@ def test_single_point_test(): s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def cb(src, dst, pad_before, pad_after, pad_value): assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 assert tvm.tir.ir_pass.Simplify(dst.elem_offset).value == 0 @@ -85,7 +88,6 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1 return tvm.tir.Evaluate(0) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body @@ -105,11 +107,12 @@ def test_copy_pad_split(): s[Apad].pragma(s[Apad].op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) + mod = tvm.tir.transform.Simplify()(mod._move()) + def cb(src, dst, pad_before, pad_after, pad_value): assert(dst.elem_offset.value == 0) assert_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) @@ -121,12 +124,10 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) return tvm.tir.Evaluate(0) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body - if __name__ == "__main__": test_copy2d() test_copy_pad() diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index fb76597577b6..760cf477f959 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -28,18 +28,16 @@ def test_makeapi(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - Cb = tvm.tir.decl_buffer(C.shape, C.dtype, name='C') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr({ + "target": tvm.target.create("llvm"), + "global_symbol": "main", + }))(mod) num_unpacked_args = 2 - mod = tvm.IRModule.from_expr( - tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr({ - "global_symbol": "main", - "target": tvm.target.create("llvm") - })) f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"] assert(len(f.params) == 7) diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index dbf22679c1a0..6179bbbfbd07 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -40,8 +40,11 @@ def lower_sch(sch, args, target_bits): raise ValueError("args must be Tensor, Buffer or Var") bounds = te.schedule.InferBound(sch) stmt = te.schedule.ScheduleOps(sch, bounds) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False) - return lower_stmt(arg_list, stmt, target_bits) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body def test_basic(): diff --git a/tests/python/unittest/test_tir_pass_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py similarity index 82% rename from tests/python/unittest/test_tir_pass_storage_flatten.py rename to tests/python/unittest/test_tir_transform_storage_flatten.py index 1eaadb35009d..e2bfeb009a11 100644 --- a/tests/python/unittest/test_tir_pass_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -30,11 +30,14 @@ def test_flatten2(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [Ab, A2b], stmt, {A: Ab, A2: A2b}) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def test_flatten_prefetch(): A = te.placeholder((25, 100, 4), name = 'A') @@ -42,8 +45,14 @@ def test_flatten_prefetch(): i = te.size_var('i') j = te.size_var('j') region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]] - stmt = tvm.tir.Prefetch(A.op, 0, A.dtype, region) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: _A}, 64) + stmt = tvm.tir.Prefetch(_A, region) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [_A], stmt, {A: _A}) + + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + stmt = mod["main"].body stmt = tvm.tir.ir_pass.Simplify(stmt) assert stmt.extent.value == 2 assert isinstance(stmt.body, tvm.tir.For) @@ -62,12 +71,15 @@ def test_flatten_storage_align(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + stmt = mod["main"].body stmt = tvm.tir.ir_pass.Simplify(stmt) assert(stmt.body.extents[0].value == 17 * 8) + def test_flatten_double_buffer(): dtype = 'int64' n = 100 @@ -87,7 +99,13 @@ def test_flatten_double_buffer(): C[j] = B[j] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {}, 64) + + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C], stmt)) + + mod = tvm.tir.transform.StorageFlatten(64)(mod) + stmt = mod["main"].body + stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2) stmt = tvm.tir.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.tir.Allocate) @@ -105,7 +123,7 @@ def count_sync(op): assert count[0] == 4 if __name__ == "__main__": - test_flatten_storage_align() test_flatten2() - test_flatten_prefetch() + test_flatten_storage_align() test_flatten_double_buffer() + test_flatten_prefetch() diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index e4e1b3102d4a..85f856db366b 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -30,11 +30,11 @@ def test_storage_share(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -166,11 +166,11 @@ def test_inplace_rule(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -201,11 +201,10 @@ def test_storage_combine(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -238,11 +237,9 @@ def test_storage_share_gpu(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A') - Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64) - - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -306,13 +303,11 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C') - Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb, Cc, Dd], stmt)) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -398,17 +393,11 @@ def test_inplace_rule3(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - B0a = tvm.tir.decl_buffer(B0.shape, B0.dtype, name='B0') - B1a = tvm.tir.decl_buffer(B1.shape, B1.dtype, name='B1') - B2a = tvm.tir.decl_buffer(B2.shape, B2.dtype, name='B2') - B3a = tvm.tir.decl_buffer(B3.shape, B3.dtype, name='B3') - B4a = tvm.tir.decl_buffer(B4.shape, B4.dtype, name='B4') - B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5') - - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B3a, B4: B4a, B5: B5a, B: Bb}, 64) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [B0, B1, B2, B3, B4, B5, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([B0a, B1a, B2a, B3a, B4a, B5a, Bb], stmt)) mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageRewrite()(mod) stmt = mod["main"].body @@ -547,7 +536,7 @@ def compute(a, b): c = te.compute(shape, lambda i, j: compute(a, b)[i, j]) c = te.compute(shape, lambda i, j: 1 + c[i, j]) s = te.create_schedule(c.op) - stmt = tvm.lower(s, [a, b, c], simple_mode=True) + stmt = tvm.lower(s, [a, b, c])["main"].body def verify(n): if isinstance(n, tvm.tir.Allocate): assert n.extents[0].value == 268435456 diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 9257f6cd3320..783b66983c48 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -34,15 +34,15 @@ def test_thread_storage_sync(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) cuda_target = tvm.target.create("cuda") - mod = tvm.IRModule.from_expr( - tvm.tir.PrimFunc([Ab, A2b], stmt).with_attr({ - "global_symbol": "test", "target": cuda_target})) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr({ + "global_symbol": "test", "target": cuda_target}))(mod._move()) fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] mod = tvm.IRModule.from_expr(fdevice) diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py index 25ca279bf339..d35913b1cd83 100644 --- a/tutorials/dev/low_level_custom_pass.py +++ b/tutorials/dev/low_level_custom_pass.py @@ -40,8 +40,6 @@ take a look at ``python/tvm/build_module.py`` to get some basics. """ - -from __future__ import absolute_import, print_function import tvm from tvm import te import numpy as np @@ -57,7 +55,7 @@ c = te.compute((n, ), lambda i: a[i] + b[i], name='c') sch = te.create_schedule(c.op) -ir = tvm.lower(sch, [a, b, c], simple_mode=True) +ir = tvm.lower(sch, [a, b, c]) print(ir) ###################################################################### @@ -137,12 +135,8 @@ def vectorize(stmt): # Glue to Lowering # ---------------- # So far, we are done with writing this IR transformation pass. What we need to do next is to glue -# this pass to TVM's lower pass. We can first call this function directly as a sanity check. +# this pass to TVM's lower pass. # - -print(vectorize(ir)) - -##################################################################### # In TVM, there is a property called ``BuildConfig``. You can use this property to customize your # own lowering options. In this case, we inject the pass written above into the TVM standard lowering # pass by feeding **a list of tuple** as argument to ``add_lower_pass``. "Tuple" indicates different @@ -160,7 +154,7 @@ def vectorize(stmt): # with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg: - print(tvm.lower(sch, [a, b, c], simple_mode=True)) + print(tvm.lower(sch, [a, b, c])) ##################################################################### # Quick View