diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 73cd283cb2de..17e442354542 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2105,12 +2105,28 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, ICHECK(desc_loops.size() == static_cast(n_desc_vars)); ICHECK(block_loops.size() == iter_types_block.size()); + // We assume that the orders of iter_vars in the target and the desc block are consistent. + // Based on that assumption, the following logic supports arbitrary permutations of a loop order, + // such as + + // for k: + // for i: + // for j: + // C[i, j] += A[i, k] * B[k, j] + + // or + + // for i: + // for j: + // for k: + // C[i, j] += A[i, k] * B[k, j] + int next_block_ind = block_loops.size() - 1; for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { - // Step 4.2. Find the corresponding loop of the i-th block var of desc + // Step 3.1. Find the corresponding loop of the i_desc-th block var of desc const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; const tir::ForNode* desc_loop = nullptr; - IterVarType iter_type_desc; + IterVarType iter_type_desc = iter_types_desc[i_desc]; for (int i = 0, n = desc_loops.size(); i < n; ++i) { // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); @@ -2127,29 +2143,32 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const IntImmNode* int_desc_extent = desc_loop->extent.as(); - const tir::ForNode* block_loop = nullptr; - + // Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type PrimExpr block_bind; - for (int i_block = next_block_ind; i_block >= 0; --i_block) { - if (iter_types_block[i_block] == iter_type_desc) { - next_block_ind = i_block - 1; - block_bind = block->iter_values[i_block]; + for (int i = next_block_ind; i >= 0; --i) { + if (iter_types_block[i] == iter_type_desc) { + next_block_ind = i - 1; + block_bind = block->iter_values[i]; break; } } + if (!block_bind.defined()) return NullOpt; + + // Step 3.3. Find the corresponding loop of the target block for (int i = 0, n = block_loops.size(); i < n; ++i) { + // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); if (!UsesVar(r, [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { - block_loop = block_loops[i]; - const IntImmNode* int_block_extent = block_loop->extent.as(); + const IntImmNode* int_block_extent = block_loops[i]->extent.as(); + // Check divisibility if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { return NullOpt; } - const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loops[i]]; auto it = ret->loop_map.find(block_loop_sref); if (it == ret->loop_map.end()) { ret->loop_map.Set(block_loop_sref, GetRef(desc_loop));