From d759a72ff863620a4058de0839045ac8680ac3cb Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 22 Jan 2022 01:46:39 +0800 Subject: [PATCH] [MetaSchedule] Post Processor: Rewrite Reduction Block (#10013) Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin --- python/tvm/meta_schedule/postproc/__init__.py | 1 + .../postproc/rewrite_reduction_block.py | 31 +++ .../postproc/rewrite_reduction_block.cc | 157 +++++++++++++ src/meta_schedule/utils.h | 13 + src/tir/schedule/utils.h | 30 +++ ...hedule_postproc_rewrite_reduction_block.py | 222 ++++++++++++++++++ 6 files changed, 454 insertions(+) create mode 100644 python/tvm/meta_schedule/postproc/rewrite_reduction_block.py create mode 100644 src/meta_schedule/postproc/rewrite_reduction_block.cc create mode 100644 tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 2e8fa1e777d6..5b16be16fca7 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -16,4 +16,5 @@ # under the License. """The tvm.meta_schedule.postproc package.""" from .postproc import Postproc, PyPostproc +from .rewrite_reduction_block import RewriteReductionBlock from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py new file mode 100644 index 000000000000..7e15ed493ccb --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that rewrites reduction block by moving the init block out.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteReductionBlock") +class RewriteReductionBlock(Postproc): + """A postprocessor that rewrites reduction block by moving the init block out.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteReductionBlock, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc new file mode 100644 index 000000000000..cea1f5b93c9f --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The visitor that finds all the reduction block to be decomposed */ +struct ReductionBlockFinder : private StmtVisitor { + public: + /*! \brief Find all the reduction blocks that should be decomposed */ + static std::vector> Find(const ScheduleState& self) { + std::vector> results; + for (const auto& kv : self->mod->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + ReductionBlockFinder finder; + finder(prim_func->body); + for (const BlockNode* block : finder.results_) { + results.emplace_back(self->stmt2ref.at(block), g_var->name_hint); + } + } + } + return results; + } + + private: + void VisitStmt_(const ForNode* loop) final { + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsThreadIdx(thread_scope) || IsBlockIdx(thread_scope)) { + thread_bound_loop_vars_.insert(loop->loop_var.get()); + } + StmtVisitor::VisitStmt_(loop); + } + + void VisitStmt_(const BlockRealizeNode* realize) final { + if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) { + results_.push_back(realize->block.get()); + } + StmtVisitor::VisitStmt_(realize); + } + + bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const { + if (thread_bound_loop_vars_.empty()) { + return true; + } + auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); }; + const BlockNode* block = realize->block.get(); + ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); + int n = block->iter_vars.size(); + for (int i = 0; i < n; ++i) { + IterVar iter_var = block->iter_vars[i]; + PrimExpr binding = realize->iter_values[i]; + if (iter_var->iter_type == tir::kCommReduce) { + if (UsesVar(binding, f_find)) { + return false; + } + } + } + return true; + } + + /*! \brief The results of the collection */ + std::vector results_; + /*! \brief Loop variables that are bound to threads */ + std::unordered_set thread_bound_loop_vars_; +}; + +/*! + * \brief Find the innermost loop that the `init` of the input block could be decomposed to + * \param block_sref The StmtSRef of the block to be decomposed + * \return The index of the innermost loop where the `init` of the input block could be decomposed, + * or -1 if the `init` does not need to be decomposed. + */ +int FindDecomposePoint(const StmtSRef& block_sref) { + Array loop_srefs = GetLoops(block_sref); + int n = loop_srefs.size(); + for (int i = 0; i < n; ++i) { + if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) { + return i; + } + } + return -1; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! \brief Rewrite reduction block by moving the init block out */ +class RewriteReductionBlockNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode); +}; + +bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { + for (;;) { + std::vector> results = + tir::ReductionBlockFinder::Find(sch->state()); + int rewritten = 0; + for (const auto& kv : results) { + const tir::StmtSRef& block_sref = kv.first; + const String& global_var_name = kv.second; + int decompose_point = tir::FindDecomposePoint(block_sref); + if (decompose_point == -1) { + continue; + } + tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + Array loop_rvs = sch->GetLoops(block_rv); + tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); + ++rewritten; + } + if (rewritten == 0) { + break; + } + } + return true; +} + +Postproc Postproc::RewriteReductionBlock() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") + .set_body_typed(Postproc::RewriteReductionBlock); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 5b497695400a..bd76ca794a9a 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -246,6 +246,19 @@ inline std::string Concat(const Array& strs, const std::string& delim) { return os.str(); } +/*! + * \brief Get the BlockRV from a block StmtSRef + * \param sch The schedule + * \param block_sref The block StmtSRef + * \param global_var_name The global variable name + * \return The BlockRV + */ +inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, + const String& global_var_name) { + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + return sch->GetBlock(block->name_hint, global_var_name); +} + /*! * \brief A helper data structure that replays a trace and collects failure counts * for each postprocessor diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index be6d5a18a47f..ebd2284cbe3c 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -192,6 +192,36 @@ inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_va Var(std::move(name), loop->loop_var.dtype()), iter_var_type); } +/*! + * \brief Get the thread scope bound to the specific loop + * \param loop The loop to be inspected + * \return The thread scope bound to the loop + */ +inline runtime::ThreadScope GetThreadScope(const ForNode* loop) { + if (loop->kind == ForKind::kThreadBinding) { + return runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag); + } + return runtime::ThreadScope{-1, -1}; +} + +/*! + * \brief Check if the thread scope is blockIdx + * \param thread_scope The thread scope to be checked + * \return True if the thread scope is blockIdx + */ +inline bool IsBlockIdx(const runtime::ThreadScope& thread_scope) { + return thread_scope.rank == 0; // The rank of blockIdx is 0 +} + +/*! + * \brief Check if the thread scope is threadIdx + * \param thread_scope The thread scope to be checked + * \return True if the thread scope is threadIdx + */ +inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) { + return thread_scope.rank == 1 && thread_scope.dim_index >= 0; +} + /******** Integer set ********/ /*! diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py new file mode 100644 index 000000000000..263448aa1be6 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -0,0 +1,222 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteReductionBlock +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteReductionBlock(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Matmul_before_rewrite: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.ir_module +class Matmul_after_rewrite: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(2, 2, 16, 2): + with T.block("C_init"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3_init * 16 + i0_4_init) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3_init * 2 + i1_4_init) + T.reads([]) + T.writes([C_local[i, j]]) + C_local[i, j] = T.float32(0) + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C_update"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.ir_module +class Softmax_cross_thread_reduction: + @T.prim_func + def main(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + for i0 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 8): + for ax1_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + i0_1 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax1_0 * 32 + ax1_1) + T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) + T.writes(T_softmax_maxelem_shared[i0_1]) + with T.init(): + T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e+38) + T_softmax_maxelem_shared[i0_1] = T.max(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) + for ax0, ax1_0 in T.grid(1, 8): + for ax1_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + i0_2 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax1_0 * 32 + ax1_1) + T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]) + T.writes(T_softmax_expsum_shared[i0_2]) + with T.init(): + T_softmax_expsum_shared[i0_2] = T.float32(0) + T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32") + for i1_0 in T.serial(8): + for i1_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_3 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) + T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]) + T.writes(T_softmax_norm[i0_3, i1]) + T.block_attr({"axis":1}) + T_softmax_norm[i0_3, i1] = T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") / T_softmax_expsum_shared[i0_3] + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_tiled_matmul(): + mod = Matmul_before_rewrite + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, Matmul_after_rewrite) + + +def test_rewrite_softmax(): + mod = Softmax_cross_thread_reduction + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + # The module should not be rewritten + tvm.ir.assert_structural_equal(sch.mod, Softmax_cross_thread_reduction) + + +if __name__ == "__main__": + test_rewrite_tiled_matmul() + test_rewrite_softmax()