diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 972e5845fcb9..f7f16855c752 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -121,6 +121,8 @@ class ContextMaintainer: """Dict[Var, Range]: The dict from loop var to its domain outside the block""" symbols: List[Dict[str, Union[Var, Buffer]]] = [] """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope""" + closure_vars: Dict[str, Object] = {} + """ClosureVars: The closure vars defined in Python interpreter""" # function context func_params: List[Var] = [] @@ -144,12 +146,17 @@ class ContextMaintainer: root_alloc_buffers: List[Buffer] = [] """List[Buffer]: The buffers allocated under root block""" - def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]): + def __init__( + self, + _report_error: Callable[[str, Union[Span, synr.ast.Span]], None], + closure_vars: Dict[str, Object], + ): # scope context self.node_stack = [] self.block_info_stack = [] self.loop_stack = {} self.symbols = [] + self.closure_vars = closure_vars # function context self.func_params = [] self.func_buffer_map = {} @@ -233,7 +240,7 @@ def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]: for symbols in reversed(self.symbols): if name in symbols: return symbols[name] - return None + return self.closure_vars.get(name) def report_error(self, message: str, span: Union[Span, synr.ast.Span]): self._report_error(message, span) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index b01ad383c36d..13b283bc0c40 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -158,18 +158,21 @@ class TVMScriptParser(Transformer): # pylint gets confused here with synr.Transformer which doesn't have a # custom init, so just disable it - def __init__(self, base_lineno, tir_namespace): # pylint: disable=super-init-not-called + def __init__( + self, base_lineno, tir_namespace, closure_vars + ): # pylint: disable=super-init-not-called self.context = None self.base_lineno = base_lineno self.current_lineno = 0 self.current_col_offset = 0 self.tir_namespace = tir_namespace + self.closure_vars = closure_vars self.meta = None def init_function_parsing_env(self): """Initialize function parsing environment""" - self.context = ContextMaintainer(self.report_error) # scope emitter + self.context = ContextMaintainer(self.report_error, self.closure_vars) # scope emitter def init_meta(self, meta_dict): if meta_dict is not None: @@ -709,7 +712,7 @@ def transform_For(self, node): self.context.enter_scope(nodes=node.body.stmts) # for scope handler process the scope arg_list = [ - tvm.runtime.convert(arg, span=node.rhs.span) + tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span)) for arg in self.parse_arg_list(func, node.rhs) ] func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) @@ -1253,12 +1256,14 @@ def from_source( """ if isinstance(input_func, str): tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix - return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix)) + return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {})) elif inspect.isfunction(input_func): _, start_line = inspect.getsourcelines(input_func) env: Dict[str, Any] = input_func.__globals__ namespace = [key for key in env.keys() if env[key] is tir] - parser = TVMScriptParser(start_line, namespace) + _closure_vars = inspect.getclosurevars(input_func) + closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals} + parser = TVMScriptParser(start_line, namespace, closure_vars) result = to_ast(input_func, TVMDiagnosticCtx(), parser) return result else: diff --git a/tests/python/unittest/test_tvmscript_meta_programming.py b/tests/python/unittest/test_tvmscript_meta_programming.py new file mode 100644 index 000000000000..2473c0c84564 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_meta_programming.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import tir as T + + +def matmul_generator(M: int, N: int, K: int, dtype: str): + @T.prim_func + def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [M, K], dtype=dtype) + B = T.match_buffer(b, [N, K], dtype=dtype) + C = T.match_buffer(c, [M, N], dtype=dtype) + + for i, j, k in T.grid(M, N, K): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + return matmul + + +@T.prim_func +def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float16") + B = T.match_buffer(b, [128, 128], dtype="float16") + C = T.match_buffer(c, [128, 128], dtype="float16") + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +def test_meta_programming_matmul(): + f = matmul_generator(128, 128, 128, "float16") + tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16) + + +if __name__ == "__main__": + test_meta_programming_matmul()