diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 570560c62d8c..162ad84f26f6 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -328,6 +328,12 @@ class ScheduleNode : public runtime::Object { * \param ordered_loop_rvs The loops in the new order */ virtual void Reorder(const Array& ordered_loop_rvs) = 0; + /*! + * \brief Reorder the itervars inside a block. + * \param block_rv The block to be transformed. + * \param new_order The new itervar order. + */ + virtual void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) = 0; /*! * \brief Create a new unit loop on top of the specific block. * \param block_rv The block above which the new loop is created diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 68f0b9454cb1..5527905141fc 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -752,6 +752,19 @@ def after_reorder(a: T.handle, b: T.handle) -> None: """ _ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member + @type_checked + def reorder_block_iter_var(self, block: BlockRV, new_order: List[int]) -> None: + """Reorder the itervars inside a given block. + + Parameters + ---------- + block : BlockRV + The block to be transformed. + new_order : List[int] + The new block itervar order. + """ + _ffi_api.ScheduleReorderBlockIterVar(self, block, new_order) # type: ignore # pylint: disable=no-member + @type_checked def add_unit_loop(self, block_or_loop: Union[LoopRV, BlockRV]) -> LoopRV: """Create a new unit loop on top of the specific block or loop. diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 93ea38169d74..912fdcf5e7d8 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -480,6 +480,14 @@ void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { this->state_->DebugVerify(); } +void ConcreteScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, + const Array new_order) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::ReorderBlockIterVar(state_, GetSRef(block_rv), new_order); + TVM_TIR_SCHEDULE_END("reorder_block_iter_var", this->error_render_level_); + this->state_->DebugVerify(); +} + LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { LoopRV result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 227288b232d9..4cbbc0d6854a 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -104,6 +104,7 @@ class ConcreteScheduleNode : public ScheduleNode { Array Split(const LoopRV& loop_rv, const Array>& factors, bool preserve_unit_iters) override; void Reorder(const Array& ordered_loop_rvs) override; + void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) override; LoopRV AddUnitLoop(const BlockRV& block_rv) override; LoopRV AddUnitLoop(const LoopRV& loop_rv) override; /******** Schedule: Manipulate ForKind ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 09185498e143..283856bddc84 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -189,6 +189,15 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, */ TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_srefs); +/*! + * \brief Reorder itervars inside a block. + * \param self The state of the schedule. + * \param block_sref The sref of block to be transformed. + * \param new_order The new itervar order. + */ +TVM_DLL void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, + const Array& new_order); + /*! * \brief Create a new unit loop on top of the specific block or loop. * \param sref The block/loop above which the new thread_binding loop is created diff --git a/src/tir/schedule/primitive/reorder_block_iter_var.cc b/src/tir/schedule/primitive/reorder_block_iter_var.cc new file mode 100644 index 000000000000..c7967a3ee904 --- /dev/null +++ b/src/tir/schedule/primitive/reorder_block_iter_var.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief The reorder index is not a valid permutation of + * [0, 1, ..., n-1] where n is the number of block iter vars. + */ +class InvalidReorderIndex : public ScheduleError { + public: + explicit InvalidReorderIndex(IRModule mod, Block block, Array new_order) + : mod_(mod), block_(block), new_order_(new_order) {} + IRModule mod() const final { return mod_; } + String FastErrorString() const final { + return "ScheduleError: The specified reorder indices are invalid."; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The user provided block itervar index order " << new_order_ + << " is not a valid permutation of [0, 1, ..., num_block_iter_vars-1] in block {0}."; + return String(os.str()); + } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + Array new_order_; +}; + +class BlockIterVarRewriter : public StmtMutator { + public: + Map block_map; + explicit BlockIterVarRewriter(const BlockNode* block_n, std::vector order) + : order_(std::move(order)), block_to_rewrite(block_n) {} + + private: + std::vector order_; + const BlockNode* block_to_rewrite; + Stmt VisitStmt_(const BlockRealizeNode* op) final { + if (op->block.get() == block_to_rewrite) { + auto block_n = CopyOnWrite(op->block.get()); + Block block = op->block; + Array new_iter_vars; + Array new_iter_values; + for (int idx : order_) { + new_iter_vars.push_back(block->iter_vars[idx]); + new_iter_values.push_back(op->iter_values[idx]); + } + block_n->iter_vars = new_iter_vars; + Block new_block(block_n); + block_map.Set(block, new_block); + auto block_realize_n = CopyOnWrite(op); + block_realize_n->block = new_block; + block_realize_n->iter_values = new_iter_values; + return BlockRealize(block_realize_n); + } else { + return StmtMutator::VisitStmt_(op); + } + } +}; + +void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, + const Array& new_order) { + const BlockNode* block_n = TVM_SREF_TO_BLOCK(block_sref); + std::vector new_order_vec; + for (const Integer& x : new_order) { + new_order_vec.push_back(x->value); + } + // check whether new_order is valid or not; + size_t num_block_itervars = block_n->iter_vars.size(); + std::set ind_set(new_order_vec.begin(), new_order_vec.end()); + bool is_full = new_order_vec.size() == num_block_itervars; + bool is_unique = (ind_set.size() == new_order_vec.size()); + bool is_within_boundary = std::all_of(new_order_vec.begin(), new_order_vec.end(), [&](int x) { + return x >= 0 && x < static_cast(num_block_itervars); + }); + if (!is_full || !is_unique || !is_within_boundary) { + throw InvalidReorderIndex(self->mod, GetRef(block_n), new_order); + } + + // find parent block + const BlockNode* parent_block_n = nullptr; + const StmtSRefNode* p = block_sref.get()->parent; + while (p != nullptr) { + if (p->stmt->IsInstance()) { + parent_block_n = TVM_SREF_TO_BLOCK(GetRef(p)); + break; + } + p = p->parent; + } + const StmtSRef parent_block_sref = GetRef(p); + const Block& parent_block = GetRef(parent_block_n); + + // rewrite block and blockrealize + BlockIterVarRewriter rewriter(block_n, std::move(new_order_vec)); + Block new_parent_block = Downcast(rewriter(parent_block)); + rewriter.block_map.Set(parent_block, new_parent_block); + self->Replace(parent_block_sref, new_parent_block, rewriter.block_map); +} + +struct ReorderBlockIterVarTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReorderBlockIterVar"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Array new_order) { + sch->ReorderBlockIterVar(block, new_order); + } + + static String UnpackedAsPython(Array outputs, String block, Array new_order) { + PythonAPICall py("reorder_block_iter_var"); + py.Input("block", block); + py.Input("new_order", new_order); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ReorderBlockIterVarTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index a0e39b74d31b..0e6b77a6a8d9 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -157,6 +157,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&Sche TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder") .set_body_method(&ScheduleNode::Reorder); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorderBlockIterVar") + .set_body_method(&ScheduleNode::ReorderBlockIterVar); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV { if (const auto* loop_rv = rv.as()) { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 2b6a7f71d4f5..b87106e527fc 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -217,6 +217,15 @@ void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { /*outputs=*/{})); } +void TracedScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, + const Array new_order) { + ConcreteScheduleNode::ReorderBlockIterVar(block_rv, new_order); + static const InstructionKind& kind = InstructionKind::Get("ReorderBlockIterVar"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv, new_order}, /*attrs=*/{}, + /*outputs=*/{})); +} + LoopRV TracedScheduleNode::AddUnitLoop(const BlockRV& block_rv) { LoopRV result = ConcreteScheduleNode::AddUnitLoop(block_rv); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 8b9621c749de..0217f191e8a7 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -64,6 +64,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Array Split(const LoopRV& loop_rv, const Array>& factor_rvs, bool preserve_unit_iters) final; void Reorder(const Array& ordered_loop_rvs) final; + void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) final; LoopRV AddUnitLoop(const BlockRV& block_rv) final; LoopRV AddUnitLoop(const LoopRV& loop_rv) final; /******** Schedule: Manipulate ForKind ********/ diff --git a/tests/python/unittest/test_tir_reorder_block_iter_var.py b/tests/python/unittest/test_tir_reorder_block_iter_var.py new file mode 100644 index 000000000000..99e07aa525f9 --- /dev/null +++ b/tests/python/unittest/test_tir_reorder_block_iter_var.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +@T.prim_func +def matmul( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), +) -> None: + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def matmul_after_reorder_block_iter_var( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), +): + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vk, vj, vi = T.axis.remap("RSS", [k, j, i]) + T.reads(A[vi, vk], B[vj, vk]) + T.writes(C[vi, vj]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +def test_reorder_block_iter_var(): + sch = tir.Schedule(matmul, debug_mask="all") + C = sch.get_block("C") + sch.reorder_block_iter_var(C, [2, 1, 0]) + tvm.ir.assert_structural_equal(matmul_after_reorder_block_iter_var, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=matmul) + + +def test_reorder_block_iter_var_fail_not_full(): + sch = tir.Schedule(matmul, debug_mask="all") + C = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder_block_iter_var(C, [2, 1]) + + +def test_reorder_block_iter_var_fail_not_within_bound(): + sch = tir.Schedule(matmul, debug_mask="all") + C = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder_block_iter_var(C, [-1, 3, 2]) + + +def test_reorder_block_iter_var_fail_not_unique(): + sch = tir.Schedule(matmul, debug_mask="all") + C = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder_block_iter_var(C, [0, 0, 2]) + + +if __name__ == "__main__": + tvm.testing.main()