diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index e1c0019d9bf0..15748a99c81a 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -140,3 +140,22 @@ def has_block(sch: Schedule, block_name: str) -> bool: True if the given block exists in the schedule. """ return _ffi_api.HasBlock(sch, block_name) # type: ignore + + +def is_output_block(sch: Schedule, block: BlockRV) -> bool: + """Check whether the given block is an output block + + Parameters + ---------- + sch : Schedule + The schedule object of the block + block : BlockRV + The blockRV to be checked + + Returns + ------- + yes/no : bool + True if the given block is an output block + + """ + return _ffi_api.IsOutputBlock(sch, block) # type: ignore diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index b35d64f125d8..674abe28a3e0 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2071,6 +2071,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") }); TVM_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock); +TVM_REGISTER_GLOBAL("tir.schedule.IsOutputBlock").set_body_typed([](Schedule sch, BlockRV block) { + auto state = sch->state(); + auto block_sref = sch->GetSRef(block); + return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); +}); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 0002de38794b..cd91a44b6518 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -39,6 +39,7 @@ from tvm.tir.stmt_functor import pre_order_visit from tvm.meta_schedule.testing import te_workload from tvm.te import create_prim_func +from tvm.tir.schedule.analysis import is_output_block def _make_vars(*args: str) -> List[Var]: @@ -396,5 +397,25 @@ def test_get_auto_tensorize_mapping_info_matmul(n, m, k, expected): check_index_map(matmul, "C", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, expected) +def test_is_output_block(): + @T.prim_func + def two_elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tvm.tir.Schedule(two_elementwise) + block_rv = sch.get_block("C") + assert is_output_block(sch, block_rv) + + if __name__ == "__main__": tvm.testing.main()