diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index e74b9ea26484..b76d41326ff1 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -231,6 +231,17 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va */ void CheckAffineBinding(const ScheduleState& self, Block block); +/*! + * \brief Check whether a block has an affine binding under the high exclusive sref node, + * throw an exception if the block does not have an affine binding. + * \param self The schedule state + * \param block The block to be checked + * \param high_exclusive The highest sref node + * \throw ScheduleError If the input block does not have an affine binding + */ +void CheckPartialAffineBinding(const ScheduleState& self, Block block, + const Optional& high_exclusive); + /*! * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 435870471f29..4a7ac401dd60 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -544,26 +544,62 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va return true; } -void CheckAffineBinding(const ScheduleState& self, Block block) { +void CheckPartialAffineBinding(const ScheduleState& self, Block block, + const Optional& high_exclusive) { class NotAffineBindingError : public ScheduleError { public: - explicit NotAffineBindingError(IRModule mod, Block block) - : mod_(std::move(mod)), block_(std::move(block)) {} + explicit NotAffineBindingError(IRModule mod, Block block, Optional high_exclusive) + : mod_(std::move(mod)), block_(std::move(block)) { + if (high_exclusive.defined()) { + high_exclusive_loop_ = high_exclusive.value()->StmtAs(); + } + } String FastErrorString() const final { - return "ScheduleError: The block is required to have an affine binding"; + std::ostringstream ss; + if (high_exclusive_loop_) { + ss << "ScheduleError: The block is required to have an partial affine binding under " + << high_exclusive_loop_->loop_var; + } else { + ss << "ScheduleError: The block is required to have an affine binding"; + } + return ss.str(); } String DetailRenderTemplate() const final { - return "The block {0} is required to have an affine binding"; + std::ostringstream ss; + if (high_exclusive_loop_) { + ss << "The block {0} is required to have an partial affine binding under " + << high_exclusive_loop_->loop_var; + } else { + ss << "The block {0} is required to have an affine binding"; + } + return ss.str(); } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; + const ForNode* high_exclusive_loop_{nullptr}; }; - if (!self->IsAffineBlockBinding(self->stmt2ref.at(block.get()))) { - throw NotAffineBindingError(self->mod, std::move(block)); + StmtSRef block_sref = self->stmt2ref.at(block.get()); + if (self->IsAffineBlockBinding(block_sref)) { + // check block cached state for global affineness + return; + } + if (block_sref->parent && high_exclusive.defined()) { + // if it is not of global affine binding, check affineness under high_exclusive, + arith::Analyzer analyzer; + Map dom_map = + LoopDomainOfSRefTreePath(GetRef(block_sref->parent), high_exclusive); + if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) { + return; + } } + throw NotAffineBindingError(self->mod, std::move(block), high_exclusive); +} + +void CheckAffineBinding(const ScheduleState& self, Block block) { + CheckPartialAffineBinding(self, std::move(block), NullOpt); } Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index fa2a4469b8c9..d64a72ed3401 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -134,16 +134,18 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { class BlockPropertyError : public ScheduleError { public: /*! - * \brief Check that all the blocks under the specific stmt have affine bindings and only have - * data-parallel or reduction block iters + * \brief Check that all the blocks under the specific stmt have affine bindings + * wrt top loop sref and only have data-parallel or reduction block iters * \param self The state of the schedule * \param sref The sref to the specific stmt */ - static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, + static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, const StmtSRefNode* top, const StmtSRefNode* sref) { class BlockIterTypeAndAffineBindingChecker : public StmtVisitor { public: - explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {} + explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state, + const StmtSRefNode* top) + : state_(state), top_(top) {} private: void VisitStmt_(const BlockNode* op) final { @@ -151,13 +153,16 @@ class BlockPropertyError : public ScheduleError { if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { throw BlockPropertyError(state_->mod, GetRef(op)); } - CheckAffineBinding(state_, GetRef(op)); + Optional high_exclusive = + top_->parent ? GetRef(top_->parent) : Optional(NullOpt); + CheckPartialAffineBinding(state_, GetRef(op), high_exclusive); } } const ScheduleState& state_; + const StmtSRefNode* top_; }; - BlockIterTypeAndAffineBindingChecker checker(self); + BlockIterTypeAndAffineBindingChecker checker(self, top); checker(GetRef(sref->stmt)); } @@ -708,8 +713,8 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { // Step 3. Collect all loops in the chain and check the loops are single-branch std::vector chain = GetLoopsInReorderRange(self, top, bottom); // Step 4. Check the block below has all its block_var to be data-parallel or reduction, - // and the block has an affine binding. - BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom); + // and the block has an affine binding wrt top of the loop range. + BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, top, bottom); // Step 5. Replace the original loops with the reordered loops and check that outer loop is // not dependent on inner loop For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs); diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index f62a316f8013..462099e6fe15 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -213,6 +213,95 @@ def test_reorder_with_opaque_access(): verify_trace_roundtrip(sch=sch, mod=opaque_access) +def test_reorder_with_partial_affineness(): + @T.prim_func + def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]): + # example to write first axis multiple times + for v0, v1, v2 in T.grid(6, 4, 4): + with T.block("block"): + i = T.axis.spatial(14, v0 * 2 + v1) + j = T.axis.spatial(4, v2) + B[i, j] = A[i, j] + 1.0 + + @T.prim_func + def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]): + # example to write first axis multiple times + for v0, v2, v1 in T.grid(6, 4, 4): + with T.block("block"): + i = T.axis.spatial(14, v0 * 2 + v1) + j = T.axis.spatial(4, v2) + B[i, j] = A[i, j] + 1.0 + + sch = tir.Schedule(non_affine_func, debug_mask="all") + v0, v1, v2 = sch.get_loops(sch.get_block("block")) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(v0, v2, v1) + + sch.reorder(v2, v1) + tvm.ir.assert_structural_equal(non_affine_func_reorder, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=non_affine_func) + + +def test_reorder_with_cascade_tiled_ops(): + @T.prim_func + def cascade_pool_ops( + x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"] + ) -> None: + y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32") + for n, c, h, w, kh, kw in T.grid(1, 16, 110, 110, 3, 3): + with T.block("pool_0"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw]) + with T.init(): + y1[ax0, ax1, ax2, ax3] = 0.0 + y1[ax0, ax1, ax2, ax3] = y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1] + for n, c, h, w, kh, kw in T.grid(1, 16, 108, 108, 3, 3): + with T.block("pool_1"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw]) + with T.init(): + y2[ax0, ax1, ax2, ax3] = 0.0 + y2[ax0, ax1, ax2, ax3] = y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1] + + @T.prim_func + def cascade_pool_ops_tile_reordered( + x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"] + ) -> None: + y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32") + for n, c, h_o in T.grid(1, 16, 27): + for w, h_i, kh, kw in T.grid(110, 6, 3, 3): + with T.block("pool_0"): + ax0 = T.axis.spatial(1, 0) + ax1 = T.axis.spatial(16, c) + ax2 = T.axis.spatial(110, h_o * 4 + h_i) + ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw]) + with T.init(): + y1[ax0, ax1, ax2, ax3] = 0.0 + y1[ax0, ax1, ax2, ax3] = ( + y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1] + ) + for h_i, w, kh, kw in T.grid(4, 108, 3, 3): + with T.block("pool_1"): + ax0 = T.axis.spatial(1, 0) + ax1 = T.axis.spatial(16, c) + ax2 = T.axis.spatial(108, h_o * 4 + h_i) + ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw]) + with T.init(): + y2[ax0, ax1, ax2, ax3] = 0.0 + y2[ax0, ax1, ax2, ax3] = ( + y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1] + ) + + sch = tvm.tir.schedule.Schedule(cascade_pool_ops) + pool_0 = sch.get_block("pool_0") + pool_1 = sch.get_block("pool_1") + _, _, h, w, _, _ = sch.get_loops(pool_1) + ho, _ = sch.split(h, factors=[None, 4]) + sch.compute_at(pool_0, ho) + _, _, _, h_i, w, _, _ = sch.get_loops(pool_0) + sch.reorder(w, h_i) + tvm.ir.assert_structural_equal(cascade_pool_ops_tile_reordered, sch.mod["main"], True) + verify_trace_roundtrip(sch=sch, mod=cascade_pool_ops) + + def test_reorder_with_predicate(): sch = tir.Schedule(elementwise_predicate, debug_mask="all") block_b = sch.get_block("B")