Skip to content

Commit

Permalink
register the dispatch for runtime::Module
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Mar 20, 2023
1 parent 873ab3d commit 9d64727
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts));
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<runtime::Module>("",
[](runtime::Module rtmod, ObjectPath p, IRDocsifier d) -> Doc {
std::ostringstream oss;
oss << rtmod << ", " << rtmod.get();
return LiteralDoc::Str(String(oss.str()), NullOpt);
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<DictAttrs>("", [](DictAttrs attrs, ObjectPath p, IRDocsifier d) -> Doc {
return d->AsDoc(attrs->dict, p->Attr("dict"));
Expand Down
44 changes: 44 additions & 0 deletions tests/python/relax/test_tvmscript_printer_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,5 +529,49 @@ def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32
)


def test_runtime_module_in_irmodule_attrs():
@I.ir_module
class TestModule:
@T.prim_func
def tir_func(
x: T.Buffer((T.int64(128),), "float32"), y: T.Buffer((T.int64(128),), "float32")
):
T.evaluate(0)

@R.function
def foo(x: R.Tensor((128,), "float32")) -> R.Tensor((128,), "float32"):
cls = TestModule
gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128,), dtype="float32"))
return gv0

exec = relax.build(TestModule, "llvm")
NewTestModule = TestModule.with_attr("test", exec.mod)
# empty module alias
module_str = NewTestModule.script(module_alias="")
_assert_print(
module_str,
"""
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"test": "Module(type_key= relax.Executable),
""".rstrip()
+ f" {exec.mod.handle.value:#x}".rstrip()
+ """"})
@T.prim_func
def tir_func(x: T.Buffer((T.int64(128),), "float32"), y: T.Buffer((T.int64(128),), "float32")):
T.evaluate(0)
@R.function
def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32"):
gv0 = R.call_tir(Module.tir_func, (x,), out_sinfo=R.Tensor((128,), dtype="float32"))
return gv0
""",
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 9d64727

Please sign in to comment.