diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 6af99586d2f9..5c4990a9ba92 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -53,7 +53,6 @@ struct ExprDeepEqual { TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const; }; - /*! * \brief Find undefined vars in the statment. * \param stmt The function to be checked. diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index e228ce32adab..f3d447e4524a 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -202,59 +202,6 @@ Stmt RewriteForTensorCore(Stmt stmt, */ bool VerifyCompactBuffer(Stmt stmt); -/*! - * \brief Remove No Op from the Stmt. - * \param stmt The stmt to be trasnformed - * \return Transformed stmt. - */ -Stmt RemoveNoOp(Stmt stmt); - -/*! - * \brief unroll the constant loop marked by unroll. - * This pass also automatically attach pragma unroll tag to loops which meets the standard. - * - * \param stmt The statment to be unrolled. - * \param auto_max_step The maximum step before stop attach automatic unroll - * \param auto_max_depth The maximum depth before stop attach automatic unroll - * \param auto_max_extent The maximum extent of the loop we can unroll, - * this is an legacy option that do not take the loop total steps into account. - * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen. - * \return Transformed stmt. - */ -Stmt UnrollLoop(Stmt stmt, - int auto_max_step, - int auto_max_depth, - int auto_max_extent, - bool explicit_unroll); - -/*! - * \brief vectorize the constant loops - * \param stmt The statement to be vectorized. - * \return Transformed stmt. - */ -Stmt VectorizeLoop(Stmt stmt); - -/*! - * \brief convert vectorized loops into serialized loops - * \param stmt The statement to skip vectorization on. - * \return Transformed stmt. - */ -Stmt SkipVectorize(Stmt stmt); - -/*! -* \brief instruments bound checkers. -* \param stmt The statement to be instrumented. -* \return Instrumented stmt. -*/ -Stmt InstrumentBoundCheckers(Stmt stmt); - -/*! - * \brief Inject virtual thread loops into stmt. - * \param stmt The statement to be transformed. - * \return Transformed stmt. - */ -Stmt InjectVirtualThread(Stmt stmt); - /*! * \brief Inject prefetch instructions into stmt. * \param stmt The statement to be transformed. @@ -262,84 +209,6 @@ Stmt InjectVirtualThread(Stmt stmt); */ Stmt InjectPrefetch(Stmt stmt); -/*! - * \brief Inject double buffer into stmt. - * \param stmt The statement to be transformed. - * \param split_loop Loop splitting factor. - * \return Transformed stmt. - */ -Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); - -/*! - * \brief Inject copy intrinsics with optional pad. - * - * \param stmt The statement to be transformed. - * \param pragma_key The pragma key for hint of copy. - * \param fintrin The function with signature - * - * Stmt fintrin(Buffer src, - * Buffer dst, - * Array pad_before, - * Array pad_after, - * Expr pad_value) - * \return Transformed stmt. - */ -Stmt InjectCopyIntrin(Stmt stmt, - const std::string& pragma_key, - const runtime::PackedFunc& fintrin); - -/*! - * \brief Rewrite storage allocation pattern. - * Moves the allocation to outer most possible scope. - * Trying to share space between allocations to make - * a static allocation plan when possible. - * - * \param stmt The stmt to be transformed - * \return Transformed stmt. - */ -Stmt StorageRewrite(Stmt stmt); - -/*! - * \brief partition loops in the stmt - * \param stmt The stmt to do loop partition - * \param split_const_loop flag to enable partition for const loop - * \return Transformed stmt. - */ -Stmt LoopPartition(Stmt stmt, bool split_const_loop); - -/*! - * \brief Detect and insert sync points to co-processor. - * - * \param stmt The stmt to be transformed - * \return Transformed stmt. - */ -Stmt CoProcSync(Stmt stmt); - -/*! - * \brief Lift common attrs with attr_key to outer scope. - * - * \param stmt The stmt to be transformed - * \param attr_key The attribute key to be checked. - * \return Transformed stmt. - */ -Stmt LiftAttrScope(Stmt stmt, std::string attr_key); - -/*! - * \brief Detect and rewrite unsafe select that contains memory access. - * \param stmt The statement to be rewritten. - * \return Transformed stmt. - */ -Stmt RewriteUnsafeSelect(Stmt stmt); - -/*! - * \brief Lower attached storage access information. - * Do this pass after all storage access analysis finish. - * - * \param stmt The stmt to be transformed - * \return Transformed stmt. - */ -Stmt LowerStorageAccessInfo(Stmt stmt); - /*! * \brief Decorate the stmt with a device scope, this is helpful for * hardware accelerator without thread blocks. @@ -356,15 +225,6 @@ Stmt DecorateDeviceScope(Stmt stmt); */ Stmt HoistIfThenElse(Stmt stmt); -/*! - * \brief Narrow down PrimExpr datatype in stmt to target_bits. - * \note Run this pass after StorageFlatten. - * \param stmt The stmt to do datatype rewrite - * \param target_bits the bit of target datatype - * \return Transformed stmt. - */ -Stmt NarrowDataType(Stmt stmt, int target_bits); - /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 23c195563ac2..e593e1bf0fbc 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -58,6 +58,124 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< const std::string& name, const tvm::Array& required); +/*! + * \brief Inject copy intrinsics with optional pad. + * + * \param pragma_key The pragma key for hint of copy. + * \param fintrin The function with signature + * + * Stmt fintrin(Buffer src, + * Buffer dst, + * Array pad_before, + * Array pad_after, + * Expr pad_value) + * \return The pass. + */ +TVM_DLL Pass InjectCopyIntrin(std::string pragma_key, + runtime::PackedFunc fintrin); + +/*! + * \brief Detect and insert sync points to co-processor. + * + * \return The pass. + */ +TVM_DLL Pass CoProcSync(); + +/*! + * \brief Lift common attrs with attr_key to outer scope. + * + * \param attr_key The attribute key to be checked. + * \return The pass. + */ +TVM_DLL Pass LiftAttrScope(std::string attr_key); + +/*! + * \brief partition loops in the stmt. + * + * \param split_const_loop flag to enable partition for const loop + * + * \return The pass. + */ +TVM_DLL Pass LoopPartition(bool split_const_loop); + +/*! + * \brief Lower vectorization loops. + * + * \param enable_vectorize Whether vectorization is enabled. + * + * \return The pass. + */ +TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true); + +/*! + * \brief Inject virtual thread loops. + * + * \return The pass. + */ +TVM_DLL Pass InjectVirtualThread(); + +/*! + * \brief Inject double buffer statements. + * + * \param split_loop_factor Loop splitting factor. + * \return The pass. + */ +TVM_DLL Pass InjectDoubleBuffer(int split_loop_factor); + +/*! + * \brief Rewrite storage allocation pattern. + * Moves the allocation to outer most possible scope. + * Trying to share space between allocations to make + * a static allocation plan when possible. + * + * \return The pass. + */ +TVM_DLL Pass StorageRewrite(); + +/*! + * \brief unroll the constant loop marked by unroll. + * This pass also automatically attach pragma unroll tag to loops which meets the standard. + * + * \param auto_max_step The maximum step before stop attach automatic unroll + * \param auto_max_depth The maximum depth before stop attach automatic unroll + * \param auto_max_extent The maximum extent of the loop we can unroll, + * this is an legacy option that do not take the loop total steps into account. + * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen. + * \return The pass. + */ +TVM_DLL Pass UnrollLoop(int auto_max_step, + int auto_max_depth, + int auto_max_extent, + bool explicit_unroll); + +/*! + * \brief Remove No Op from the Stmt. + * + * \return The pass. + */ +TVM_DLL Pass RemoveNoOp(); + +/*! + * \brief Detect and rewrite unsafe select that contains memory access. + * + * \return The pass. + */ +TVM_DLL Pass RewriteUnsafeSelect(); + +/*! +* \brief Run arithmetic simplifications on the statements and expressions. +* +* \return The pass. +*/ +TVM_DLL Pass Simplify(); + +/*! +* \brief Instruments bound checkers. +* +* \return The pass. +*/ +TVM_DLL Pass InstrumentBoundCheckers(); + /*! * \brief Transform the high-level PrimFunc to a low-level version * that can be used as an API function. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a429d0775dae..18a8a47ad439 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -179,6 +179,7 @@ def lower(sch, cfg.auto_unroll_max_depth, cfg.auto_unroll_max_extent, cfg.unroll_explicit) + for f in lower_phase2: stmt = f(stmt) @@ -187,11 +188,14 @@ def lower(sch, stmt = ir_pass.RemoveNoOp(stmt) if not cfg.disable_select_rewriting: stmt = ir_pass.RewriteUnsafeSelect(stmt) + for f in lower_phase3: stmt = f(stmt) + # Instrument BoundCheckers if cfg.instrument_bound_checkers: stmt = ir_pass.InstrumentBoundCheckers(stmt) + if simple_mode: return stmt diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 9f64a93a4860..f83bb11ad51e 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -60,6 +60,203 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") +def InjectCopyIntrin(pragma_key, fintrin): + """Inject virtual thread loops. + + Parameters + ---------- + pragma_key : str + The pragma key for hint of copy. + + fintrin : function + The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value) + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectCopyIntrin(pragma_key, fintrin) + + +def CoProcSync(): + """Detect and insert sync points to co-processor. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.CoProcSync() + + +def LiftAttrScope(attr_key): + """Lift common attrs with attr_key to outer scope. + + Parameters + ---------- + attr_key : str + The attribute key to be checked. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LiftAttrScope(attr_key) + + +def LoopPartition(split_const_loop): + """Inject virtual thread loops. + + Parameters + ---------- + split_const_loop : bool + Flag to enable partition for const loop. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LoopPartition(split_const_loop) + + +def VectorizeLoop(enable_vectorize=True): + """Lower vectorization loops. + + Parameters + ---------- + enable_vectorize : bool + Whether vectorization is enabled. + Will lower to scalar loop when it is turned off. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.VectorizeLoop(enable_vectorize) + + +def InjectVirtualThread(): + """Inject virtual thread loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectVirtualThread() + + +def InjectDoubleBuffer(split_loop_factor): + """Inject double buffer statements. + + Parameters + ---------- + split_loop_factor : int + Loop splitting factor. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectDoubleBuffer(split_loop_factor) + + +def StorageRewrite(): + """Rewrite storage allocation pattern. + + Moves the allocation to outer most possible scope. + Trying to share space between allocations to make + a static allocation plan when possible. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.StorageRewrite() + + +def UnrollLoop(auto_max_step, + auto_max_depth, + auto_max_extent, + explicit_unroll): + """Unroll the constant loop marked by unroll. + + This pass also automatically attach pragma unroll tag to loops which meets the standard. + + Parameters + ---------- + auto_max_step : int + The maximum step before stop attach automatic unroll + + auto_max_depth : int + The maximum depth before stop attach automatic unroll + + auto_max_extent : int + The maximum extent of the loop we can unroll. + This is an legacy option that do not take the loop total steps into account. + + explicit_unroll : bool + Whether explicitly unroll the loop, or leave unroll annotation to codegen. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.UnrollLoop( + auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll) + + +def RemoveNoOp(): + """Remove No Op from the Stmt. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RemoveNoOp() + + +def RewriteUnsafeSelect(): + """Detect and rewrite unsafe select that contains memory access. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RewriteUnsafeSelect() + + +def Simplify(): + """Run arithmetic simplifications on the statements and expressions. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.Simplify() + + +def InstrumentBoundCheckers(): + """Instruments bound checkers. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InstrumentBoundCheckers() + + def LowerCustomDatatypes(): """Lower custom datatypes. diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h index adb4f3000a29..f842780bec7c 100644 --- a/src/arith/compute_expr.h +++ b/src/arith/compute_expr.h @@ -25,6 +25,7 @@ #define TVM_ARITH_COMPUTE_EXPR_H_ #include +#include #include #include diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f576c842b25c..e38179e965f5 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -109,64 +109,6 @@ void GetBinds(const Array& args, } } -/*! -* \brief Build a Stmt given a schedule, args and binds. This function runs the IR passes. -* \param sch The schedule to build. -* \param args The arguments for the schedule. -* \param binds Buffer assignments. -* \param loop_partition True if the LoopPartition pass should be included. -* \param out_arg_list Returns the arguments for the Stmt. -* \param config The build configuration. -* \return The built Stmt. -*/ -tir::Stmt BuildStmt(te::Schedule sch, - const Array& args, - const std::unordered_map& binds, - bool loop_partition, - Array *out_arg_list, - const BuildConfig& config) { - sch = sch.normalize(); - - // Phase 0 - auto bounds = te::InferBound(sch); - auto stmt = te::ScheduleOps(sch, bounds, false); - stmt = tir::InjectPrefetch(stmt); - - bool compact = tir::VerifyCompactBuffer(stmt); - Map out_binds; - GetBinds(args, compact, binds, &out_binds, out_arg_list, config); - - // Phase 1 - stmt = tir::StorageFlatten(stmt, out_binds, 64, - config->instrument_bound_checkers); - stmt = tir::CanonicalSimplify(stmt); - if (loop_partition) { - stmt = tir::LoopPartition(stmt, config->partition_const_loop); - } - if (config->disable_vectorize) { - stmt = tir::SkipVectorize(stmt); - } else { - stmt = tir::VectorizeLoop(stmt); - } - stmt = tir::InjectVirtualThread(stmt); - stmt = tir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop); - stmt = tir::StorageRewrite(stmt); - stmt = tir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth, - config->auto_unroll_max_extent, config->unroll_explicit); - - // Phase 2 - stmt = tir::Simplify(stmt); - stmt = tir::RemoveNoOp(stmt); - - if (!(config->disable_select_rewriting)) - stmt = tir::RewriteUnsafeSelect(stmt); - - if (config->instrument_bound_checkers) - stmt = tir::InstrumentBoundCheckers(stmt); - - return stmt; -} - transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tvm::attr::kTarget, target); @@ -176,7 +118,7 @@ transform::Pass BindTarget(Target target) { template -transform::Pass FilterBy(FCond fcond) { +transform::Pass Filter(FCond fcond) { auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { if (fcond(f)) { return f; @@ -184,18 +126,14 @@ transform::Pass FilterBy(FCond fcond) { return tir::PrimFunc(nullptr); } }; - return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {}); + return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } -IRModule lower(te::Schedule sch, - const Array& args, - const std::string& name, - const std::unordered_map& binds, - const BuildConfig& config) { - Array out_arg_list; - auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); - +IRModule BuildIRModule(const Array& out_arg_list, + tir::Stmt stmt, + const std::string& name, + const BuildConfig& config) { Array params; Map buffer_map; @@ -216,10 +154,64 @@ IRModule lower(te::Schedule sch, if (config->restricted_func) { f = WithAttr(std::move(f), "tir.noalias", Integer(1)); } + return IRModule(Map({{GlobalVar(name), f}})); } +IRModule lower(te::Schedule sch, + const Array& args, + const std::string& name, + const std::unordered_map& binds, + const BuildConfig& config) { + Array out_arg_list; + + sch = sch.normalize(); + + // Phase 0 + auto bounds = te::InferBound(sch); + auto stmt = te::ScheduleOps(sch, bounds, false); + stmt = tir::InjectPrefetch(stmt); + + bool compact = tir::VerifyCompactBuffer(stmt); + Map out_binds; + GetBinds(args, compact, binds, &out_binds, &out_arg_list, config); + + // Phase 1 + stmt = tir::StorageFlatten(stmt, out_binds, 64, + config->instrument_bound_checkers); + + // convert to IRModule. + auto mod = BuildIRModule(out_arg_list, stmt, name, config); + auto pass_list = Array(); + + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop)); + pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize)); + pass_list.push_back(tir::transform::InjectVirtualThread()); + pass_list.push_back(tir::transform::InjectDoubleBuffer(config->double_buffer_split_loop)); + pass_list.push_back(tir::transform::StorageRewrite()); + pass_list.push_back( + tir::transform::UnrollLoop(config->auto_unroll_max_step, + config->auto_unroll_max_depth, + config->auto_unroll_max_extent, + config->unroll_explicit)); + // Phase 2 + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::RemoveNoOp()); + if (!(config->disable_select_rewriting)) { + pass_list.push_back(tir::transform::RewriteUnsafeSelect()); + } + if (config->instrument_bound_checkers) { + pass_list.push_back(tir::transform::InstrumentBoundCheckers()); + } + // run + auto optimize = transform::Sequential(pass_list); + mod = optimize(std::move(mod)); + return mod; +} + + std::pair split_dev_host_funcs(IRModule mod_mixed, const Target& target, @@ -242,7 +234,7 @@ split_dev_host_funcs(IRModule mod_mixed, mod_mixed = opt_mixed(std::move(mod_mixed)); auto host_pass_list = { - FilterBy([](const tir::PrimFunc& f) { + Filter([](const tir::PrimFunc& f) { return f->GetAttr( tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; @@ -258,7 +250,7 @@ split_dev_host_funcs(IRModule mod_mixed, // device pipeline auto device_pass_list = { - FilterBy([](const tir::PrimFunc& f) { + Filter([](const tir::PrimFunc& f) { return f->GetAttr( tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 3083b6879635..65981b9b62f5 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -114,27 +114,12 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") REGISTER_PASS(ConvertSSA); REGISTER_PASS(VerifySSA); -REGISTER_PASS(RewriteUnsafeSelect); REGISTER_PASS(Inline); REGISTER_PASS(IRTransform); -REGISTER_PASS(VectorizeLoop); -REGISTER_PASS(SkipVectorize); -REGISTER_PASS(UnrollLoop); -REGISTER_PASS(InjectCopyIntrin); -REGISTER_PASS(StorageRewrite); -REGISTER_PASS(CoProcSync); -REGISTER_PASS(LowerStorageAccessInfo); -REGISTER_PASS(InjectVirtualThread); REGISTER_PASS(InjectPrefetch); -REGISTER_PASS(InjectDoubleBuffer); -REGISTER_PASS(LoopPartition); -REGISTER_PASS(RemoveNoOp); -REGISTER_PASS(LiftAttrScope); REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); -REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(HoistIfThenElse); -REGISTER_PASS(NarrowDataType); } // namespace tir } // namespace tvm diff --git a/src/tir/pass/bound_checker.cc b/src/tir/transforms/bound_checker.cc similarity index 88% rename from src/tir/pass/bound_checker.cc rename to src/tir/transforms/bound_checker.cc index ee24d0f77673..f770bc76941e 100644 --- a/src/tir/pass/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -22,8 +22,11 @@ */ // Instrument checkers for out of the bounds access. +#include +#include #include -#include +#include +#include #include #include #include @@ -173,8 +176,8 @@ class BoundChecker : public StmtExprMutator { } // Try to simplify index and bound. - index = tir::Simplify(index); - upper_bound = tir::Simplify(upper_bound); + index = analyzer_.Simplify(index); + upper_bound = analyzer_.Simplify(upper_bound); // Cast to the same type - signed, to be able to check lower bound. index = CastNode::make(DataType::Int(64), index); @@ -201,6 +204,8 @@ class BoundChecker : public StmtExprMutator { const char *const error_message_ = "OUT OF THE BOUNDS"; // Hashtable which maps buffer_var to shape. std::unordered_map mem_to_shape_; + // internal analyzer + arith::Analyzer analyzer_; }; Stmt InstrumentBoundCheckers(Stmt stmt) { @@ -209,5 +214,29 @@ Stmt InstrumentBoundCheckers(Stmt stmt) { bound_collector(stmt); return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt)); } + + +TVM_REGISTER_GLOBAL("ir_pass.InstrumentBoundCheckers") +.set_body_typed(InstrumentBoundCheckers); + +namespace transform { + +Pass InstrumentBoundCheckers() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + BoundCollector bound_collector; + // At first walk recursively and collect bound attributes. + bound_collector(n->body); + n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers") +.set_body_typed(InstrumentBoundCheckers); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc similarity index 97% rename from src/tir/pass/coproc_sync.cc rename to src/tir/transforms/coproc_sync.cc index 38b7798eae11..fc20285a1a22 100644 --- a/src/tir/pass/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -20,13 +20,14 @@ /*! * \file coproc_sync.cc */ +#include +#include #include -#include #include #include #include -#include "ir_util.h" -#include "storage_access.h" +#include "../pass/ir_util.h" +#include "../pass/storage_access.h" namespace tvm { namespace tir { @@ -677,5 +678,24 @@ Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); } +TVM_REGISTER_GLOBAL("ir_pass.CoProcSync") +.set_body_typed(CoProcSync); + +namespace transform { + +Pass CoProcSync() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = CoProcSyncInserter().Insert(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.CoProcSync") +.set_body_typed(CoProcSync); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc similarity index 91% rename from src/tir/pass/inject_copy_intrin.cc rename to src/tir/transforms/inject_copy_intrin.cc index 4805caf5ac55..5e40eb2d9025 100644 --- a/src/tir/pass/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -21,10 +21,11 @@ * \brief Replace certain copy with copy intrinsics. * \file copy_intrin_rewrite.cc */ +#include +#include #include #include #include -#include #include "../../arith/pattern_match.h" namespace tvm { @@ -196,5 +197,26 @@ Stmt InjectCopyIntrin(Stmt stmt, return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt)); } +TVM_REGISTER_GLOBAL("ir_pass.InjectCopyIntrin") +.set_body_typed(InjectCopyIntrin); + +namespace transform { + +Pass InjectCopyIntrin(std::string pragma_key, + PackedFunc flower_copy_fromto) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = CopyIntrinInjector( + pragma_key, flower_copy_fromto)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin") +.set_body_typed(InjectCopyIntrin); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc similarity index 93% rename from src/tir/pass/inject_double_buffer.cc rename to src/tir/transforms/inject_double_buffer.cc index b9aa5a9e697e..e9422fa2ff3e 100644 --- a/src/tir/pass/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -21,10 +21,12 @@ * \brief Inject double buffering optimization for data fetch. * \file inject_double_buffer.cc */ +#include #include +#include #include #include -#include "ir_util.h" +#include "../pass/ir_util.h" #include "../../arith/compute_expr.h" namespace tvm { @@ -273,5 +275,26 @@ class DoubleBufferInjector : public StmtExprMutator { Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) { return DoubleBufferInjector(split_loop).Inject(stmt); } + +TVM_REGISTER_GLOBAL("ir_pass.InjectDoubleBuffer") +.set_body_typed(InjectDoubleBuffer); + + +namespace transform { + +Pass InjectDoubleBuffer(int split_loop) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = DoubleBufferInjector(split_loop).Inject(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer") +.set_body_typed(InjectDoubleBuffer); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc similarity index 96% rename from src/tir/pass/inject_virtual_thread.cc rename to src/tir/transforms/inject_virtual_thread.cc index e9c403ca5cb5..c70962d8207e 100644 --- a/src/tir/pass/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -20,8 +20,10 @@ /*! * \file inject_virtual_thread.cc */ +#include #include #include +#include #include #include #include "../../arith/compute_expr.h" @@ -500,5 +502,24 @@ Stmt InjectVirtualThread(Stmt stmt) { return ConvertSSA(std::move(stmt)); } +TVM_REGISTER_GLOBAL("ir_pass.InjectVirtualThread") +.set_body_typed(InjectVirtualThread); + +namespace transform { + +Pass InjectVirtualThread() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = ConvertSSA(VirtualThreadInjector()(std::move(n->body))); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread") +.set_body_typed(InjectVirtualThread); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc similarity index 90% rename from src/tir/pass/lift_attr_scope.cc rename to src/tir/transforms/lift_attr_scope.cc index 9aa037feb460..a1d922394df1 100644 --- a/src/tir/pass/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -23,9 +23,10 @@ * the body contains the same scope. * \file lift_attr_scope.cc */ -#include +#include +#include #include -#include "ir_util.h" +#include "../pass/ir_util.h" namespace tvm { namespace tir { @@ -191,5 +192,24 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { return AttrScopeLifter(attr_key).Lift(std::move(stmt)); } +TVM_REGISTER_GLOBAL("ir_pass.LiftAttrScope") +.set_body_typed(LiftAttrScope); + +namespace transform { + +Pass LiftAttrScope(std::string attr_key) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope") +.set_body_typed(LiftAttrScope); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/loop_partition.cc b/src/tir/transforms/loop_partition.cc similarity index 96% rename from src/tir/pass/loop_partition.cc rename to src/tir/transforms/loop_partition.cc index e9157e796e38..dbed5f2abd86 100644 --- a/src/tir/pass/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -20,9 +20,11 @@ /*! * \file loop_partition.cc */ +#include #include -#include #include +#include +#include #include #include #include @@ -500,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, Stmt pre_stmt; bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { - body_begin = tir::Simplify(middle_interval.min()); + body_begin = analyzer_.Simplify(middle_interval.min()); if (!analyzer_.CanProve(body_begin == min)) { PrimExpr cond = (body_begin - min >= 0); if (!analyzer_.CanProve(cond)) { @@ -525,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, Stmt post_stmt; bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { - post_doubt_begin = tir::Simplify(middle_interval.max() + 1); + post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1); if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative PrimExpr cond = (max - post_doubt_begin + 1 >= 0); @@ -588,7 +590,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, - for_node->for_type, for_node->device_api, body); + for_node->for_type, for_node->device_api, body); } } @@ -610,5 +612,25 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop) { return stmt; } + +TVM_REGISTER_GLOBAL("ir_pass.LoopPartition") +.set_body_typed(LoopPartition); + +namespace transform { + +Pass LoopPartition(bool split_const_loop) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = LoopPartition(std::move(n->body), split_const_loop); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LoopPartition") +.set_body_typed(LoopPartition); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index e7f81ed929b9..9fa72303e2d8 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -143,6 +143,8 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); } +TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccessInfo") +.set_body_typed(LowerStorageAccessInfo); namespace transform { diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 1f9d976c407d..4aeaafda48ba 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -395,6 +395,10 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); } +TVM_REGISTER_GLOBAL("ir_pass.NarrowDataType") +.set_body_typed(NarrowDataType); + + namespace transform { Pass NarrowDataType(int target_bits) { diff --git a/src/tir/pass/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc similarity index 89% rename from src/tir/pass/remove_no_op.cc rename to src/tir/transforms/remove_no_op.cc index 181a8c483e4e..44c974fc0fb0 100644 --- a/src/tir/pass/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -21,8 +21,11 @@ * \file remove_no_op.cc * \brief Remove no op from the stmt */ +#include #include #include +#include +#include #include #include @@ -147,5 +150,25 @@ class NoOpRemover : public StmtMutator { Stmt RemoveNoOp(Stmt stmt) { return NoOpRemover()(std::move(stmt)); } + +TVM_REGISTER_GLOBAL("ir_pass.RemoveNoOp") +.set_body_typed(RemoveNoOp); + +namespace transform { + +Pass RemoveNoOp() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = NoOpRemover()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp") +.set_body_typed(RemoveNoOp); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc similarity index 89% rename from src/tir/pass/rewrite_unsafe_select.cc rename to src/tir/transforms/rewrite_unsafe_select.cc index 501649237090..386b4cc66ed6 100644 --- a/src/tir/pass/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -21,9 +21,10 @@ * \file unsafe_select_rewrite.cc * \brief Rewrite uinsafe select expression. */ +#include #include #include -#include +#include namespace tvm { namespace tir { @@ -132,5 +133,24 @@ Stmt RewriteUnsafeSelect(Stmt stmt) { return UnsafeSelectRewriter()(std::move(stmt)); } +TVM_REGISTER_GLOBAL("ir_pass.RewriteUnsafeSelect") +.set_body_typed(RewriteUnsafeSelect); + +namespace transform { + +Pass RewriteUnsafeSelect() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = UnsafeSelectRewriter()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect") +.set_body_typed(RewriteUnsafeSelect); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/arith/stmt_simplify.cc b/src/tir/transforms/simplify.cc similarity index 86% rename from src/arith/stmt_simplify.cc rename to src/tir/transforms/simplify.cc index 6c3dd022565c..ecfa25e28975 100644 --- a/src/arith/stmt_simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -18,17 +18,19 @@ */ /*! - * \file stmt_simplify.cc + * \file simplify.cc * \brief Statement simplifier based on analyzer */ +#include #include #include +#include #include #include #include #include -#include "ir_mutator_with_analyzer.h" +#include "../../arith/ir_mutator_with_analyzer.h" namespace tvm { namespace arith { @@ -125,5 +127,23 @@ PrimExpr Simplify(PrimExpr expr, Map vrange) { Stmt Simplify(Stmt stmt, Map vrange) { return CanonicalSimplify(std::move(stmt), vrange); } + +namespace transform { + +Pass Simplify() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + arith::Analyzer analyzer; + n->body = arith::StmtSimplifier(&analyzer).Simplify(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.Simplify") +.set_body_typed(Simplify); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc similarity index 98% rename from src/tir/pass/storage_rewrite.cc rename to src/tir/transforms/storage_rewrite.cc index f3604b640349..c13879c31c64 100644 --- a/src/tir/pass/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -22,16 +22,18 @@ * \brief Memory access pattern analysis and optimization. * Re-write data access to enable memory sharing when possible. */ +#include #include #include #include +#include #include #include #include #include #include #include -#include "ir_util.h" +#include "../pass/ir_util.h" #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" @@ -1039,5 +1041,26 @@ Stmt StorageRewrite(Stmt stmt) { stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); return VectorAllocRewriter()(std::move(stmt)); } + +TVM_REGISTER_GLOBAL("ir_pass.StorageRewrite") +.set_body_typed(StorageRewrite); + +namespace transform { + +Pass StorageRewrite() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); + n->body = VectorAllocRewriter()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite") +.set_body_typed(StorageRewrite); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc similarity index 88% rename from src/tir/pass/unroll_loop.cc rename to src/tir/transforms/unroll_loop.cc index 0167dbcec5f2..27c39d4c18aa 100644 --- a/src/tir/pass/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -22,8 +22,11 @@ * \file unroll_loop.cc */ // Unrolls the loop as in Halide pipeline. +#include #include +#include #include +#include #include #include #include @@ -201,13 +204,31 @@ Stmt UnrollLoop(Stmt stmt, } } -Stmt UnrollLoopExplicitly(Stmt stmt) { - const ForNode* op = stmt.as(); - if (!op) { - LOG(FATAL) << "attempted to unroll a non-loop statement"; - } - return LoopUnroller(0, 0, 0, false).Unroll(op); +TVM_REGISTER_GLOBAL("ir_pass.UnrollLoop") +.set_body_typed(UnrollLoop); + +namespace transform { + +Pass UnrollLoop(int auto_max_step, + int auto_max_depth, + int auto_max_extent, + bool explicit_unroll) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = UnrollLoop(std::move(f->body), + auto_max_step, + auto_max_depth, + auto_max_extent, + explicit_unroll); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {}); } +TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop") +.set_body_typed(UnrollLoop); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc similarity index 95% rename from src/tir/pass/vectorize_loop.cc rename to src/tir/transforms/vectorize_loop.cc index b73587db2ab6..cc4361dc3ad1 100644 --- a/src/tir/pass/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -21,9 +21,11 @@ * \file vectorize_loop.cc */ // Loop vectorizer as in Halide pipeline. +#include #include -#include +#include #include +#include #include #include #include @@ -539,8 +541,9 @@ class VectorizeSkipper : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (op->for_type == ForType::Vectorized) { - return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, - op->body); + return ForNode::make(op->loop_var, op->min, op->extent, + ForType::Serial, op->device_api, + op->body); } else { return stmt; } @@ -551,5 +554,32 @@ Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); } +TVM_REGISTER_GLOBAL("ir_pass.VectorizeLoop") +.set_body_typed(VectorizeLoop); + +TVM_REGISTER_GLOBAL("ir_pass.SkipVectorize") +.set_body_typed(SkipVectorize); + +namespace transform { + +// TODO(tvm-team): Make it as a target property. +Pass VectorizeLoop(bool enable_vectorize) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + if (enable_vectorize) { + n->body = LoopVectorizer()(std::move(n->body)); + } else { + n->body = VectorizeSkipper()(std::move(n->body)); + } + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop") +.set_body_typed(VectorizeLoop); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_pass_virtual_thread.py b/tests/python/unittest/test_tir_pass_virtual_thread.py deleted file mode 100644 index 2d96696eed88..000000000000 --- a/tests/python/unittest/test_tir_pass_virtual_thread.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te - -def test_virtual_thread(): - m = te.var('m') - A = te.placeholder((m, ), name='A') - A1 = te.compute((m,), lambda i: A[i], name='A1') - A2 = te.compute((m,), lambda i: A1[i] + 3, name='A2') - - s = te.create_schedule(A2.op) - vx = te.thread_axis("vthread", name="vx") - xo, xi = s[A2].split(A2.op.axis[0], nparts=2) - s[A2].bind(xo, vx) - xo, xi = s[A2].split(xi, 8) - s[A1].compute_at(s[A2], xo) - - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt) - print(stmt) - -if __name__ == "__main__": - test_virtual_thread() diff --git a/tests/python/unittest/test_tir_pass_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py similarity index 91% rename from tests/python/unittest/test_tir_pass_coproc_sync.py rename to tests/python/unittest/test_tir_transform_coproc_sync.py index b0e2050e2ee9..f6583493d646 100644 --- a/tests/python/unittest/test_tir_pass_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -37,7 +37,10 @@ def meminfo_cache(): ib.scope_attr(cp, "coproc_scope", 1) A[j] = A[j + k * 10] + 2 stmt = ib.get() - stmt = tvm.tir.ir_pass.CoProcSync(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) + stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body + body = stmt.body.body.body blist = tvm.tir.stmt_list(body) assert(blist[1].value.name == "cop.coproc_read_barrier") @@ -65,7 +68,10 @@ def test_coproc_sync2(): ib.scope_attr(cp, "coproc_scope", 2) A[ty] = 1.0 stmt = ib.get() - stmt = tvm.tir.ir_pass.CoProcSync(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) + stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body + def test_coproc_sync3(): def __check_list(tvm_array, py_list): @@ -91,7 +97,10 @@ def __check_list(tvm_array, py_list): A[0] = 0.0 stmt = ib.get() - stmt = tvm.tir.ir_pass.CoProcSync(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) + stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body + slist = tvm.tir.stmt_list(stmt[0].body.body) push_st = slist[2] slist = tvm.tir.stmt_list(slist[-1]) diff --git a/tests/python/unittest/test_tir_pass_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py similarity index 89% rename from tests/python/unittest/test_tir_pass_inject_copy_intrin.py rename to tests/python/unittest/test_tir_transform_inject_copy_intrin.py index 8c34e344d73e..7ec2e48b4fe4 100644 --- a/tests/python/unittest/test_tir_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -35,7 +35,10 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert src.strides[0] == l assert tuple(src.shape) == (m, l) return tvm.tir.Evaluate(0) - stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body + def test_copy_pad(): m = te.var('m') @@ -59,7 +62,10 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert pad_after[1].value == 0 assert pad_value.value == 1.0 return tvm.tir.Evaluate(0) - stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body + def test_single_point_test(): A = te.placeholder((1,), name='A') @@ -78,7 +84,10 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert tvm.tir.ir_pass.Simplify(src.strides[0]).value == 1 assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1 return tvm.tir.Evaluate(0) - stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body + def assert_expr_equal(a, b): assert tvm.tir.ir_pass.Simplify(a - b).value == 0 @@ -111,7 +120,11 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert_expr_equal(pad_after[0], rpad_after) assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) return tvm.tir.Evaluate(0) - stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body + + if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_pass_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py similarity index 91% rename from tests/python/unittest/test_tir_pass_inject_double_buffer.py rename to tests/python/unittest/test_tir_transform_inject_double_buffer.py index 6b04db30f6d5..4c0573da7616 100644 --- a/tests/python/unittest/test_tir_pass_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -36,13 +36,19 @@ def test_double_buffer(): C[j] = B[j] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2) - stmt = tvm.tir.ir_pass.Simplify(stmt) - assert isinstance(stmt.body.body, tvm.tir.Allocate) - assert stmt.body.body.extents[0].value == 2 mod = tvm.IRModule({ "db" : tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt) }) + + opt = tvm.transform.Sequential( + [tvm.tir.transform.InjectDoubleBuffer(2), + tvm.tir.transform.Simplify()]) + mod = opt(mod) + stmt = mod["db"].body + + assert isinstance(stmt.body.body, tvm.tir.Allocate) + assert stmt.body.body.extents[0].value == 2 + f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] def count_sync(op): diff --git a/tests/python/unittest/test_tir_pass_inject_vthread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py similarity index 83% rename from tests/python/unittest/test_tir_pass_inject_vthread.py rename to tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 8fbd8295d238..c0789c654fbf 100644 --- a/tests/python/unittest/test_tir_pass_inject_vthread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -40,9 +40,14 @@ def get_vthread(name): C[i * nthread + tx] = B[i] + 1 return ib.get() - stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread")) + stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr( + tvm.tir.PrimFunc([], get_vthread("vthread"))))["main"].body + assert stmt.body.body.extents[0].value == 2 - stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("cthread")) + + stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr( + tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body + assert len(stmt.body.body.extents) == 3 @@ -67,16 +72,20 @@ def get_vthread(name): A[tx] = tx + 1.0 B[ty] = ty + 1.0 ib.emit(tvm.tir.call_extern("int32", "Run", - abuffer.access_ptr("r"), - bbuffer.access_ptr("r"), - cbuffer.access_ptr("rw"))) + abuffer.access_ptr("r"), + bbuffer.access_ptr("r"), + cbuffer.access_ptr("rw"))) return ib.get() - stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread")) + + stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr( + tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body + assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.body.body.body.body.extents[0].value == 2 assert len(stmt.body.body.body.body.body.body.extents) == 3 + def test_vthread_if_then_else(): nthread = 2 tx = te.thread_axis("vthread") @@ -92,7 +101,10 @@ def test_vthread_if_then_else(): with ib.if_scope(i == 0): B[i] = A[i * nthread + tx] + 2 stmt = ib.get() - stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt) + + stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr( + tvm.tir.PrimFunc([], stmt)))["main"].body + assert stmt.body.body.body[0].else_case != None assert stmt.body.body.body[1].else_case == None diff --git a/tests/python/unittest/test_tir_pass_bound_checkers.py b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py similarity index 94% rename from tests/python/unittest/test_tir_pass_bound_checkers.py rename to tests/python/unittest/test_tir_transform_instrument_bound_checkers.py index d6c89b2ab878..47c1f7bf1159 100644 --- a/tests/python/unittest/test_tir_pass_bound_checkers.py +++ b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py @@ -18,32 +18,12 @@ import tvm from tvm import te import numpy as np + def collect_visit(stmt, f): ret = [] tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) return ret -def lower(sch, args): - binds = {} - arg_list = [] - for x in args: - if isinstance(x, te.tensor.Tensor): - buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name) - assert x not in binds - binds[x] = buf - arg_list.append(buf) - else: - raise ValueError("args must be Tensor, Buffer or Var") - sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, True) - stmt = tvm.tir.ir_pass.RemoveNoOp(stmt) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, True) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - return stmt @pytest.mark.xfail def test_out_of_bounds_llvm(index_a, index_b): @@ -72,7 +52,6 @@ def test_in_bounds_llvm(): tgt = "llvm" tgt_host = "llvm" stmt = tvm.lower (s, [A, B, C], simple_mode=True) - print (stmt) fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd") ctx = tvm.context(tgt, 0) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) @@ -93,7 +72,6 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b): tgt = "llvm" tgt_host = "llvm" stmt = tvm.lower (s, [a, b, c], simple_mode=True) - print (stmt) f = tvm.build(s, [a, b, c], tgt, target_host=tgt_host, name="myaddvec") ctx = tvm.cpu(0) n = nn @@ -192,13 +170,11 @@ def collect_branch_stmt (x): s = te.create_schedule(T.op) xo, xi = s[T].split(T.op.axis[0], factor=4) - bounds = tvm.te.schedule.InferBound(s) - stmt = lower (s, [A, B, T]) - # num_attributes = num_buffers * num_splits = 2 * 3 - # before instrumentation - assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3) - assert_bound_instrumentation(stmt, check_branch_stmt, 0) - stmt = tvm.tir.ir_pass.InstrumentBoundCheckers(stmt) + with tvm.target.build_config(instrument_bound_checkers=True, + partition_const_loop=True): + mod = tvm.driver.lower(s, [A, B, T], name="main") + + stmt = mod["main"].body # after instrumentation assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3) assert_bound_instrumentation(stmt, check_branch_stmt, 2) @@ -209,7 +185,8 @@ def collect_branch_stmt (x): def test_in_bounds_const_loop_partition_llvm(): - with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True): + with tvm.target.build_config(instrument_bound_checkers=True, + partition_const_loop=True): n = 21 A = te.placeholder((n, ), name='A') B = te.placeholder((n, ), name='B') diff --git a/tests/python/unittest/test_tir_pass_lift_attr_scope.py b/tests/python/unittest/test_tir_transform_lift_attr_scope.py similarity index 88% rename from tests/python/unittest/test_tir_pass_lift_attr_scope.py rename to tests/python/unittest/test_tir_transform_lift_attr_scope.py index 0831565dc155..f5f4030d1b23 100644 --- a/tests/python/unittest/test_tir_pass_lift_attr_scope.py +++ b/tests/python/unittest/test_tir_transform_lift_attr_scope.py @@ -35,7 +35,10 @@ def test_coproc_lift(): A[j] = A[j] + 3 A[j] = A[j] + 3 body = ib.get() - body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope") + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + assert body.body.body.node == cp # only able to lift to the common pattern of the last two fors. @@ -52,7 +55,10 @@ def test_coproc_lift(): A[i] = A[i] + 2 body = ib.get() - body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope") + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + assert body.body.body.body[1].node == cp assert len(body.body.body.body) == 2 diff --git a/tests/python/unittest/test_tir_pass_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py similarity index 79% rename from tests/python/unittest/test_tir_pass_loop_partition.py rename to tests/python/unittest/test_tir_transform_loop_partition.py index 1256d8bbd4fc..6ca3f596f196 100644 --- a/tests/python/unittest/test_tir_pass_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -23,26 +23,6 @@ def collect_visit(stmt, f): tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x))) return ret -def lower(sch, args): - binds = {} - arg_list = [] - for x in args: - if isinstance(x, te.tensor.Tensor): - buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name) - assert x not in binds - binds[x] = buf - arg_list.append(buf) - else: - raise ValueError("args must be Tensor, Buffer or Var") - sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - return stmt def test_basic(): n = te.size_var('n') @@ -55,10 +35,16 @@ def test_basic(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body[0])) - assert('if' in str(stmt.body.body[1])) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) + mod = tvm.tir.transform.LoopPartition(False)(mod) + stmt = tvm.tir.transform.Simplify()(mod)["main"].body + + assert(not any( + collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))) + assert(any( + collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse)))) + def test_const_loop(): n = 21 @@ -71,9 +57,12 @@ def test_const_loop(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, True) - stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body[0])) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.tir.transform.LoopPartition(True)(mod) + stmt = tvm.tir.transform.Simplify()(mod)["main"].body + + assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))) def test_multi_loop(): ib = tvm.tir.ir_builder.create() @@ -87,8 +76,11 @@ def test_multi_loop(): with ib.else_scope(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.Simplify(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt)) + mod = tvm.tir.transform.LoopPartition(False)(mod) + stmt = tvm.tir.transform.Simplify()(mod)["main"].body + assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))) def test_multi_if(): @@ -107,9 +99,14 @@ def test_multi_if(): with ib.else_scope(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body[0])) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.tir.transform.LoopPartition(False)(mod) + stmt = tvm.tir.transform.Simplify()(mod)["main"].body + + assert(not any( + collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))) + def test_thread_axis(): m = te.size_var('m') @@ -126,9 +123,14 @@ def test_thread_axis(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body[0])) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.tir.transform.LoopPartition(False)(mod) + stmt = tvm.tir.transform.Simplify()(mod)["main"].body + + assert(not any( + collect_visit(stmt.body. body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))) + def test_vectorize(): n = te.size_var('n') @@ -147,11 +149,12 @@ def test_vectorize(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) - stmt = lower(s, [A, B]) + stmt = tvm.lower(s, [A, B], name="main")["main"].body body = stmt.body.body.body.body assert(x.var.name not in str(body.condition)) assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp)))) + def test_condition(): ib = tvm.tir.ir_builder.create() m = te.size_var('m') @@ -161,10 +164,14 @@ def test_condition(): ib.emit(tvm.tir.Evaluate( tvm.tir.Select(ib.likely(i*4+j 1, A[i-1], 1.0) - yy = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(y)).value + + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([i], tvm.tir.Evaluate(y))) + yy = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value z = tvm.tir.Select( tvm.tir.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1) - zz = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(z)).value + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([i], tvm.tir.Evaluate(z))) + zz = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value + + a = tvm.tir.Select(tvm.tir.floordiv(i, 4) > 10, y, z) - a = tvm.tir.Select(tvm.te.floordiv(i, 4) > 10, y, z) - aa = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(a)).value + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a))) + aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value assert yy.name == "tvm_if_then_else" assert zz.name == "tvm_if_then_else" assert isinstance(aa, tvm.tir.Select) diff --git a/tests/python/unittest/test_arith_stmt_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py similarity index 93% rename from tests/python/unittest/test_arith_stmt_simplify.py rename to tests/python/unittest/test_tir_transform_simplify.py index 45f083342410..bf5398245c50 100644 --- a/tests/python/unittest/test_arith_stmt_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -27,7 +27,9 @@ def test_stmt_simplify(): A[i] = C[i] body = tvm.tir.LetStmt(n, 10, ib.get()) - body = tvm.tir.ir_pass.CanonicalSimplify(body) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C, n], body)) + body = tvm.tir.transform.Simplify()(mod)["main"].body assert isinstance(body.body, tvm.tir.Store) @@ -44,7 +46,9 @@ def test_thread_extent_simplify(): with ib.if_scope(tx + ty < 12): A[tx] = C[tx + ty] body = tvm.tir.LetStmt(n, 10, ib.get()) - body = tvm.tir.ir_pass.CanonicalSimplify(body) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C, n], body)) + body = tvm.tir.transform.Simplify()(mod)["main"].body assert isinstance(body.body.body.body, tvm.tir.Store) diff --git a/tests/python/unittest/test_tir_pass_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py similarity index 89% rename from tests/python/unittest/test_tir_pass_storage_rewrite.py rename to tests/python/unittest/test_tir_transform_storage_rewrite.py index b36d86b47af8..e4e1b3102d4a 100644 --- a/tests/python/unittest/test_tir_pass_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -33,9 +33,12 @@ def test_storage_share(): Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + # verify only have one allocations. # verify inplace folding works num_alloc = [0] @@ -72,7 +75,10 @@ def test_alloc_seq(): A[j] = 1.3 body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): @@ -129,7 +135,10 @@ def verify(n): body = stmt_generater(dtype_list, length) offset = offset_generater(dtype_list, length) - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + tvm.tir.ir_pass.PostOrderVisit(body, verify) length = 1024 @@ -160,9 +169,12 @@ def test_inplace_rule(): Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + # verify only have one allocations. # verify inplace folding works num_alloc = [0] @@ -192,9 +204,12 @@ def test_storage_combine(): Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): @@ -226,9 +241,12 @@ def test_storage_share_gpu(): Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A') Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B') stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt)) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + alloc_stats = {"global": 0, "shared": 0} def verify(n): @@ -248,7 +266,9 @@ def test_parallel_alloc(): A[j] = A[j] + 2 body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + assert (isinstance(body.body.body, tvm.tir.Allocate)) ib = tvm.tir.ir_builder.create() @@ -262,7 +282,9 @@ def test_parallel_alloc(): A = ib.allocate("float32", n, name="A", scope="global") A[j] = A[j] + 2 body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body assert(isinstance(body.body.body.body.body, tvm.tir.Allocate)) @@ -289,9 +311,12 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024): Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C') Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D') stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb, Cc, Dd], stmt)) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + # verify only have one allocations. # verify inplace folding works num_alloc = [0] @@ -381,10 +406,13 @@ def test_inplace_rule3(): B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5') Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B2a, B4: B4a, B5: B5a, B: Bb}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B3a, B4: B4a, B5: B5a, B: Bb}, 64) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([B0a, B1a, B2a, B3a, B4a, B5a, Bb], stmt)) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + # verify only have one allocations. # verify inplace folding works def verify(n): @@ -411,7 +439,10 @@ def test_alloc_seq_type(): A2[j] = A[j] body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): @@ -440,7 +471,10 @@ def test_alloc_seq_type2(): C[j] = 1.2 body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): @@ -469,7 +503,9 @@ def test_reuse_small_buffer(): E[j] = C[j] body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body num_alloc = [0] @@ -519,14 +555,15 @@ def verify(n): if __name__ == "__main__": + test_storage_share() test_alloc_seq() test_alloc_different_dtypes() test_inplace_rule() - test_storage_share() test_parallel_alloc() test_storage_combine() test_storage_share_gpu() test_inplace_rule2() + test_exceed_mem() test_inplace_rule3() test_alloc_seq_type() diff --git a/tests/python/unittest/test_tir_pass_unroll.py b/tests/python/unittest/test_tir_transform_unroll_loop.py similarity index 84% rename from tests/python/unittest/test_tir_pass_unroll.py rename to tests/python/unittest/test_tir_transform_unroll_loop.py index 165edab55f4e..7854835a21d6 100644 --- a/tests/python/unittest/test_tir_pass_unroll.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -46,7 +46,11 @@ def test_unroll_loop(): wrapped = ib.get() wrapped = tvm.tir.SeqStmt([wrapped, stmt]) assert isinstance(ret, tvm.tir.For) - ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped)) + ret = tvm.tir.transform.UnrollLoop(0, 8, 0, False)(mod)["main"].body + + # ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False) assert isinstance(ret[0], tvm.tir.For) assert ret[0].for_type == tvm.tir.For.Unrolled assert isinstance(ret[1], tvm.tir.For) @@ -65,7 +69,11 @@ def test_unroll_fake_loop(): Aptr[j + 1] = Aptr[i] + 1 stmt = ib.get() - ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt)) + ret = tvm.tir.transform.UnrollLoop(8, 0, 1, False)(mod)["main"].body + + # ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True) assert isinstance(ret[0], tvm.tir.Store) def test_unroll_single_count_loops(): @@ -78,8 +86,10 @@ def test_unroll_single_count_loops(): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) # all parameters to UnrolLoops are default values except for # auto_unroll_max_extent which has been set to 1 (default:0) - after_unroll_stmt = tvm.tir.ir_pass.UnrollLoop(stmt, 0, 8, 1, True) - assert after_unroll_stmt == stmt + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + ret = tvm.tir.transform.UnrollLoop(0, 8, 1, True)(mod)["main"].body + + assert ret == stmt if __name__ == "__main__": test_unroll_loop() diff --git a/tests/python/unittest/test_tir_pass_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py similarity index 82% rename from tests/python/unittest/test_tir_pass_vectorize.py rename to tests/python/unittest/test_tir_transform_vectorize.py index 2ade843361c0..d7124b6b7e89 100644 --- a/tests/python/unittest/test_tir_pass_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -28,12 +28,16 @@ def test_vectorize_loop(): stmt = ib.get() assert isinstance(stmt.body, tvm.tir.For) - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.For) assert not isinstance(stmt.body, tvm.tir.For) assert isinstance(stmt.body.index, tvm.tir.Ramp) assert isinstance(stmt.body.value, tvm.tir.Broadcast) + def test_vectorize_vector(): dtype = 'int64' n = te.var('n') @@ -44,7 +48,10 @@ def test_vectorize_vector(): A[j] = tvm.tir.const(1, A.dtype) stmt = ib.get() assert isinstance(stmt.body, tvm.tir.For) - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.For) assert not isinstance(stmt.body, tvm.tir.For) assert isinstance(stmt.body.index, tvm.tir.Ramp) @@ -63,13 +70,17 @@ def test_vectorize_with_if(): with ib.if_scope(i < n): A[i] = 2.0 stmt = ib.get() - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.IfThenElse) assert isinstance(stmt.then_case.index, tvm.tir.Ramp) assert isinstance(stmt.then_case.value, tvm.tir.Add) assert stmt.then_case.value.dtype == "float32x4" assert isinstance(stmt.else_case, tvm.tir.For) + def test_vectorize_with_le_cond(): n = te.var('n') ib = tvm.tir.ir_builder.create() @@ -78,9 +89,13 @@ def test_vectorize_with_le_cond(): with ib.if_scope(i <= n): A[i] = A[i] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.For) + def test_vectorize_with_ge_cond(): n = te.var('n') ib = tvm.tir.ir_builder.create() @@ -89,9 +104,13 @@ def test_vectorize_with_ge_cond(): with ib.if_scope(i >= n): A[i] = A[i] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.For) + def test_vectorize_if_then_else(): n = te.var('n') x = te.var('x') @@ -102,7 +121,10 @@ def test_vectorize_if_then_else(): i > 0, A[i] + 1, A[i]) stmt = ib.get() - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.For) @@ -114,8 +136,12 @@ def test_vectorize_if_then_else(): k > 0, A[k * 4 + i], 0) stmt = ib.get() + assert isinstance(stmt.body, tvm.tir.For) - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert not isinstance(stmt.body, tvm.tir.For) assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast)