From 6006d25f4332a8b024d75cfea5baa325c61798e9 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 2 Apr 2023 20:23:54 -0700 Subject: [PATCH] [TensorIR][Schedule] New primitive `reorder_block_itervar` (#14448) # Motivation Currently the `reorder` primitive only changes the loops, and block iterable variables order would not be changed. `transform_block_layout` can change the block iterable variables, but it requires the loops outside the given block to have no branches, which limited its usage. This schedule primitive changes the block iterable variable order directly, with API like: ```python def reorder_block_iter_var(self, block: BlockRV, new_order: List[int]) -> None: """Reorder the itervars inside a given block. Parameters ---------- block : BlockRV The block to be transformed. new_order : List[int] The new block itervar order. """ ``` where the `new_order` is a permutation of [0, 1, ..., n-1] if n is the number of itervars in the block. # Example Suppose we need to change the block itervar order in block "C": ```python @T.prim_func def matmul(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: for i, j, k in T.grid(128, 128, 128): with T.block("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] ``` after applying: ```python sch = tir.Schedule(matmul, debug_mask="all") C = sch.get_block("C") sch.reorder_block_iter_var(C, [2, 1, 0]) ``` the block itervar order would be changed to `vk, vj, vi`. ```python @T.prim_func def matmul_after_reorder_block_iter_var(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]): for i, j, k in T.grid(128, 128, 128): with T.block("C"): vk, vj, vi = T.axis.remap("RSS", [k, j, i]) T.reads(A[vi, vk], B[vj, vk]) T.writes(C[vi, vj]) with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] ``` --- include/tvm/tir/schedule/schedule.h | 6 + python/tvm/tir/schedule/schedule.py | 13 ++ src/tir/schedule/concrete_schedule.cc | 8 + src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/primitive.h | 9 ++ .../primitive/reorder_block_iter_var.cc | 148 ++++++++++++++++++ src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 9 ++ src/tir/schedule/traced_schedule.h | 1 + .../test_tir_reorder_block_iter_var.py | 86 ++++++++++ 10 files changed, 283 insertions(+) create mode 100644 src/tir/schedule/primitive/reorder_block_iter_var.cc create mode 100644 tests/python/unittest/test_tir_reorder_block_iter_var.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 570560c62d8c..162ad84f26f6 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -328,6 +328,12 @@ class ScheduleNode : public runtime::Object { * \param ordered_loop_rvs The loops in the new order */ virtual void Reorder(const Array& ordered_loop_rvs) = 0; + /*! + * \brief Reorder the itervars inside a block. + * \param block_rv The block to be transformed. + * \param new_order The new itervar order. + */ + virtual void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) = 0; /*! * \brief Create a new unit loop on top of the specific block. * \param block_rv The block above which the new loop is created diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 68f0b9454cb1..5527905141fc 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -752,6 +752,19 @@ def after_reorder(a: T.handle, b: T.handle) -> None: """ _ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member + @type_checked + def reorder_block_iter_var(self, block: BlockRV, new_order: List[int]) -> None: + """Reorder the itervars inside a given block. + + Parameters + ---------- + block : BlockRV + The block to be transformed. + new_order : List[int] + The new block itervar order. + """ + _ffi_api.ScheduleReorderBlockIterVar(self, block, new_order) # type: ignore # pylint: disable=no-member + @type_checked def add_unit_loop(self, block_or_loop: Union[LoopRV, BlockRV]) -> LoopRV: """Create a new unit loop on top of the specific block or loop. diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 93ea38169d74..912fdcf5e7d8 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -480,6 +480,14 @@ void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { this->state_->DebugVerify(); } +void ConcreteScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, + const Array new_order) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::ReorderBlockIterVar(state_, GetSRef(block_rv), new_order); + TVM_TIR_SCHEDULE_END("reorder_block_iter_var", this->error_render_level_); + this->state_->DebugVerify(); +} + LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { LoopRV result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 227288b232d9..4cbbc0d6854a 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -104,6 +104,7 @@ class ConcreteScheduleNode : public ScheduleNode { Array Split(const LoopRV& loop_rv, const Array>& factors, bool preserve_unit_iters) override; void Reorder(const Array& ordered_loop_rvs) override; + void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) override; LoopRV AddUnitLoop(const BlockRV& block_rv) override; LoopRV AddUnitLoop(const LoopRV& loop_rv) override; /******** Schedule: Manipulate ForKind ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 09185498e143..283856bddc84 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -189,6 +189,15 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, */ TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_srefs); +/*! + * \brief Reorder itervars inside a block. + * \param self The state of the schedule. + * \param block_sref The sref of block to be transformed. + * \param new_order The new itervar order. + */ +TVM_DLL void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, + const Array& new_order); + /*! * \brief Create a new unit loop on top of the specific block or loop. * \param sref The block/loop above which the new thread_binding loop is created diff --git a/src/tir/schedule/primitive/reorder_block_iter_var.cc b/src/tir/schedule/primitive/reorder_block_iter_var.cc new file mode 100644 index 000000000000..c7967a3ee904 --- /dev/null +++ b/src/tir/schedule/primitive/reorder_block_iter_var.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief The reorder index is not a valid permutation of + * [0, 1, ..., n-1] where n is the number of block iter vars. + */ +class InvalidReorderIndex : public ScheduleError { + public: + explicit InvalidReorderIndex(IRModule mod, Block block, Array new_order) + : mod_(mod), block_(block), new_order_(new_order) {} + IRModule mod() const final { return mod_; } + String FastErrorString() const final { + return "ScheduleError: The specified reorder indices are invalid."; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The user provided block itervar index order " << new_order_ + << " is not a valid permutation of [0, 1, ..., num_block_iter_vars-1] in block {0}."; + return String(os.str()); + } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + Array new_order_; +}; + +class BlockIterVarRewriter : public StmtMutator { + public: + Map block_map; + explicit BlockIterVarRewriter(const BlockNode* block_n, std::vector order) + : order_(std::move(order)), block_to_rewrite(block_n) {} + + private: + std::vector order_; + const BlockNode* block_to_rewrite; + Stmt VisitStmt_(const BlockRealizeNode* op) final { + if (op->block.get() == block_to_rewrite) { + auto block_n = CopyOnWrite(op->block.get()); + Block block = op->block; + Array new_iter_vars; + Array new_iter_values; + for (int idx : order_) { + new_iter_vars.push_back(block->iter_vars[idx]); + new_iter_values.push_back(op->iter_values[idx]); + } + block_n->iter_vars = new_iter_vars; + Block new_block(block_n); + block_map.Set(block, new_block); + auto block_realize_n = CopyOnWrite(op); + block_realize_n->block = new_block; + block_realize_n->iter_values = new_iter_values; + return BlockRealize(block_realize_n); + } else { + return StmtMutator::VisitStmt_(op); + } + } +}; + +void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, + const Array& new_order) { + const BlockNode* block_n = TVM_SREF_TO_BLOCK(block_sref); + std::vector new_order_vec; + for (const Integer& x : new_order) { + new_order_vec.push_back(x->value); + } + // check whether new_order is valid or not; + size_t num_block_itervars = block_n->iter_vars.size(); + std::set ind_set(new_order_vec.begin(), new_order_vec.end()); + bool is_full = new_order_vec.size() == num_block_itervars; + bool is_unique = (ind_set.size() == new_order_vec.size()); + bool is_within_boundary = std::all_of(new_order_vec.begin(), new_order_vec.end(), [&](int x) { + return x >= 0 && x < static_cast(num_block_itervars); + }); + if (!is_full || !is_unique || !is_within_boundary) { + throw InvalidReorderIndex(self->mod, GetRef(block_n), new_order); + } + + // find parent block + const BlockNode* parent_block_n = nullptr; + const StmtSRefNode* p = block_sref.get()->parent; + while (p != nullptr) { + if (p->stmt->IsInstance()) { + parent_block_n = TVM_SREF_TO_BLOCK(GetRef(p)); + break; + } + p = p->parent; + } + const StmtSRef parent_block_sref = GetRef(p); + const Block& parent_block = GetRef(parent_block_n); + + // rewrite block and blockrealize + BlockIterVarRewriter rewriter(block_n, std::move(new_order_vec)); + Block new_parent_block = Downcast(rewriter(parent_block)); + rewriter.block_map.Set(parent_block, new_parent_block); + self->Replace(parent_block_sref, new_parent_block, rewriter.block_map); +} + +struct ReorderBlockIterVarTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReorderBlockIterVar"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Array new_order) { + sch->ReorderBlockIterVar(block, new_order); + } + + static String UnpackedAsPython(Array outputs, String block, Array new_order) { + PythonAPICall py("reorder_block_iter_var"); + py.Input("block", block); + py.Input("new_order", new_order); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ReorderBlockIterVarTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index a0e39b74d31b..0e6b77a6a8d9 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -157,6 +157,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&Sche TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder") .set_body_method(&ScheduleNode::Reorder); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorderBlockIterVar") + .set_body_method(&ScheduleNode::ReorderBlockIterVar); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV { if (const auto* loop_rv = rv.as()) { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 2b6a7f71d4f5..b87106e527fc 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -217,6 +217,15 @@ void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { /*outputs=*/{})); } +void TracedScheduleNode::ReorderBlockIterVar(const BlockRV& block_rv, + const Array new_order) { + ConcreteScheduleNode::ReorderBlockIterVar(block_rv, new_order); + static const InstructionKind& kind = InstructionKind::Get("ReorderBlockIterVar"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv, new_order}, /*attrs=*/{}, + /*outputs=*/{})); +} + LoopRV TracedScheduleNode::AddUnitLoop(const BlockRV& block_rv) { LoopRV result = ConcreteScheduleNode::AddUnitLoop(block_rv); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 8b9621c749de..0217f191e8a7 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -64,6 +64,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Array Split(const LoopRV& loop_rv, const Array>& factor_rvs, bool preserve_unit_iters) final; void Reorder(const Array& ordered_loop_rvs) final; + void ReorderBlockIterVar(const BlockRV& block_rv, const Array new_order) final; LoopRV AddUnitLoop(const BlockRV& block_rv) final; LoopRV AddUnitLoop(const LoopRV& loop_rv) final; /******** Schedule: Manipulate ForKind ********/ diff --git a/tests/python/unittest/test_tir_reorder_block_iter_var.py b/tests/python/unittest/test_tir_reorder_block_iter_var.py new file mode 100644 index 000000000000..99e07aa525f9 --- /dev/null +++ b/tests/python/unittest/test_tir_reorder_block_iter_var.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +@T.prim_func +def matmul( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), +) -> None: + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def matmul_after_reorder_block_iter_var( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), +): + for i, j, k in T.grid(128, 128, 128): + with T.block("C"): + vk, vj, vi = T.axis.remap("RSS", [k, j, i]) + T.reads(A[vi, vk], B[vj, vk]) + T.writes(C[vi, vj]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +def test_reorder_block_iter_var(): + sch = tir.Schedule(matmul, debug_mask="all") + C = sch.get_block("C") + sch.reorder_block_iter_var(C, [2, 1, 0]) + tvm.ir.assert_structural_equal(matmul_after_reorder_block_iter_var, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=matmul) + + +def test_reorder_block_iter_var_fail_not_full(): + sch = tir.Schedule(matmul, debug_mask="all") + C = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder_block_iter_var(C, [2, 1]) + + +def test_reorder_block_iter_var_fail_not_within_bound(): + sch = tir.Schedule(matmul, debug_mask="all") + C = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder_block_iter_var(C, [-1, 3, 2]) + + +def test_reorder_block_iter_var_fail_not_unique(): + sch = tir.Schedule(matmul, debug_mask="all") + C = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder_block_iter_var(C, [0, 0, 2]) + + +if __name__ == "__main__": + tvm.testing.main()