From dafee2ecde2989bbdd40ce2955514c17cfdab2e1 Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Sat, 8 Apr 2023 22:36:57 -0400 Subject: [PATCH] [Dynamic] M2 for S3: Compute Inline (#173) --- .../analysis/block_access_region_detector.cc | 94 ++--- src/tir/schedule/primitive/compute_inline.cc | 138 ++++---- tests/python/relax/test_dyn_compute_inline.py | 330 ++++++++++++++++++ .../test_tir_schedule_compute_inline.py | 92 +++++ 4 files changed, 542 insertions(+), 112 deletions(-) create mode 100644 tests/python/relax/test_dyn_compute_inline.py diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 057cec475d..4f013d3344 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -37,8 +37,10 @@ namespace tir { */ class BlockReadWriteDetector : public StmtExprVisitor { public: - explicit BlockReadWriteDetector(const Map& buffer_var_map) - : buffer_var_map_(buffer_var_map) {} + explicit BlockReadWriteDetector(const Array& alloc_buffers, + const Map& buffer_var_map) + : buffer_var_map_(buffer_var_map), + alloc_buffers_(alloc_buffers.begin(), alloc_buffers.end()) {} /*! \brief Return read regions of the block */ Array CollectReads( @@ -78,6 +80,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { std::unordered_map match_buffers_; /*!\ brief Internal analyzer. */ arith::Analyzer ana_; + /*! \brief The alloc buffers of the current block*/ + std::unordered_set alloc_buffers_; /*! * \brief Update read/write buffers and regions with provided buffer and region @@ -145,11 +149,13 @@ Array BlockReadWriteDetector::CollectOpaques() { void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef(op)); } void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { - std::vector relaxed_region; - for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); + if (!alloc_buffers_.count(op->buffer)) { + std::vector relaxed_region; + for (const PrimExpr& index : op->indices) { + relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); + } + Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); } - Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); ExprVisitor::VisitExpr_(op); } @@ -182,20 +188,22 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { auto it = buffer_var_map_.find(GetRef(buffer_var)); if (it != buffer_var_map_.end()) { const Buffer& buffer = (*it).second; - const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); - const Region& region = buffer_region->region; - std::vector int_set; - int_set.reserve(region.size()); - for (const Range& range : region) { - int_set.push_back(arith::EvalSet(range, dom_map_)); - } - // read access, write access or opaque access - if ((access_mask->value & 1) && (access_mask->value & 2)) { - Update(&opaque_buffers_, &opaque_regions_, buffer, int_set); - } else if (access_mask->value & 1) { - Update(&read_buffers_, &read_regions_, buffer, int_set); - } else if (access_mask->value & 2) { - Update(&writes_buffers_, &write_regions_, buffer, int_set); + if (!alloc_buffers_.count(buffer)) { + const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); + const Region& region = buffer_region->region; + std::vector int_set; + int_set.reserve(region.size()); + for (const Range& range : region) { + int_set.push_back(arith::EvalSet(range, dom_map_)); + } + // read access, write access or opaque access + if ((access_mask->value & 1) && (access_mask->value & 2)) { + Update(&opaque_buffers_, &opaque_regions_, buffer, int_set); + } else if (access_mask->value & 1) { + Update(&read_buffers_, &read_regions_, buffer, int_set); + } else if (access_mask->value & 2) { + Update(&writes_buffers_, &write_regions_, buffer, int_set); + } } } } else { @@ -221,11 +229,13 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { } void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { - std::vector relaxed_region; - for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); + if (!alloc_buffers_.count(op->buffer)) { + std::vector relaxed_region; + for (const PrimExpr& index : op->indices) { + relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); + } + Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); } - Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); StmtVisitor::VisitStmt_(op); } @@ -236,24 +246,28 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) { vmap[op->block->iter_vars[i]->var.get()] = op->iter_values[i]; } for (const auto& read : op->block->reads) { - std::vector relaxed_region; - for (const auto& range : read->region) { - relaxed_region.push_back( - arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent( - Substitute(range->min, vmap), Substitute(range->extent, vmap))), - dom_map_)); + if (!alloc_buffers_.count(read->buffer)) { + std::vector relaxed_region; + for (const auto& range : read->region) { + relaxed_region.push_back( + arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent( + Substitute(range->min, vmap), Substitute(range->extent, vmap))), + dom_map_)); + } + Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region); } - Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region); } for (const auto& write : op->block->writes) { - std::vector relaxed_region; - for (const auto& range : write->region) { - relaxed_region.push_back( - arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent( - Substitute(range->min, vmap), Substitute(range->extent, vmap))), - dom_map_)); + if (!alloc_buffers_.count(write->buffer)) { + std::vector relaxed_region; + for (const auto& range : write->region) { + relaxed_region.push_back( + arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent( + Substitute(range->min, vmap), Substitute(range->extent, vmap))), + dom_map_)); + } + Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region); } - Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region); } } @@ -349,7 +363,7 @@ void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) { Array> GetBlockAccessRegion(const Block& block, const Map& buffer_var_map) { - BlockReadWriteDetector detector(buffer_var_map); + BlockReadWriteDetector detector(block->alloc_buffers, buffer_var_map); detector(block); Array writes = detector.CollectWrites(); std::unordered_set excluded_buffers; @@ -366,7 +380,7 @@ Array> GetBlockAccessRegion(const Block& block, Array> GetBlockReadWriteRegion(const Block& block, const Map& buffer_var_map) { - BlockReadWriteDetector detector(buffer_var_map); + BlockReadWriteDetector detector(block->alloc_buffers, buffer_var_map); detector(block); Array opaques = detector.CollectOpaques(); std::unordered_set excluded_buffers; diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index b64351186a..4849c926ed 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -22,9 +22,8 @@ namespace tvm { namespace tir { static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of - 'A[i, j, k, ...] = f(i, j, k, ...)', -where the indices on the left are distinct atomic variables, -and there should be no variables other than the index variables)"; + 'A[f(i, j, k, ...)] = g(i, j, k, ...)', +where the store indices mapping f on the left are bijective affine.)"; static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of `B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`, @@ -284,31 +283,6 @@ class BaseInliner : public StmtExprMutator { return std::move(tgt_block); } - /*! - * \brief Count the number of undefined variables that are not used - * as buffer objects. - * - * This is used to determine whether inlining or reverse inlining is - * possible. The only undefined variables present should be the - * load/store indices, or buffer access based on those indices. - * - * \param stmt The statement in which to count undefined variables - */ - static int GetNumUndefinedNonpointerVars(const Stmt& stmt) { - auto undefined_vars = UndefinedVars(stmt, {}); - // Buffer pointers and the inlined indices are allowed, but no - // other variables may appear in the inlined block. - int num_nonpointer_vars = 0; - for (const auto& var : undefined_vars) { - bool is_pointer = var->dtype.is_handle() && var->type_annotation.defined() && - var->type_annotation.as(); - if (!is_pointer) { - num_nonpointer_vars++; - } - } - return num_nonpointer_vars; - } - private: /*! * \brief Add the buffers in the block signature to the `buffer_var_map_`, @@ -406,7 +380,7 @@ class BaseInliner : public StmtExprMutator { /*! \brief Maps a buffer's data field to itself */ Map buffer_var_map_; /*! \brief The indices used for indexing the buffer to be inlined */ - std::vector idx_vars_; + std::vector idx_vars_; /*! \brief The mapping to substitute index variables to PrimExprs */ std::unordered_map idx_sub_; @@ -443,10 +417,62 @@ class ComputeInliner : public BaseInliner { return false; } - int n_vars = GetNumUndefinedNonpointerVars(GetRef(inlined_store_)); - if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) { + // Fast path on trivial case: + // Check the store indices are same with the block iters; + store_value_ = inlined_store_->value; + size_t num_iters = producer_block->iter_vars.size(); + size_t buffer_ndim = inlined_store_->indices.size(); + if (num_iters == buffer_ndim) { + std::vector idx_vars; + idx_vars.reserve(num_iters); + for (size_t i = 0; i < num_iters; ++i) { + const IterVar& iter = producer_block->iter_vars[i]; + const PrimExpr& e = inlined_store_->indices[i]; + if (e.same_as(iter->var) || + (analyzer_.CanProveEqual(e, 0) && analyzer_.CanProveEqual(iter->dom->min, 0) && + analyzer_.CanProveEqual(iter->dom->extent, 1))) { + idx_vars.push_back(iter->var); + } else { + break; + } + } + if (idx_vars.size() == num_iters) { + // match success + idx_vars_ = std::move(idx_vars); + return true; + } + } + + // If the mapping for store indices is non-trivial + // check bijective mapping from producer iter var to store indices + Map producer_iter_doms; + for (const auto& iter : producer_block->iter_vars) { + producer_iter_doms.Set(iter->var, iter->dom); + } + auto res = arith::DetectIterMap( + /*indices=*/inlined_store_->indices, + /*input_iters=*/producer_iter_doms, + /*predicate=*/true, + /*check_level=*/arith::IterMapLevel::Bijective, + /*analyzer=*/&analyzer_, + /*simplify_trivial_iterators=*/false); + if (res->indices.empty()) { + // Failure: indices of BufferStore are not bijective affine return false; } + idx_vars_.resize(buffer_ndim); + for (size_t i = 0; i < idx_vars_.size(); ++i) { + idx_vars_[i] = Var("ph_" + std::to_string(i), inlined_store_->indices[i].dtype()); + } + auto inverse_iter_map = arith::InverseAffineIterMap( + res->indices, Array(idx_vars_.begin(), idx_vars_.end())); + for (const auto& iter : producer_block->iter_vars) { + if (is_const_int(iter->dom->min) && analyzer_.CanProveEqual(iter->dom->extent, 1)) { + // fallback mapping for constant iters + inverse_iter_map.Set(iter->var, iter->dom->min); + } + } + store_value_ = Substitute(store_value_, inverse_iter_map); return true; } @@ -464,45 +490,7 @@ class ComputeInliner : public BaseInliner { PrimExpr ReplaceInlinedBuffer(BufferLoad load) { SetIndexSubstitution(load->indices); - return Substitute(inlined_store_->value, idx_sub_); - } - - /*! - * \brief Check if the indices are atomic distinct variables and the access is n-dimensional. - * If so, set `self->idx_vars_` properly. - * \param indices The indices to be extracted - * \param expected_ndim The expected ndim of the access - * \return A boolean flag indicating if the check is successful - */ - bool UpdateAndCheckIndexVars(const Array& indices, int expected_ndim) { - int n = indices.size(); - if (n != expected_ndim) { - // Failure: dimension mismatch - return false; - } - std::vector result; - result.reserve(n); - for (const PrimExpr& i : indices) { - if (const auto* var = i.as()) { - result.push_back(var); - } else { - // Failure: indexing expression is not a variable - return false; - } - } - using DistinctSet = std::unordered_set; - int n_distinct = DistinctSet(result.begin(), result.end()).size(); - if (n != n_distinct) { - // Failure: indexing variables are not distinct - return false; - } - if (idx_vars_.empty()) { - idx_vars_ = std::move(result); - } else if (!support::ArrayWithSameContent(idx_vars_, result)) { - // Failure: indexing variables are not consitent in different BufferLoads - return false; - } - return true; + return Substitute(store_value_, idx_sub_); } /*! @@ -512,11 +500,17 @@ class ComputeInliner : public BaseInliner { void SetIndexSubstitution(const Array& indices) { ICHECK_EQ(indices.size(), idx_vars_.size()); int n = idx_vars_.size(); - idx_sub_.reserve(n); for (int i = 0; i < n; ++i) { - idx_sub_[idx_vars_[i]] = indices[i]; + idx_sub_[idx_vars_[i].get()] = indices[i]; } } + + /*! \brief The arithmetic analyzer */ + arith::Analyzer analyzer_; + /*! \brief The store value for inlinement. If the producer + store indices are trivial, it is wrt the producer block iter var, + otherwise it is wrt to the placeholder vars of store indices. */ + PrimExpr store_value_; }; /*! diff --git a/tests/python/relax/test_dyn_compute_inline.py b/tests/python/relax/test_dyn_compute_inline.py new file mode 100644 index 0000000000..74564e8882 --- /dev/null +++ b/tests/python/relax/test_dyn_compute_inline.py @@ -0,0 +1,330 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +import tvm.testing +from tvm import relax, te, tir, meta_schedule as ms +from tvm.script import relax as R, tir as T, ir as I + + +def test_matmul(): + @T.prim_func + def main(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(var_rxplaceholder, (n, 4)) + B = T.match_buffer(var_rxplaceholder_1, (4, n)) + C = T.match_buffer(var_matmul, (n, n)) + # with T.block("root"): + A_pad = T.alloc_buffer(((n + 31) // 32 * 32, 4)) + B_pad = T.alloc_buffer((4, (n + 31) // 32 * 32)) + C_pad = T.alloc_buffer(((n + 31) // 32 * 32, (n + 31) // 32 * 32)) + for i0, i1 in T.grid((n + 31) // 32 * 32, 4): + with T.block("A_pad"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v0, v1]) + T.writes(A_pad[v0, v1]) + A_pad[v0, v1] = T.if_then_else(v0 < n, A[v0, v1], T.float32(0)) + for i0, i1 in T.grid(4, (n + 31) // 32 * 32): + with T.block("B_pad"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + T.reads(B[v0, v1]) + T.writes(B_pad[v0, v1]) + B_pad[v0, v1] = T.if_then_else(v1 < n, B[v0, v1], T.float32(0)) + for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): + with T.block("matmul_o"): + v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) + T.reads( + A_pad[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], + B_pad[0:4, v_i1_o * 32 : v_i1_o * 32 + 32], + ) + T.writes(C_pad[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) + A_shared = T.alloc_buffer((32, 4), scope="shared") + for ax0, ax1 in T.grid(32, 4): + with T.block("A_shared"): + v0 = T.axis.spatial(32, ax0) + v1 = T.axis.spatial(4, ax1) + T.reads(A_pad[v_i0_o * 32 + v0, v1]) + T.writes(A_shared[v0, v1]) + A_shared[v0, v1] = A_pad[v_i0_o * 32 + v0, v1] + for i0_1, i1_1, k in T.grid(32, 32, 4): + with T.block("matmul"): + v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) + T.reads( + A_shared[v_i0_o * 32 + v_i0_i, v_k_i], + B_pad[v_k_i, v_i1_o * 32 + v_i1_i], + ) + T.writes(C_pad[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i]) + with T.init(): + C_pad[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = T.float32(0) + C_pad[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = ( + C_pad[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] + + A_shared[v_i0_o * 32 + v_i0_i, v_k_i] + * B_pad[v_k_i, v_i1_o * 32 + v_i1_i] + ) + for i0, i1 in T.grid(n, n): + with T.block("C_pad"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + T.reads(C_pad[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_pad[v0, v1] + + @T.prim_func + def expected(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + A = T.match_buffer(var_rxplaceholder, (n, 4)) + B = T.match_buffer(var_rxplaceholder_1, (4, n)) + C = T.match_buffer(var_matmul, (n, n)) + # with T.block("root"): + B_pad = T.alloc_buffer((4, (n + 31) // 32 * 32)) + C_pad = T.alloc_buffer(((n + 31) // 32 * 32, (n + 31) // 32 * 32)) + for i0, i1 in T.grid(4, (n + 31) // 32 * 32): + with T.block("B_pad"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + T.reads(B[v0, v1]) + T.writes(B_pad[v0, v1]) + B_pad[v0, v1] = T.if_then_else(v1 < n, B[v0, v1], T.float32(0)) + for i0_0, i1_0 in T.grid((n + 31) // 32, (n + 31) // 32): + with T.block("matmul_o"): + v_i0_o, v_i1_o = T.axis.remap("SS", [i0_0, i1_0]) + T.reads( + A[v_i0_o * 32 : v_i0_o * 32 + 32, 0:4], + B_pad[0:4, v_i1_o * 32 : v_i1_o * 32 + 32], + ) + T.writes(C_pad[v_i0_o * 32 : v_i0_o * 32 + 32, v_i1_o * 32 : v_i1_o * 32 + 32]) + A_shared = T.alloc_buffer((32, 4), scope="shared") + for ax0, ax1 in T.grid(32, 4): + with T.block("A_shared"): + v0 = T.axis.spatial(32, ax0) + v1 = T.axis.spatial(4, ax1) + T.reads(A[v_i0_o * 32 + v0, v1]) + T.writes(A_shared[v0, v1]) + A_shared[v0, v1] = T.if_then_else( + v_i0_o * 32 + v0 < n, A[v_i0_o * 32 + v0, v1], T.float32(0) + ) + for i0_1, i1_1, k in T.grid(32, 32, 4): + with T.block("matmul"): + v_i0_i, v_i1_i, v_k_i = T.axis.remap("SSR", [i0_1, i1_1, k]) + T.reads( + A_shared[v_i0_o * 32 + v_i0_i, v_k_i], + B_pad[v_k_i, v_i1_o * 32 + v_i1_i], + ) + T.writes(C_pad[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i]) + with T.init(): + C_pad[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = T.float32(0) + C_pad[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] = ( + C_pad[v_i0_o * 32 + v_i0_i, v_i1_o * 32 + v_i1_i] + + A_shared[v_i0_o * 32 + v_i0_i, v_k_i] + * B_pad[v_k_i, v_i1_o * 32 + v_i1_i] + ) + for i0, i1 in T.grid(n, n): + with T.block("C_pad"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + T.reads(C_pad[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_pad[v0, v1] + + sch = tir.Schedule(main, debug_mask="all") + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="A_pad", func_name="main") + sch.compute_inline(b1) + tvm.ir.assert_structural_equal(sch.mod["main"], expected) + + +def test_norm_s3(): + @I.ir_module + class ModAfterS2: + @T.prim_func + def main( + var_A: T.handle, + var_weight: T.Buffer((T.int64(4096),), "float32"), + var_rms_norm: T.handle, + ): + T.func_attr( + {"op_pattern": 4, "tir.noalias": T.bool(True), "tir_var_upper_bound": {"n": 2048}} + ) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096))) + rms_norm = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096))) + sq_sum = T.alloc_buffer((T.int64(1), n)) + + A_pad = T.alloc_buffer( + [T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)], + dtype="float32", + ) + sq_sum_pad = T.alloc_buffer( + [T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32)], dtype="float32" + ) + + # pad A + for bsz, i, k in T.grid( + T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096) + ): + with T.block("A_pad"): + v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) + T.reads(A[v_bsz, v_i, v_k]) + T.writes(A_pad[v_bsz, v_i, v_k]) + A_pad[v_bsz, v_i, v_k] = T.if_then_else( + v_i < n, A[v_bsz, v_i, v_k], T.float32(0) + ) + + # compute on padded buffers + for i_0 in range((n + T.int64(31)) // T.int64(32)): + with T.block("compute_o"): + v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) + v_i_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i_0) + T.reads( + A_pad[ + v_bsz, + v_i_o * T.int64(32) : v_i_o * T.int64(32) + T.int64(32), + T.int64(0) : T.int64(4096), + ] + ) + T.writes( + sq_sum_pad[v_bsz, v_i_o * T.int64(32) : v_i_o * T.int64(32) + T.int64(32)] + ) + sq_sum_pad_local = T.alloc_buffer([T.int64(32)], dtype="float32", scope="local") + for bsz, i_1, k in T.grid(T.int64(1), T.int64(32), T.int64(4096)): + with T.block("compute"): + v_i_i, v_k_i = T.axis.remap("SR", [i_1, k]) + T.reads(A_pad[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i]) + T.writes(sq_sum_pad_local[v_i_i]) + with T.init(): + sq_sum_pad_local[v_i_i] = T.float32(0) + sq_sum_pad_local[v_i_i] = ( + sq_sum_pad_local[v_i_i] + + A_pad[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i] + * A_pad[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i] + ) + for bsz, i_1 in T.grid(T.int64(1), T.int64(32)): + with T.block("compute_cache_write"): + v_i_i = T.axis.remap("S", [i_1]) + T.reads(sq_sum_pad_local[v_i_i]) + T.writes(sq_sum_pad[v_bsz, v_i_o * T.int64(32) + v_i_i]) + sq_sum_pad[v_bsz, v_i_o * T.int64(32) + v_i_i] = sq_sum_pad_local[v_i_i] + + # write back to sq_sum + for bsz, i in T.grid(T.int64(1), n): + with T.block("sq_sum_pad"): + v_bsz, v_i = T.axis.remap("SS", [bsz, i]) + T.reads(sq_sum_pad[v_bsz, v_i]) + T.writes(sq_sum[v_bsz, v_i]) + sq_sum[v_bsz, v_i] = sq_sum_pad[v_bsz, v_i] + + for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("rms_norm"): + v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) + T.reads(var_weight[v_k], A[v_bsz, v_i, v_k], sq_sum[v_bsz, v_i]) + T.writes(rms_norm[v_bsz, v_i, v_k]) + rms_norm[v_bsz, v_i, v_k] = var_weight[v_k] * ( + A[v_bsz, v_i, v_k] + / T.sqrt( + sq_sum[v_bsz, v_i] * T.float32(0.000244140625) + + T.float32(9.9999999999999995e-07) + ) + ) + + @I.ir_module + class ModAfterS3: + @T.prim_func + def main( + var_A: T.handle, + var_weight: T.Buffer((T.int64(4096),), "float32"), + var_rms_norm: T.handle, + ): + T.func_attr( + {"op_pattern": 4, "tir.noalias": T.bool(True), "tir_var_upper_bound": {"n": 2048}} + ) + n = T.int64() + A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096))) + rms_norm = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096))) + sq_sum = T.alloc_buffer((T.int64(1), n)) + + sq_sum_pad = T.alloc_buffer( + [T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32)], dtype="float32" + ) + + # compute on padded buffers + for i_0 in range((n + T.int64(31)) // T.int64(32)): + with T.block("compute_o"): + v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) + v_i_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i_0) + T.reads( + A[ + v_bsz, + v_i_o * T.int64(32) : v_i_o * T.int64(32) + T.int64(32), + T.int64(0) : T.int64(4096), + ] + ) + T.writes( + sq_sum_pad[v_bsz, v_i_o * T.int64(32) : v_i_o * T.int64(32) + T.int64(32)] + ) + sq_sum_pad_local = T.alloc_buffer([T.int64(32)], dtype="float32", scope="local") + for bsz, i_1, k in T.grid(T.int64(1), T.int64(32), T.int64(4096)): + with T.block("compute"): + v_i_i, v_k_i = T.axis.remap("SR", [i_1, k]) + T.reads(A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i]) + T.writes(sq_sum_pad_local[v_i_i]) + with T.init(): + sq_sum_pad_local[v_i_i] = T.float32(0) + sq_sum_pad_local[v_i_i] = sq_sum_pad_local[v_i_i] + T.if_then_else( + v_i_o * T.int64(32) + v_i_i < n, + A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i], + T.float32(0), + ) * T.if_then_else( + v_i_o * T.int64(32) + v_i_i < n, + A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i], + T.float32(0), + ) + for bsz, i_1 in T.grid(T.int64(1), T.int64(32)): + with T.block("compute_cache_write"): + v_i_i = T.axis.remap("S", [i_1]) + T.reads(sq_sum_pad_local[v_i_i]) + T.writes(sq_sum_pad[v_bsz, v_i_o * T.int64(32) + v_i_i]) + sq_sum_pad[v_bsz, v_i_o * T.int64(32) + v_i_i] = sq_sum_pad_local[v_i_i] + + # write back to sq_sum + for bsz, i in T.grid(T.int64(1), n): + with T.block("sq_sum_pad"): + v_bsz, v_i = T.axis.remap("SS", [bsz, i]) + T.reads(sq_sum_pad[v_bsz, v_i]) + T.writes(sq_sum[v_bsz, v_i]) + sq_sum[v_bsz, v_i] = sq_sum_pad[v_bsz, v_i] + + for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): + with T.block("rms_norm"): + v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) + T.reads(var_weight[v_k], A[v_bsz, v_i, v_k], sq_sum[v_bsz, v_i]) + T.writes(rms_norm[v_bsz, v_i, v_k]) + rms_norm[v_bsz, v_i, v_k] = var_weight[v_k] * ( + A[v_bsz, v_i, v_k] + / T.sqrt( + sq_sum[v_bsz, v_i] * T.float32(0.000244140625) + + T.float32(9.9999999999999995e-07) + ) + ) + + sch = tir.Schedule(ModAfterS2, debug_mask="all") + sch.compute_inline(sch.get_block("A_pad")) + tvm.ir.assert_structural_equal(sch.mod["main"], ModAfterS3["main"]) + + +if __name__ == "__main__": + test_matmul() + test_norm_s3() diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 63df2de231..265928f58d 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -894,6 +894,98 @@ def test_compute_inline_multi_consumer(use_block_name): verify_trace_roundtrip(sch=sch, mod=elementwise_multi_producer_consumer) +def test_compute_inline_layout_transformed_store(): + @T.prim_func + def before(X: T.Buffer[(9, 96), "float32"]): + A = T.alloc_buffer([6, 9, 16], "float32") + B = T.alloc_buffer([3, 3, 96], "float32") + for i, j in T.grid(9, 96): + with T.block("producer"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vj // 16, vi, vj % 16] = X[vi, vj] + for i0, i1, j in T.grid(3, 3, 96): + with T.block("consumer"): + vi0, vi1, vj = T.axis.remap("SSS", [i0, i1, j]) + B[vi0, vi1, vj] = A[vj % 16, vi0 * 3 + vi1, vj // 16] + 1.0 + + @T.prim_func + def after(X: T.Buffer[(9, 96), "float32"]): + B = T.alloc_buffer([3, 3, 96], "float32") + for i0, i1, j in T.grid(3, 3, 96): + with T.block("consumer"): + vi0, vi1, vj = T.axis.remap("SSS", [i0, i1, j]) + B[vi0, vi1, vj] = X[vi0 * 3 + vi1, vj // 16 + vj % 16 * 16] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + sch.compute_inline(sch.get_block("producer")) + tvm.ir.assert_structural_equal(after, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_compute_inline_out_of_bound_consumer(): + """The case is intentionally left for when the producer region + do not cover the consumers. Though the out of bound region values + are not defined, currently the behavior is still generating inline + computation with the rule of producer. + """ + + @T.prim_func + def before(): + A = T.alloc_buffer([8], "int32") + B = T.alloc_buffer([10], "int32") + for i in range(8): + with T.block("producer"): + vi = T.axis.remap("S", [i]) + A[vi] = vi + for i in range(10): + with T.block("consumer"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + 1 + + @T.prim_func + def after(): + B = T.alloc_buffer([10], "int32") + for i in range(10): + with T.block("consumer"): + vi = T.axis.remap("S", [i]) + B[vi] = vi + 1 + + sch = tir.Schedule(before, debug_mask="all") + sch.compute_inline(sch.get_block("producer")) + tvm.ir.assert_structural_equal(after, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_compute_inline_dynamic_shape_producer(): + @T.prim_func + def before(a: T.handle, c: T.handle, m: T.int32, n: T.int32): + A = T.match_buffer(a, (m, n)) + B = T.alloc_buffer((m, n)) + C = T.match_buffer(c, (m, n)) + for i, j in T.grid(m, n): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(m, n): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def after(a: T.handle, c: T.handle, m: T.int32, n: T.int32): + A = T.match_buffer(a, (m, n)) + C = T.match_buffer(c, (m, n)) + for i, j in T.grid(m, n): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + sch.compute_inline(sch.get_block("B")) + tvm.ir.assert_structural_equal(after, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=before) + + def test_compute_inline_fail_multi_writer(use_block_name): sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") block_b = "B" if use_block_name else sch.get_block("B")