diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 61ca3eb9f7eb..a00ea5768e23 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef { * \sa tvm::support::With */ static IRBuilder Current(); + /*! \brief See if the current thread-local scope has an IRBuilder. */ + static bool IsInScope(); /*! * \brief Give a string name to the `obj` * \tparam TObjectRef The type of the object to name. diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index dacfc361a6c7..ed425cf61441 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -45,11 +45,14 @@ class IRModuleFrameNode : public IRBuilderFrameNode { * \note Only defined functions are in the map, while declared functions are not included. */ Map functions; + /*! \brief IRModule's attributes. */ + Map attrs; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); + v->Visit("attrs", &attrs); } static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3daffb2640c5..232c70aa93d8 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -37,7 +37,7 @@ class IRModule(Node, Scriptable): Map of global var to BaseFunc """ - def __init__(self, functions=None, type_definitions=None): + def __init__(self, functions=None, type_definitions=None, attrs=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -60,7 +60,17 @@ def __init__(self, functions=None, type_definitions=None): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") mapped_type_defs[k] = v type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) + + attrs = None if not attrs else attrs + if attrs is not None: + attrs = ast.literal_eval(str(attrs)) + attrs = tvm.ir.make_node("DictAttrs", **attrs) + self.__init_handle_by_constructor__( + _ffi_api.IRModule, + functions, + type_definitions, + attrs, + ) def __setitem__(self, var, val): """Add a mapping to the module. diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index b35bbd0a7df5..1d5d050444f7 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -138,6 +138,17 @@ def current() -> "IRBuilder": """ return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member + @staticmethod + def is_in_scope() -> bool: + """See if the current thread-local scope has an IRBuilder. + + Returns + ------- + bool + Whether the current thread-local scope has an IRBuilder + """ + return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member + def get(self) -> _Object: """Get the constructed IR.""" return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index 946be263a779..b796de8113f3 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,9 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import decl_function, def_function, ir_module +from .ir import ( + decl_function, + def_function, + ir_module, + module_attrs, +) diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 796d6f3aad04..c5276f8d136e 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,6 +16,10 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from typing import Dict + +from tvm.runtime import Object as tvm_Object + from tvm.ir import BaseFunc, GlobalVar from . import _ffi_api @@ -67,3 +71,13 @@ def def_function(func_name: str, func: BaseFunc) -> None: The given function implementation """ return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_attrs(attrs: Dict[str, tvm_Object]) -> None: + """Specify the attrs of the ir_module frame. + Parameters + ---------- + attrs: Dict[str, Object] + The module attrs. + """ + return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py index fedd2f0a14a8..adda17601206 100644 --- a/python/tvm/script/parser/ir/__init__.py +++ b/python/tvm/script/parser/ir/__init__.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """The ir module parser""" - +from ...ir_builder.ir import * # pylint: disable=redefined-builtin from . import parser as _parser from .entry import ir_module -__all__ = ["ir_module"] +__all__ = ["ir_module", "module_attrs"] diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index 13b3e298590f..201c99074f20 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: with self.var_table.with_frame(): with I.ir_module(): + with self.with_dispatch_token("ir"): + for stmt in node.body: + if not isinstance(stmt, doc.FunctionDef): + self.visit(stmt) for stmt in node.body: if isinstance(stmt, doc.FunctionDef): self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): - self.visit_body(node.body) + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit(stmt) @dispatch.register(token="ir", type_name="Assign") @@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None: @dispatch.register(token="ir", type_name="Expr") -def _visit_expr(_self: Parser, _node: doc.Expr) -> None: +def _visit_expr(self: Parser, node: doc.Expr) -> None: """The expression visiting method for ir module. Parameters @@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None: node : doc.ClassDef The doc AST expression node. """ + self.eval_expr(node.value) @dispatch.register(token="default", type_name="Assign") diff --git a/src/ir/module.cc b/src/ir/module.cc index 4d5bebf70894..ba66a6689422 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -382,10 +382,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") - .set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); - }); + .set_body_typed([](tvm::Map funcs, tvm::Map types, + tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); }); TVM_REGISTER_GLOBAL("ir.Module_Add") .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 8303efff4f20..879db4f3d713 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() { return stack->back(); } +bool IRBuilder::IsInScope() { + std::vector* stack = ThreadLocalBuilderStack(); + return !stack->empty(); +} + namespace details { Namer::FType& Namer::vtable() { @@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") .set_body_method(&IRBuilderNode::Get); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name); diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index addf12928435..92470ec65342 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -38,7 +38,8 @@ void IRModuleFrameNode::ExitWithScope() { } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - builder->result = tvm::IRModule(func_map); + auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs); } TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 5764e90c8dd4..0c34f85246c9 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -60,9 +60,21 @@ void DefFunction(const String& func_name, const BaseFunc& func) { } } +void ModuleAttrs(Map attrs) { + if (IRBuilder::IsInScope()) { + // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope + IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); } // namespace ir } // namespace ir_builder diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 065cfe5168ad..1c751d40f2e7 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -64,6 +64,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::sort(functions.begin(), functions.end()); With f(d); (*f)->AddDispatchToken(d, "ir"); + if (mod->attrs.defined() && !mod->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(IR(d, "module_attrs") // + ->Call({d->AsDoc(mod->attrs, p->Attr("attrs"))}))); + } for (const auto& entry : functions) { const GlobalVar& gv = entry.gv; const BaseFunc& func = entry.func; diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index bbc6dd45a83e..52d99550be92 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3725,6 +3725,19 @@ def tir_packed_call(A: T.Buffer(16)): return tvm.tir.transform.LowerTVMBuiltin()(Module) +def ir_module_with_attrs(): + @I.ir_module + class Module: + I.module_attrs({"attr": 10}) + + @T.prim_func + def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): + for i in range(16): + B[i] = A[i] + + return Module + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3791,6 +3804,7 @@ def tir_packed_call(A: T.Buffer(16)): if_then_else_var, tvm_shfl_builtins, tvm_struct_set_generated_in_cpp, + ir_module_with_attrs, )