diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 630a72cedee5..e6354c40c4f6 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -24,12 +24,13 @@ namespace tir { static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of 'A[i, j, k, ...] = f(i, j, k, ...)', where the indices on the left are distinct atomic variables, -and there should not no variables other than the index variables)"; +and there should be no variables other than the index variables)"; static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of - `B[...] = g(i, j, k, A[i, j, k, ...] ...)`, + `B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`, where A is the only buffer the block consumes, whose indices are distinct atomic variables, -and there should not no variables other than the index variables)"; +and there should be no variables other than the index variables), and f is a bijective affine +mapping)"; class HasInitBlock : public ScheduleError { public: @@ -257,57 +258,6 @@ class BaseInliner : public StmtExprMutator { return std::move(tgt_block); } - /*! - * \brief Check if the indices are atomic distinct variables and the access is n-dimensional. - * If so, set `self->idx_vars_` properly. - * \param indices The indices to be extracted - * \param expected_ndim The expected ndim of the access - * \return A boolean flag indicating if the check is successful - */ - bool UpdateAndCheckIndexVars(const Array& indices, int expected_ndim) { - int n = indices.size(); - if (n != expected_ndim) { - // Failure: dimension mismatch - return false; - } - std::vector result; - result.reserve(n); - for (const PrimExpr& i : indices) { - if (const auto* var = i.as()) { - result.push_back(var); - } else { - // Failure: indexing expression is not a variable - return false; - } - } - using DistinctSet = std::unordered_set; - int n_distinct = DistinctSet(result.begin(), result.end()).size(); - if (n != n_distinct) { - // Failure: indexing variables are not distinct - return false; - } - if (idx_vars_.empty()) { - idx_vars_ = std::move(result); - } else if (!support::ArrayWithSameContent(idx_vars_, result)) { - // Failure: indexing variables are not consitent in different BufferLoads - return false; - } - return true; - } - - /*! - * \brief Set the mapping of index substitution `self->idx_sub_` - * \param indices The expressions that the corresponding index variables are replaced to - */ - void SetIndexSubstitution(const Array& indices) { - ICHECK_EQ(indices.size(), idx_vars_.size()); - int n = idx_vars_.size(); - idx_sub_.reserve(n); - for (int i = 0; i < n; ++i) { - idx_sub_[idx_vars_[i]] = indices[i]; - } - } - /*! * \brief Count the number of undefined variables that are not used * as buffer objects. @@ -490,6 +440,57 @@ class ComputeInliner : public BaseInliner { SetIndexSubstitution(load->indices); return Substitute(inlined_store_->value, idx_sub_); } + + /*! + * \brief Check if the indices are atomic distinct variables and the access is n-dimensional. + * If so, set `self->idx_vars_` properly. + * \param indices The indices to be extracted + * \param expected_ndim The expected ndim of the access + * \return A boolean flag indicating if the check is successful + */ + bool UpdateAndCheckIndexVars(const Array& indices, int expected_ndim) { + int n = indices.size(); + if (n != expected_ndim) { + // Failure: dimension mismatch + return false; + } + std::vector result; + result.reserve(n); + for (const PrimExpr& i : indices) { + if (const auto* var = i.as()) { + result.push_back(var); + } else { + // Failure: indexing expression is not a variable + return false; + } + } + using DistinctSet = std::unordered_set; + int n_distinct = DistinctSet(result.begin(), result.end()).size(); + if (n != n_distinct) { + // Failure: indexing variables are not distinct + return false; + } + if (idx_vars_.empty()) { + idx_vars_ = std::move(result); + } else if (!support::ArrayWithSameContent(idx_vars_, result)) { + // Failure: indexing variables are not consitent in different BufferLoads + return false; + } + return true; + } + + /*! + * \brief Set the mapping of index substitution `self->idx_sub_` + * \param indices The expressions that the corresponding index variables are replaced to + */ + void SetIndexSubstitution(const Array& indices) { + ICHECK_EQ(indices.size(), idx_vars_.size()); + int n = idx_vars_.size(); + idx_sub_.reserve(n); + for (int i = 0; i < n; ++i) { + idx_sub_[idx_vars_[i]] = indices[i]; + } + } }; /*! @@ -534,13 +535,33 @@ class ReverseComputeInliner : public BaseInliner { // Failure: no BufferLoad from the `inlined_buffer_` return false; } - int n_vars = GetNumUndefinedNonpointerVars(GetRef(inlined_store_)); + + // Collect block iter domains and update the substition map + Map consumer_iter_doms; + for (const auto& iter_var : consumer_block->iter_vars) { + consumer_iter_doms.Set(iter_var->var, iter_var->dom); + // Set default mapping for unit iters + if (is_const_int(iter_var->dom->extent, 1) && is_const_int(iter_var->dom->min)) { + idx_sub_[iter_var->var.get()] = iter_var->dom->min; + } + } + for (const BufferLoadNode* load : loads) { - if (!UpdateAndCheckIndexVars(load->indices, n_vars)) { - // Failure: incorrect of inconsistent index vars + if (!UpdateAndCheckIndexExprs(load->indices)) { return false; } } + + buffer_load_iter_map_ = arith::DetectIterMap( + /*indices=*/buffer_load_indices_, + /*input_iters=*/consumer_iter_doms, + /*predicate=*/true, + /*require_bijective=*/true, + /*analyzer=*/&analyzer); + if (buffer_load_iter_map_.empty()) { + // Failure: indices of BufferLoad are not bijective affine + return false; + } return true; } @@ -556,8 +577,20 @@ class ReverseComputeInliner : public BaseInliner { return ReplaceInlinedBuffer(std::move(store)); } + /*! + * \brief Apply the inverse of `buffer_load_iter_map_` to producer indices. Update `idx_sub_` with + * the result. It will be later used to transform the BufferStore indices of the producer. + * \param producer_indices The BufferStore indices of the producer. + */ + void CreateInverseMapping(const Array producer_indices) { + auto inverse_iter_map = arith::InverseAffineIterMap(buffer_load_iter_map_, producer_indices); + for (const auto& pair : inverse_iter_map) { + idx_sub_[pair.first.get()] = pair.second; + } + } + Stmt ReplaceInlinedBuffer(BufferStore producer) { - SetIndexSubstitution(producer->indices); + CreateInverseMapping(producer->indices); producer_rhs_ = producer->value; return Substituter(this)(GetRef(inlined_store_)); } @@ -588,8 +621,31 @@ class ReverseComputeInliner : public BaseInliner { return std::move(extractor.result); } + /*! + * \brief Update `buffer_load_indices_` with the given indices. If `buffer_load_indices_` is + * already non-empty, check it is consistent with the given indices. + * \param indices The indices + * \param expected_ndim The expected ndim of the access + * \return A boolean flag indicating if the check is successful + */ + bool UpdateAndCheckIndexExprs(const Array& indices) { + if (buffer_load_indices_.empty()) { + buffer_load_indices_ = indices; + } else if (!std::equal(buffer_load_indices_.begin(), buffer_load_indices_.end(), indices.begin(), indices.end(), ExprDeepEqual())) { + // Failure: indices are not consistent in different BufferLoads + return false; + } + return true; + } + /*! \brief The RHS value of the producer's BufferStore statement */ PrimExpr producer_rhs_{nullptr}; + /*! \brief The indices of the consumer's BufferLoad */ + Array buffer_load_indices_; + /*! \brief The IterMap representing the indices of the consumer's BufferLoad */ + Array buffer_load_iter_map_{nullptr}; + /*! \brief The arithmetic analyzer */ + arith::Analyzer analyzer; }; void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 8894cd4d9f39..057d808ca4ec 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -169,6 +169,112 @@ def elementwise_multi_reverse_loads_inlined(a: T.handle, c: T.handle) -> None: C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 +@T.prim_func +def elementwise_reverse_affine_load( + A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 32, 8, 8), "float32"] +) -> None: + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j, k, l in T.grid(8, 32, 8, 8): + with T.block("C"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + C[vi, vj, vk, vl] = B[ + ((((vi * 32) + vj) * 8 + vk) * 8 + vl) // 128, + ((((vi * 32) + vj) * 8 + vk) * 8 + vl) % 128, + ] + + +@T.prim_func +def elementwise_reverse_affine_load_inlined( + A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 32, 8, 8), "float32"] +) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[ + (vj + vi * 128) // 2048, + (vj + vi * 128) // 64 % 32, + ((vj + vi * 128) // 8) % 8, + (vj + vi * 128) % 8, + ] = ( + A[vi, vj] * 2.0 + ) + + +@T.prim_func +def elementwise_reverse_affine_load_unit_iter( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(8, 16, 1), "float32"], + D: T.Buffer[(1, 8, 16, 128), "float32"], +) -> None: + C = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] * 2.0 + for i, j, k in T.grid(8, 16, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + D[0, vi, vj, vk] = C[vi * 16 + vj, vk] + B[vi, vj, 0] + + +@T.prim_func +def elementwise_reverse_affine_load_unit_iter_inlined( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(8, 16, 1), "float32"], + D: T.Buffer[(1, 8, 16, 128), "float32"], +) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + D[0, vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + B[vi // 16, vi % 16, 0] + + +@T.prim_func +def elementwise_multi_reverse_affine_load( + A: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(8, 16, 128), "float32"], +) -> None: + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j, k in T.grid(8, 16, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = B[vi * 16 + vj, vk] + B[vi * 16 + vj, vk] + + +@T.prim_func +def elementwise_multi_reverse_affine_load_inlined( + A: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(8, 16, 128), "float32"], +) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi // 16, vi % 16, vj] = A[vi, vj] * 2.0 + A[vi, vj] * 2.0 + + +@T.prim_func +def elementwise_reverse_non_affine_load( + A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(8, 16, 128), "float32"] +) -> None: + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j, k in T.grid(8, 16, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = B[vi * 16 + vj, vi * 16 + vj] + + @T.prim_func def opaque_access_load(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) @@ -520,6 +626,40 @@ def test_reverse_compute_multi_reverse_loads(): verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_loads) +def test_reverse_compute_inline_affine_load(): + sch = tir.Schedule(elementwise_reverse_affine_load, debug_mask="all") + block_c = sch.get_block("C") + sch.reverse_compute_inline(block_c) + tvm.ir.assert_structural_equal(elementwise_reverse_affine_load_inlined, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load) + + +def test_reverse_compute_inline_multi_affine_load(): + sch = tir.Schedule(elementwise_multi_reverse_affine_load, debug_mask="all") + block_c = sch.get_block("C") + sch.reverse_compute_inline(block_c) + tvm.ir.assert_structural_equal(elementwise_multi_reverse_affine_load_inlined, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_multi_reverse_affine_load) + + +def test_reverse_compute_inline_affine_load_unit_iter(): + sch = tir.Schedule(elementwise_reverse_affine_load_unit_iter, debug_mask="all") + block_c = sch.get_block("C") + sch.reverse_compute_inline(block_c) + print(sch.mod.script()) + tvm.ir.assert_structural_equal( + elementwise_reverse_affine_load_unit_iter_inlined, sch.mod["main"] + ) + verify_trace_roundtrip(sch=sch, mod=elementwise_reverse_affine_load_unit_iter) + + +def test_reverse_compute_fail_non_affine_load(): + sch = tir.Schedule(elementwise_reverse_non_affine_load, debug_mask="all") + block_c = sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(block_c) + + def test_reverse_compute_fail_multi_reverse_loads(): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") block_c = sch.get_block("C")