From da58202e0397d2ae1ebae6c22648125c090df2ce Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Mon, 20 Mar 2023 16:24:42 +0530 Subject: [PATCH] [TIR] [Analysis] Expose IsOutputBlock to python This patch just exposes an existing analysis API IsOutputBlock to python. Since many schedule primitives have conditions on output blocks, this API would be really useful while scheduling --- python/tvm/tir/schedule/analysis.py | 19 +++++++++++++++++ src/tir/schedule/analysis/analysis.cc | 5 +++++ .../unittest/test_tir_schedule_analysis.py | 21 +++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index e1c0019d9bf0..15748a99c81a 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -140,3 +140,22 @@ def has_block(sch: Schedule, block_name: str) -> bool: True if the given block exists in the schedule. """ return _ffi_api.HasBlock(sch, block_name) # type: ignore + + +def is_output_block(sch: Schedule, block: BlockRV) -> bool: + """Check whether the given block is an output block + + Parameters + ---------- + sch : Schedule + The schedule object of the block + block : BlockRV + The blockRV to be checked + + Returns + ------- + yes/no : bool + True if the given block is an output block + + """ + return _ffi_api.IsOutputBlock(sch, block) # type: ignore diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index b35d64f125d8..674abe28a3e0 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2071,6 +2071,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") }); TVM_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock); +TVM_REGISTER_GLOBAL("tir.schedule.IsOutputBlock").set_body_typed([](Schedule sch, BlockRV block) { + auto state = sch->state(); + auto block_sref = sch->GetSRef(block); + return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); +}); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 0002de38794b..cd91a44b6518 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -39,6 +39,7 @@ from tvm.tir.stmt_functor import pre_order_visit from tvm.meta_schedule.testing import te_workload from tvm.te import create_prim_func +from tvm.tir.schedule.analysis import is_output_block def _make_vars(*args: str) -> List[Var]: @@ -396,5 +397,25 @@ def test_get_auto_tensorize_mapping_info_matmul(n, m, k, expected): check_index_map(matmul, "C", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, expected) +def test_is_output_block(): + @T.prim_func + def two_elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tvm.tir.Schedule(two_elementwise) + block_rv = sch.get_block("C") + assert is_output_block(sch, block_rv) + + if __name__ == "__main__": tvm.testing.main()