Skip to content

Commit

Permalink
[TVMScript] Expose IRModule::attrs as I.module_attrs
Browse files Browse the repository at this point in the history
This is an upstreaming of the non-relax portions of
apache#14132, including a unit test
specically to validate `I.module_attrs`.
  • Loading branch information
jinhongyii authored and tqchen committed Apr 6, 2023
1 parent 11c13ac commit ff5118f
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 12 deletions.
2 changes: 2 additions & 0 deletions include/tvm/script/ir_builder/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef {
* \sa tvm::support::With
*/
static IRBuilder Current();
/*! \brief See if the current thread-local scope has an IRBuilder. */
static bool IsInScope();
/*!
* \brief Give a string name to the `obj`
* \tparam TObjectRef The type of the object to name.
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/script/ir_builder/ir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ class IRModuleFrameNode : public IRBuilderFrameNode {
* \note Only defined functions are in the map, while declared functions are not included.
*/
Map<GlobalVar, BaseFunc> functions;
/*! \brief IRModule's attributes. */
Map<String, ObjectRef> attrs;

void VisitAttrs(tvm::AttrVisitor* v) {
IRBuilderFrameNode::VisitAttrs(v);
v->Visit("global_vars", &global_var_map);
v->Visit("functions", &functions);
v->Visit("attrs", &attrs);
}

static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame";
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class IRModule(Node, Scriptable):
Map of global var to BaseFunc
"""

def __init__(self, functions=None, type_definitions=None):
def __init__(self, functions=None, type_definitions=None, attrs=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
Expand All @@ -60,7 +60,17 @@ def __init__(self, functions=None, type_definitions=None):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)

attrs = None if not attrs else attrs
if attrs is not None:
attrs = ast.literal_eval(str(attrs))
attrs = tvm.ir.make_node("DictAttrs", **attrs)
self.__init_handle_by_constructor__(
_ffi_api.IRModule,
functions,
type_definitions,
attrs,
)

def __setitem__(self, var, val):
"""Add a mapping to the module.
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/script/ir_builder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ def current() -> "IRBuilder":
"""
return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member

@staticmethod
def is_in_scope() -> bool:
"""See if the current thread-local scope has an IRBuilder.
Returns
-------
bool
Whether the current thread-local scope has an IRBuilder
"""
return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member

def get(self) -> _Object:
"""Get the constructed IR."""
return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/script/ir_builder/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,9 @@
# under the License.
"""Package tvm.script.ir_builder.ir"""
from .frame import IRModuleFrame
from .ir import decl_function, def_function, ir_module
from .ir import (
decl_function,
def_function,
ir_module,
module_attrs,
)
14 changes: 14 additions & 0 deletions python/tvm/script/ir_builder/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
# under the License.
"""Package tvm.script.ir_builder.ir.ir"""

from typing import Dict

from tvm.runtime import Object as tvm_Object

from tvm.ir import BaseFunc, GlobalVar

from . import _ffi_api
Expand Down Expand Up @@ -67,3 +71,13 @@ def def_function(func_name: str, func: BaseFunc) -> None:
The given function implementation
"""
return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member


def module_attrs(attrs: Dict[str, tvm_Object]) -> None:
"""Specify the attrs of the ir_module frame.
Parameters
----------
attrs: Dict[str, Object]
The module attrs.
"""
return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member
4 changes: 2 additions & 2 deletions python/tvm/script/parser/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""The ir module parser"""

from ...ir_builder.ir import * # pylint: disable=redefined-builtin
from . import parser as _parser
from .entry import ir_module

__all__ = ["ir_module"]
__all__ = ["ir_module", "module_attrs"]
11 changes: 9 additions & 2 deletions python/tvm/script/parser/ir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:

with self.var_table.with_frame():
with I.ir_module():
with self.with_dispatch_token("ir"):
for stmt in node.body:
if not isinstance(stmt, doc.FunctionDef):
self.visit(stmt)
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
self.visit_tvm_declare_function(stmt)
with self.with_dispatch_token("ir"):
self.visit_body(node.body)
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
self.visit(stmt)


@dispatch.register(token="ir", type_name="Assign")
Expand All @@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None:


@dispatch.register(token="ir", type_name="Expr")
def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
def _visit_expr(self: Parser, node: doc.Expr) -> None:
"""The expression visiting method for ir module.
Parameters
Expand All @@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
node : doc.ClassDef
The doc AST expression node.
"""
self.eval_expr(node.value)


@dispatch.register(token="default", type_name="Assign")
Expand Down
6 changes: 2 additions & 4 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) {
TVM_REGISTER_NODE_TYPE(IRModuleNode);

TVM_REGISTER_GLOBAL("ir.IRModule")
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> types) {
return IRModule(funcs, types, {});
});
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types,
tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); });

TVM_REGISTER_GLOBAL("ir.Module_Add")
.set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {
Expand Down
6 changes: 6 additions & 0 deletions src/script/ir_builder/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() {
return stack->back();
}

bool IRBuilder::IsInScope() {
std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
return !stack->empty();
}

namespace details {

Namer::FType& Namer::vtable() {
Expand Down Expand Up @@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet")
.set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
Expand Down
3 changes: 2 additions & 1 deletion src/script/ir_builder/ir/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ void IRModuleFrameNode::ExitWithScope() {
}
IRBuilder builder = IRBuilder::Current();
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
builder->result = tvm::IRModule(func_map);
auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs);
}

TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
Expand Down
12 changes: 12 additions & 0 deletions src/script/ir_builder/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,21 @@ void DefFunction(const String& func_name, const BaseFunc& func) {
}
}

void ModuleAttrs(Map<String, ObjectRef> attrs) {
if (IRBuilder::IsInScope()) {
// TODO(hongyi): add comments to explain why we need to check if the module frame is in scope
IRModuleFrame frame = FindModuleFrame("I.ModuleAttr");
if (!frame->attrs.empty()) {
LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs;
}
frame->attrs = attrs;
}
}

TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs);

} // namespace ir
} // namespace ir_builder
Expand Down
5 changes: 5 additions & 0 deletions src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
std::sort(functions.begin(), functions.end());
With<IRFrame> f(d);
(*f)->AddDispatchToken(d, "ir");
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
(*f)->stmts.push_back(
ExprStmtDoc(IR(d, "module_attrs") //
->Call({d->AsDoc<ExprDoc>(mod->attrs, p->Attr("attrs"))})));
}
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
const BaseFunc& func = entry.func;
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3725,6 +3725,19 @@ def tir_packed_call(A: T.Buffer(16)):
return tvm.tir.transform.LowerTVMBuiltin()(Module)


def ir_module_with_attrs():
@I.ir_module
class Module:
I.module_attrs({"attr": 10})

@T.prim_func
def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")):
for i in range(16):
B[i] = A[i]

return Module


ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
Expand Down Expand Up @@ -3791,6 +3804,7 @@ def tir_packed_call(A: T.Buffer(16)):
if_then_else_var,
tvm_shfl_builtins,
tvm_struct_set_generated_in_cpp,
ir_module_with_attrs,
)


Expand Down

0 comments on commit ff5118f

Please sign in to comment.