Skip to content

Commit

Permalink
[TVMScript] Implicit root block syntax sugar for TVMScript printer (#…
Browse files Browse the repository at this point in the history
…13819)

This PR implements the syntax sugar of implicit root block for new TVMScript printer. This syntax sugar will skip the `T.block("root")`, when the root block realize is simple and we shall reconstruct that root block via `tvm::tir::ScriptComplete` when roundtripping. For example, it will change
```python
@T.prim_func
def root_block_explicitly():
  with T.block("root"):
    a = T.alloc_buffer([128, 128])
    for i, j in T.grid(128, 128):
      with T.block():
        T.evaluate(0)
```
into
```python
@T.prim_func
def main():
  a = T.alloc_buffer((128, 128))
  for i, j in T.grid(128, 128):
    with T.block(""):
      T.evaluate(0)
```
  • Loading branch information
cyx-6 authored Jan 21, 2023
1 parent d907de3 commit ac9fb98
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 13 deletions.
35 changes: 34 additions & 1 deletion src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,40 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
}
// Step 4. Handle `func->body`
AsDocBody(func->body, p->Attr("body"), frame->get(), d);
Optional<tir::Block> implicit_root_block = [&]() -> Optional<tir::Block> {
const tir::BlockRealizeNode* root_block_realize = func->body.as<tir::BlockRealizeNode>();
if (root_block_realize && !root_block_realize->iter_values.size() &&
tir::is_one(root_block_realize->predicate)) {
tir::Block root_block = root_block_realize->block;
if (!root_block->annotations.size() && !root_block->match_buffers.size() &&
!root_block->reads.size() && !root_block->writes.size() &&
!root_block->init.defined()) {
const tir::BlockRealizeNode* block_realize =
root_block->body.as<tir::BlockRealizeNode>();
if (root_block->alloc_buffers.size() ||
(block_realize && block_realize->block->iter_vars.size()) ||
(!block_realize && tir::ContainsNode<tir::BlockRealizeNode>(root_block->body))) {
return root_block;
}
}
}
return NullOpt;
}();
if (implicit_root_block) {
tir::Block root_block = implicit_root_block.value();
ObjectPath root_block_p = p->Attr("body")->Attr("body");
// Handle root block `alloc_buffer`
for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) {
tir::Buffer buffer = root_block->alloc_buffers[i];
ObjectPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayIndex(i);
IdDoc lhs = DefineBuffer(buffer, *frame, d);
ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *frame, d);
(*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
}
AsDocBody(root_block->body, root_block_p->Attr("body"), frame->get(), d);
} else {
AsDocBody(func->body, p->Attr("body"), frame->get(), d);
}
Optional<ExprDoc> ret_type = NullOpt;
if (func->ret_type.defined()) {
const auto* as_tuple = func->ret_type.as<TupleTypeNode>();
Expand Down
52 changes: 40 additions & 12 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,21 +717,49 @@ def block_with_remap_explicitly():

expected_output = """@T.prim_func
def main():
with T.block("root"):
T.reads()
T.writes()
for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
with T.block("update"):
v0 = T.axis.spatial(128, i0 + 1)
v1, v2 = T.axis.remap("SR", [i1, i2])
v3 = T.axis.spatial(128, i3 - 1)
v4, v5 = T.axis.remap("RS", [i4, i5])
T.reads()
T.writes()
T.evaluate(0)"""
for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
with T.block("update"):
v0 = T.axis.spatial(128, i0 + 1)
v1, v2 = T.axis.remap("SR", [i1, i2])
v3 = T.axis.spatial(128, i3 - 1)
v4, v5 = T.axis.remap("RS", [i4, i5])
T.reads()
T.writes()
T.evaluate(0)"""
_assert_print(block_with_remap_explicitly, expected_output)
_assert_print(block_with_remap_implicitly, expected_output)


def test_root_block():
from tvm.script import tir as T

@T.prim_func
def root_block_implicitly():
a = T.alloc_buffer([128, 128])
for i, j in T.grid(128, 128):
with T.block():
T.evaluate(0)

@T.prim_func
def root_block_explicitly():
with T.block("root"):
a = T.alloc_buffer([128, 128])
for i, j in T.grid(128, 128):
with T.block():
T.evaluate(0)

expected_output = """@T.prim_func
def main():
a = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block(""):
T.reads()
T.writes()
T.evaluate(0)
"""
_assert_print(root_block_implicitly, expected_output)
_assert_print(root_block_explicitly, expected_output)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit ac9fb98

Please sign in to comment.