From 4a893e47ce7fe477202850d0865241feff3a0ce8 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 22 Apr 2022 15:12:36 -0700 Subject: [PATCH 1/4] Fix pipeline validation --- .../transforms/inject_software_pipeline.cc | 69 +++++++++-- ..._tir_transform_inject_software_pipeline.py | 113 ++++++++++++++++++ 2 files changed, 174 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index b607ba485a6a..aafbeea5aad5 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -570,6 +570,40 @@ class PipelineRewriter : public StmtExprMutator { Array ordered_stmts_; }; +/*! + * \brief Build the dependency graph among a array of blocks. + * \param[in] seq The SeqStmt + * \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the + * destination. + * \param[out] dep_dst2src Optional, a map to store dependency edges from the + * destination to the source. + */ +void BuildDependencyGraph( + const Array& blocks, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + + for (const Block& block : blocks) { + for (const BufferRegion& read : block->reads) { + auto it = buffer_writers.find(read->buffer->data); + if (it != buffer_writers.end()) { + for (const Block& writer : it->second) { + if (dep_src2dst != nullptr) { + (*dep_src2dst)[writer].push_back(block); + } + if (dep_dst2src != nullptr) { + (*dep_dst2src)[block].push_back(writer); + } + } + } + } + for (const BufferRegion& write : block->writes) { + buffer_writers[write->buffer->data].push_back(block); + } + } +} + class PipelineInjector : private StmtExprMutator { public: static Stmt Inject(const PrimFunc& func) { @@ -587,24 +621,43 @@ class PipelineInjector : private StmtExprMutator { /*! * \brief Check the pipeline satisfies the following conditions: - * 1) No conflicting order: The order of each statement should be unique. - * 2) No reordering with the same stage: Statements in the same stage are not allowed to be - * reordered. + * 1. No conflicting order: The order of each statement should be unique. + * 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for + * dependency (e.g. read-after-write) from statement A to statement B, it requires: + * case 1: stage(A) < stage(B) + * case 2: stage(A) == stage(B) and order(A) < order(B) */ void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array& original_order) { std::unordered_set used_orders; std::unordered_map stage_max_order; + std::unordered_map order_to_block; + std::unordered_map block_to_stage; for (const Block& block : original_order) { const auto& stmt_info = pipeline_info.at(block); - int stage = stmt_info.stage; int order = stmt_info.order; CHECK(!used_orders.count(order)) << "ValueError: Two statements in the software pipeline cannot have the same order"; used_orders.insert(order); - CHECK(!stage_max_order.count(stage) || stage_max_order[stage] < order) - << "ValueError: Statements in the same stage of the software pipeline must have " - "increasing order."; - stage_max_order[stage] = order; + } + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; + BuildDependencyGraph(original_order, &dep_src2dst, nullptr); + + for (const auto& pair : dep_src2dst) { + const Block& src = pair.first; + const auto& src_info = pipeline_info.at(src); + const Array& dsts = pair.second; + for (const Block& dst : dsts) { + const auto& dst_info = pipeline_info.at(dst); + CHECK_LE(src_info.stage, dst_info.stage) + << "ValueError: statement " << dst << " in stage " << dst_info.stage + << " cannot depends on statement " << src << " in a later stage " << src_info.stage; + if (src_info.stage == dst_info.stage) { + CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer " + "access dependency in the same stage of the " + "software pipeline cannot be reordered"; + } + } } } diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 1432be4efbe1..f71d02fdad9b 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -132,6 +132,115 @@ def transformed_simple_compute( C[tx, 15] = B[1, tx, 0] + T.float32(1) +@T.prim_func +def dag_interleaving( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 0, 0, 0, 1], + "software_pipeline_order": [0, 2, 1, 3, 4], + }, + ): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + AS = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + BS = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + AL = T.alloc_buffer((1, 1), dtype="float32", scope="local") + BL = T.alloc_buffer((1, 1), dtype="float32", scope="local") + with T.block(): + T.reads(A[tx, i]) + T.writes(AS[tx, 0]) + AS[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(AS[tx, 0]) + T.writes(AL[0, 0]) + AL[0, 0] = AS[tx, 0] + with T.block(): + T.reads(B[tx, i]) + T.writes(BS[tx, 0]) + BS[tx, 0] = B[tx, i] + T.float32(2) + with T.block(): + T.reads(BS[tx, 0]) + T.writes(BL[0, 0]) + BL[0, 0] = BS[tx, 0] + with T.block(): + T.reads(AL[0, 0], BL[0, 0]) + T.writes(C[tx, i]) + C[tx, i] = AL[0, 0] * BL[0, 0] + + +@T.prim_func +def transformed_dag_interleaving( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0:16], B[tx, 0:16]) + T.writes(C[tx, 0:16]) + AS = T.alloc_buffer([16, 1], dtype="float32", scope="shared") + BS = T.alloc_buffer([16, 1], dtype="float32", scope="shared") + AL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local") + BL = T.alloc_buffer([2, 1, 1], dtype="float32", scope="local") + with T.block(): + T.reads(A[tx, 0], B[tx, 0], AS[tx, 0], BS[tx, 0]) + T.writes(AS[tx, 0], BS[tx, 0], AL[0, 0, 0], BL[0, 0, 0]) + with T.block(): + T.reads(A[tx, 0]) + T.writes(AS[tx, 0]) + AS[tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(BS[tx, 0]) + BS[tx, 0] = B[tx, 0] + T.float32(2) + with T.block(): + T.reads(AS[tx, 0]) + T.writes(AL[0, 0, 0]) + AL[0, 0, 0] = AS[tx, 0] + with T.block(): + T.reads(BS[tx, 0]) + T.writes(BL[0, 0, 0]) + BL[0, 0, 0] = BS[tx, 0] + with T.block(): + T.reads( + A[tx, 1:16], B[tx, 1:16], AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0] + ) + T.writes(AS[tx, 0], BS[tx, 0], AL[0:2, 0, 0], BL[0:2, 0, 0], C[tx, 0:15]) + for i in T.serial(15): + with T.block(): + T.reads(A[tx, i + 1]) + T.writes(AS[tx, 0]) + AS[tx, 0] = A[tx, i + 1] * T.float32(2) + with T.block(): + T.reads(B[tx, i + 1]) + T.writes(BS[tx, 0]) + BS[tx, 0] = B[tx, i + 1] + T.float32(2) + with T.block(): + T.reads(AS[tx, 0]) + T.writes(AL[(i + 1) % 2, 0, 0]) + AL[(i + 1) % 2, 0, 0] = AS[tx, 0] + with T.block(): + T.reads(BS[tx, 0]) + T.writes(BL[(i + 1) % 2, 0, 0]) + BL[(i + 1) % 2, 0, 0] = BS[tx, 0] + with T.block(): + T.reads(AL[i % 2, 0, 0], BL[i % 2, 0, 0]) + T.writes(C[tx, i]) + C[tx, i] = AL[i % 2, 0, 0] * BL[i % 2, 0, 0] + with T.block(): + T.reads(AL[1, 0, 0], BL[1, 0, 0]) + T.writes(C[tx, 15]) + C[tx, 15] = AL[1, 0, 0] * BL[1, 0, 0] + + @T.prim_func def nested_pipeline_simple( A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"] @@ -792,6 +901,10 @@ def test_trivial_pipeline(): _check(trivial_pipeline, transformed_trivial_pipeline) +def test_dag_interleaving(): + _check(dag_interleaving, transformed_dag_interleaving) + + def test_nest_pipeline_simple(): _check(nested_pipeline_simple, transformed_nested_pipeline_simple) From 735b6194b6f3869e981ceb31b2f6723cdf5d6fd9 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 22 Apr 2022 17:24:30 -0700 Subject: [PATCH 2/4] fix predicate --- .../transforms/inject_software_pipeline.cc | 5 +- ..._tir_transform_inject_software_pipeline.py | 89 +++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index aafbeea5aad5..bdd9fc3ce473 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -534,7 +534,10 @@ class PipelineRewriter : public StmtExprMutator { subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var); } else { // normalize loop range - subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + (start - pipeline_loop_->min)); + PrimExpr delta = start - pipeline_loop_->min; + subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + delta); + Var loop_iter = Downcast(new_loop_var); + inbound = Substitute(inbound, Map{{loop_iter, loop_iter + delta}}); } new_block = Downcast(Substitute(new_block, subst_map)); stmts.push_back(BlockRealize({}, inbound, new_block)); diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index f71d02fdad9b..85c39117968e 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -27,6 +27,7 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) mod = tvm.tir.transform.Simplify()(mod) + print(mod["main"].script()) tvm.ir.assert_structural_equal(mod["main"], transformed, True) @@ -132,6 +133,90 @@ def transformed_simple_compute( C[tx, 15] = B[1, tx, 0] + T.float32(1) +@T.prim_func +def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1, 2], + "software_pipeline_order": [0, 1, 2], + }, + ): + with T.block(): + T.reads(A[tx, i]) + T.writes(D[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = A[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[tx, 0]) + T.writes(D[tx, i]) + D[tx, i] = C[tx, 0] + T.float32(1) + + +@T.prim_func +def transformed_three_stage_compute( + A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"] +) -> None: + for tx in T.thread_binding(16, thread="threadIdx.x"): + with T.block(): + T.reads(A[tx, 0:16]) + T.writes(D[tx, 0:16]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, 0:2], B[0:2, tx, 0]) + T.writes(B[0:2, tx, 0], C[0:2, tx, 0]) + for i in T.unroll(2): + with T.block(): + T.reads(A[tx, i]) + T.writes(B[0:2, tx, 0]) + B[i, tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.where(1 <= i) + T.reads(B[0:2, tx, 0]) + T.writes(C[0:2, tx, 0]) + C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + with T.block(): + T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) + T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) + for i in T.serial(14): + with T.block(): + T.reads(A[tx, i + 2]) + T.writes(B[0:2, tx, 0]) + B[i % 2, tx, 0] = A[tx, i + 2] * T.float32(2) + with T.block(): + T.reads(B[0:2, tx, 0]) + T.writes(C[0:2, tx, 0]) + C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[0:2, tx, 0]) + T.writes(D[tx, i]) + D[tx, i] = C[i % 2, tx, 0] + T.float32(1) + with T.block(): + T.reads(B[0:2, tx, 0], C[0:2, tx, 0]) + T.writes(C[0:2, tx, 0], D[tx, 14:16]) + for i in T.unroll(2): + with T.block(): + T.where(i < 1) + T.reads(B[0:2, tx, 0]) + T.writes(C[0:2, tx, 0]) + C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[0:2, tx, 0]) + T.writes(D[tx, i + 14]) + D[tx, i + 14] = C[i, tx, 0] + T.float32(1) + + @T.prim_func def dag_interleaving( A: T.Buffer[(16, 16), "float32"], @@ -901,6 +986,10 @@ def test_trivial_pipeline(): _check(trivial_pipeline, transformed_trivial_pipeline) +def test_three_stage_compute(): + _check(three_stage_compute, transformed_three_stage_compute) + + def test_dag_interleaving(): _check(dag_interleaving, transformed_dag_interleaving) From 486c577a0f5c929efd357372f8458f11747f1276 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 22 Apr 2022 17:35:47 -0700 Subject: [PATCH 3/4] Update test_tir_transform_inject_software_pipeline.py --- .../unittest/test_tir_transform_inject_software_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 85c39117968e..ff7e79c02352 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -27,7 +27,6 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) mod = tvm.tir.transform.Simplify()(mod) - print(mod["main"].script()) tvm.ir.assert_structural_equal(mod["main"], transformed, True) From 55a869867fd9bf51642d0ac76fd37311d181414a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 22 Apr 2022 17:36:26 -0700 Subject: [PATCH 4/4] Update inject_software_pipeline.cc --- src/tir/transforms/inject_software_pipeline.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index bdd9fc3ce473..7402d6426bc2 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -575,7 +575,7 @@ class PipelineRewriter : public StmtExprMutator { /*! * \brief Build the dependency graph among a array of blocks. - * \param[in] seq The SeqStmt + * \param[in] blocks The array of blocks. * \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the * destination. * \param[out] dep_dst2src Optional, a map to store dependency edges from the