diff --git a/.gitignore b/.gitignore index b9357018a64c..506e54d93067 100644 --- a/.gitignore +++ b/.gitignore @@ -196,6 +196,7 @@ tvm_t.* .python_history .pytest_cache .local +cmake-build-debug # Visual Studio Code .vscode diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index af72d3b1a1df..f130cb438113 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -451,7 +451,7 @@ Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { } } -void RebaseNonZeroMinLoop(const Schedule& sch) { +void RebaseNonZeroMinLoop(ScheduleNode* sch) { std::unordered_map rebase_map; for (Stage s : sch->stages) { if (s->attach_type == kInlinedAlready) continue; @@ -614,10 +614,85 @@ void InjectInline(ScheduleNode* sch) { } } +void LegalizeInvalidAttach(ScheduleNode* sch) { + // Legalize the compute_at location if the target iterator of compute_at is split or fused. + // Case 1: If the target of compute_at is split, + // we will move the compute_at location to the inner iterator. + // Case 2: If the target of compute_at is fused, + // we will move the compute_at location to the newly fused iterator. + // Note that case 2 can only happen if the target of compute_at + // is the innermost operand of fuse operation. + + // Map an old invalid attach point to its new valid attach point + std::unordered_map replace_map; + + for (Stage stage : sch->stages) { + for (Stage s = stage; s.defined();) { + // The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`, + // because we follow the validation check in that function to legalize the attach. + Stage spec = s.GetAttachSpec(); + if (spec->attach_type != kScope) { + break; + } + bool start_attach = false; + IterVar attach_ivar = spec->attach_ivar; + s = spec->attach_stage; + CHECK(attach_ivar.defined()); + CHECK(s.defined()); + + for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { + IterVar iv = s->leaf_iter_vars[i - 1]; + if (!start_attach && iv.same_as(attach_ivar)) { + start_attach = true; + break; + } + } + + if (!start_attach) { + IterVar new_attach_ivar = attach_ivar; + bool updated = true; + // recursively update the relations + while (updated) { + updated = false; + for (const auto& rel : s->relations) { + if (const FuseNode* r = rel.as()) { + if (new_attach_ivar.same_as(r->inner)) { + new_attach_ivar = r->fused; + updated = true; + } + } else if (const SplitNode* r = rel.as()) { + if (new_attach_ivar.same_as(r->parent)) { + new_attach_ivar = r->inner; + updated = true; + } + } + } + replace_map[attach_ivar] = new_attach_ivar; + } + } + } + } + + // remap the parent relation + for (Stage s : sch->stages) { + if (s->attach_type != kScope) continue; + if (replace_map.count(s->attach_ivar)) { + s->attach_ivar = replace_map.at(s->attach_ivar); + } + } + for (Stage s : sch->groups) { + if (s->attach_type != kScope) continue; + if (replace_map.count(s->attach_ivar)) { + s->attach_ivar = replace_map.at(s->attach_ivar); + } + } +} + Schedule Schedule::normalize() { Schedule sn = copy(); InjectInline(sn.operator->()); - RebaseNonZeroMinLoop(sn); + RebaseNonZeroMinLoop(sn.operator->()); + LegalizeInvalidAttach(sn.operator->()); return sn; } diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 2c851cc39789..c00ee70586ef 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -289,6 +289,25 @@ def intrin_func(ins, outs, sp): assert str(stmt.body.body.value.args[3]) == "(i: int32*i)" assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)" +def test_legalize_invalid_attach(): + A = te.compute((10, 10), lambda i, j: 1.0, name='A') + B = te.compute((10, 10), lambda i, j: A[i][j], name='B') + + # Case 1: Split an axis which is the target of a compute_at + s = te.create_schedule([B.op]) + s[A].compute_at(s[B], B.op.axis[1]) + s[B].split(B.op.axis[1], 2) + + stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body + assert isinstance(stmt.body.body, tvm.tir.stmt.For) + + # Case 2: Fuse an axis which is the target of a compute_at + s = te.create_schedule([B.op]) + s[A].compute_at(s[B], B.op.axis[1]) + s[B].fuse(B.op.axis[0], B.op.axis[1]) + stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body + assert isinstance(stmt, tvm.tir.stmt.For) + if __name__ == "__main__": test_singleton() test_pragma() @@ -305,3 +324,4 @@ def intrin_func(ins, outs, sp): test_fuse_with_out_of_order_axis_with_reorder() test_vectorize() test_vectorize_commreduce() + test_legalize_invalid_attach()