Skip to content

Commit

Permalink
[TIR] [Analysis] Expose IsOutputBlock to python (#14352)
Browse files Browse the repository at this point in the history
This patch just exposes an existing analysis API IsOutputBlock to
python. Since many schedule primitives have conditions on output blocks,
this API would be really useful while scheduling
  • Loading branch information
quic-sanirudh authored Mar 21, 2023
1 parent 4819300 commit 50b3ae4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
19 changes: 19 additions & 0 deletions python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,22 @@ def has_block(sch: Schedule, block_name: str) -> bool:
True if the given block exists in the schedule.
"""
return _ffi_api.HasBlock(sch, block_name) # type: ignore


def is_output_block(sch: Schedule, block: BlockRV) -> bool:
"""Check whether the given block is an output block
Parameters
----------
sch : Schedule
The schedule object of the block
block : BlockRV
The blockRV to be checked
Returns
-------
yes/no : bool
True if the given block is an output block
"""
return _ffi_api.IsOutputBlock(sch, block) # type: ignore
5 changes: 5 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo")
});

TVM_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock);
TVM_REGISTER_GLOBAL("tir.schedule.IsOutputBlock").set_body_typed([](Schedule sch, BlockRV block) {
auto state = sch->state();
auto block_sref = sch->GetSRef(block);
return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false));
});

} // namespace tir
} // namespace tvm
21 changes: 21 additions & 0 deletions tests/python/unittest/test_tir_schedule_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from tvm.tir.stmt_functor import pre_order_visit
from tvm.meta_schedule.testing import te_workload
from tvm.te import create_prim_func
from tvm.tir.schedule.analysis import is_output_block


def _make_vars(*args: str) -> List[Var]:
Expand Down Expand Up @@ -396,5 +397,25 @@ def test_get_auto_tensorize_mapping_info_matmul(n, m, k, expected):
check_index_map(matmul, "C", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, expected)


def test_is_output_block():
@T.prim_func
def two_elementwise(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128), "float32")
B = T.alloc_buffer((128, 128), "float32")
C = T.match_buffer(c, (128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0

sch = tvm.tir.Schedule(two_elementwise)
block_rv = sch.get_block("C")
assert is_output_block(sch, block_rv)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 50b3ae4

Please sign in to comment.