From 11c13ace0b5cef71f50193248ecaac7e845ee25e Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 8 Feb 2023 22:31:47 +0800 Subject: [PATCH] [TVMScript] IRModule TVMScript Parser. This PR adds the TVMScript parser/ir_builder support based on the blockbuilder. This commit contains the non-relax portions from https://github.com/apache/tvm/pull/13932. Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Tianqi Chen Co-authored-by: Yuchen Jin Co-authored-by: Steven S. Lyubomirsky Co-authored-by: Yong Wu --- include/tvm/script/ir_builder/ir/frame.h | 11 +++-- include/tvm/script/ir_builder/ir/ir.h | 17 +++++++ python/tvm/script/ir_builder/base.py | 6 ++- python/tvm/script/ir_builder/ir/__init__.py | 2 +- python/tvm/script/ir_builder/ir/ir.py | 45 ++++++++++++++++++ python/tvm/script/parser/core/diagnostics.py | 2 +- python/tvm/script/parser/core/evaluator.py | 2 +- python/tvm/script/parser/core/parser.py | 50 ++++++++++++++------ python/tvm/script/parser/ir/parser.py | 4 ++ python/tvm/script/parser/tir/entry.py | 4 +- python/tvm/script/parser/tir/parser.py | 26 ++++++++++ src/script/ir_builder/ir/frame.cc | 12 +++-- src/script/ir_builder/ir/ir.cc | 32 ++++++++++++- src/script/ir_builder/ir/utils.h | 49 +++++++++++++++++++ src/script/ir_builder/tir/frame.cc | 15 ++++-- src/script/ir_builder/tir/utils.h | 2 +- 16 files changed, 245 insertions(+), 34 deletions(-) create mode 100644 src/script/ir_builder/ir/utils.h diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 887981ccffc8..dacfc361a6c7 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -38,12 +38,17 @@ namespace ir { */ class IRModuleFrameNode : public IRBuilderFrameNode { public: - Array global_vars; - Array functions; + /*! \brief A map from string names to global variables that ensures global uniqueness. */ + Map global_var_map; + /*! + * \brief A map from GlobalVar to all global functions. + * \note Only defined functions are in the map, while declared functions are not included. + */ + Map functions; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); + v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); } diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index f0e7cc6f5c2f..49bdcf60e6fb 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -37,6 +37,23 @@ namespace ir { */ TVM_DLL IRModuleFrame IRModule(); +/*! + * \brief Declare a Function without given the specific function implementation. + * \note It is usually used in cross-function call. And we can specify the function by `DefFunction` + * \param func_name The function unique name. + * \param func_signature A Function w/o body, which used to specify the function signature + * (i.e. func params and func return type/shape). + * \return The corresponding GlobalVar. + */ +TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature); + +/*! + * \brief Define the function which is declared before. + * \param func_name The function unique name. + * \param func The given function implementation + */ +TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); + } // namespace ir } // namespace ir_builder } // namespace script diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 7aa33ee49c72..b35bbd0a7df5 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame": _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member return self - def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument - _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member + def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument + if exc_type is None and exc_value is None: + # Do not execute `FrameExit` if the with scope exits because of exceptions + _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member def add_callback(self, callback: Callable[[], None]) -> None: """Add a callback method invoked when exiting the with-scope. diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index ebb9728737ad..946be263a779 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,4 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import ir_module +from .ir import decl_function, def_function, ir_module diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 213180463cb2..796d6f3aad04 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,9 +16,54 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from tvm.ir import BaseFunc, GlobalVar + from . import _ffi_api from .frame import IRModuleFrame def ir_module() -> IRModuleFrame: + """Start a ir_module frame. + Returns + ------- + frame: IRModuleFrame + The constructed frame. + """ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar: + """Declare a Function without given the specific function implementation. + Parameters + ---------- + func_name : str + The function unique name. + + func_signature: Optional[BaseFunc] + A Function w/o body, which used to specify the function signature + (i.e. func params and func return type/shape). + + Note + ---- + It is usually used in cross-function call. And we can specify the function by `DefFunction` + Returns + ------- + gv : GlobalVar + The corresponding GlobalVar. + """ + + return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member + func_name, func_signature + ) + + +def def_function(func_name: str, func: BaseFunc) -> None: + """Define the function which is declared before. + Parameters + ---------- + func_name : str + The function unique name. + func: BaseFunc + The given function implementation + """ + return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py index ad7ae5034780..2767a97f6096 100644 --- a/python/tvm/script/parser/core/diagnostics.py +++ b/python/tvm/script/parser/core/diagnostics.py @@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) level : diagnostics.DiagnosticLevel The diagnostic level. """ - lineno = node.lineno or self.source.start_line + lineno = node.lineno or 1 col_offset = node.col_offset or self.source.start_column end_lineno = node.end_lineno or lineno end_col_offset = node.end_col_offset or col_offset diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 3a72a3c33106..075aedd89146 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -203,7 +203,7 @@ def _visit(self, node: doc.AST) -> Any: else: value = self._eval_expr(node.__class__(**fields)) except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) + self.parser.report_error(node, e) return self._add_intermediate_result(value) def _eval_lambda(self, node: doc.Lambda) -> Any: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index fdccabcd235d..837b7cce5d5e 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -60,6 +60,10 @@ def context(): return context() +def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument + pass + + class VarTableFrame: """The variable table frame. A frame of variable table stores the variables created in one block or scope. @@ -260,6 +264,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: node = self.diag.source.as_ast() self.visit(node) + def get_dispatch_token(self, node: doc.FunctionDef) -> str: + if not isinstance(node, doc.FunctionDef): + self.report_error(node, "Only can get dispatch token for function.") + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + return decorator.dispatch_token + def with_dispatch_token(self, token: str): """Add a new dispatching token as with statement. @@ -389,6 +404,8 @@ def report_error( # Only take the last line of the error message if isinstance(err, TVMError): msg = list(filter(None, str(err).split("\n")))[-1] + elif isinstance(err, KeyError): + msg = "KeyError: " + str(err) else: msg = str(err) self.diag.error(node, msg) @@ -458,30 +475,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any: """ return _dispatch(self, "tvm_annotation")(self, node) - def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name - """The general function definition visiting method. + def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name + """The general function definition visit method. Parameters ---------- node : doc.FunctionDef - The doc AST function definition node. - - Returns - ------- - res : Any - The visiting result. + The doc FunctionDef node. """ - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if not hasattr(decorator, "dispatch_token"): - self.report_error(node, "The parser does not understand the decorator") - token = decorator.dispatch_token + token = self.get_dispatch_token(node) + current_token = self.dispatch_tokens[-1] func = dispatch.get(token=token, type_name="FunctionDef", default=None) if func is None: self.report_error(node, "The parser does not understand the decorator") + pre_func = dispatch.get( + token=current_token, type_name="pre_token_switch", default=_do_nothing + ) + post_func = dispatch.get( + token=current_token, type_name="post_token_switch", default=_do_nothing + ) + pre_func(self, node) _dispatch_wrapper(func)(self, node) + post_func(self, node) + + def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None: + token = self.get_dispatch_token(node) + with self.with_dispatch_token(token): + _dispatch(self, "tvm_declare_function")(self, node) def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name """The general class definition visiting method. diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index e0268412d284..13b3e298590f 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -32,8 +32,12 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: node : doc.ClassDef The doc AST class definition node. """ + with self.var_table.with_frame(): with I.ir_module(): + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): self.visit_body(node.body) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 411a7f8f3c83..649f817411f0 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -83,7 +83,7 @@ def __getitem__(self, keys) -> Buffer: return self(keys) if len(keys) >= 2 and not isinstance(keys[1], str): return self(keys) - return self(*keys) # pylint: disable=no-member # type: ignore + return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member class PtrProxy: @@ -93,7 +93,7 @@ class PtrProxy: def __call__(self, dtype, storage_scope="global"): if callable(dtype): dtype = dtype().dtype - return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.Ptr[...]", "T.handle(...)") def __getitem__(self, keys): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 8a067267a352..63171f672289 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -24,6 +24,7 @@ from tvm.ir import PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var +from ...ir_builder import ir as I from ...ir_builder import tir as T from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame @@ -473,3 +474,28 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ self.report_error(node, "Return is not allowed.") + + +@dispatch.register(token="tir", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: + """The function declaration step for tir + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + + ret_type = None + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + + # Only ret_type is needed for func_signature. + func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) + global_var = I.decl_function(node.name, func_signature) + self.var_table.add(node.name, global_var) diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index a81c56922dff..addf12928435 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -26,11 +26,15 @@ namespace ir_builder { namespace ir { void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); Map func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); + CHECK_EQ(functions.size(), global_var_map.size()) + << "All functions must be defined in the IRModule. Got " << global_var_map.size() + << "declared function(s), but only " << functions.size() << "defined function(s)."; + for (const auto& kv : functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index a8cc452e4f0c..5764e90c8dd4 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -20,6 +20,8 @@ #include #include +#include "./utils.h" + namespace tvm { namespace script { namespace ir_builder { @@ -27,12 +29,40 @@ namespace ir { IRModuleFrame IRModule() { ObjectPtr n = make_object(); - n->global_vars.clear(); + n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); } +GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) { + IRModuleFrame frame = FindModuleFrame("I.DeclFunction"); + CHECK(!frame->global_var_map.count(func_name)) + << "ValueError: function " << func_name << " already exists"; + GlobalVar gv = GlobalVar(func_name); + CHECK(frame->functions.find(gv) == frame->functions.end()) + << "ValueError: function " << func_name << " has already been defined."; + frame->global_var_map.Set(func_name, gv); + if (func_signature.defined()) { + frame->functions.Set(gv, func_signature); + } + return gv; +} + +void DefFunction(const String& func_name, const BaseFunc& func) { + IRModuleFrame frame = FindModuleFrame("I.DefFunction"); + auto it = frame->global_var_map.find(func_name); + CHECK(it != frame->global_var_map.end()) + << "ValueError: function " << func_name << " does not exist, please declare it first."; + const GlobalVar& gv = (*it).second; + frame->functions.Set(gv, func); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type_; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h new file mode 100644 index 000000000000..58d5e53f7032 --- /dev/null +++ b/src/script/ir_builder/ir/utils.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace ir { + +inline IRModuleFrame FindModuleFrame(const String& method) { + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->FindFrame()) { + const Optional& last_module_frame = builder->GetLastFrame(); + if (last_module_frame.defined() && last_module_frame.value() == frame) { + return frame.value(); + } + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + throw; +} + +} // namespace ir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 1e63201a40dd..dd8d3c2ed3f3 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { - ir::IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name.value_or(""))); - frame->functions.push_back(func); + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const ir::IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // Case. First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; } diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 7ccc132fa1fe..f3b547532cfd 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -87,7 +87,7 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { * \return The top frame of BlockFrame. */ inline BlockFrame FindBlockFrame(const String& method) { - if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } else if (Optional frame = IRBuilder::Current()->FindFrame()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(). "