Skip to content

Commit

Permalink
[TVMScript] Support TVMScript template meta-programming over variables (
Browse files Browse the repository at this point in the history
#11097)

This PR supports a simple meta-programming paradigm for TVMScript, which allows users to get access to var definition in the Python environment.
  • Loading branch information
Hzfengsy authored Apr 27, 2022
1 parent c2803f6 commit c09a24d
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 7 deletions.
11 changes: 9 additions & 2 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class ContextMaintainer:
"""Dict[Var, Range]: The dict from loop var to its domain outside the block"""
symbols: List[Dict[str, Union[Var, Buffer]]] = []
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""
closure_vars: Dict[str, Object] = {}
"""ClosureVars: The closure vars defined in Python interpreter"""

# function context
func_params: List[Var] = []
Expand All @@ -144,12 +146,17 @@ class ContextMaintainer:
root_alloc_buffers: List[Buffer] = []
"""List[Buffer]: The buffers allocated under root block"""

def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
def __init__(
self,
_report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
closure_vars: Dict[str, Object],
):
# scope context
self.node_stack = []
self.block_info_stack = []
self.loop_stack = {}
self.symbols = []
self.closure_vars = closure_vars
# function context
self.func_params = []
self.func_buffer_map = {}
Expand Down Expand Up @@ -233,7 +240,7 @@ def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
for symbols in reversed(self.symbols):
if name in symbols:
return symbols[name]
return None
return self.closure_vars.get(name)

def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
self._report_error(message, span)
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,21 @@ class TVMScriptParser(Transformer):

# pylint gets confused here with synr.Transformer which doesn't have a
# custom init, so just disable it
def __init__(self, base_lineno, tir_namespace): # pylint: disable=super-init-not-called
def __init__(
self, base_lineno, tir_namespace, closure_vars
): # pylint: disable=super-init-not-called
self.context = None

self.base_lineno = base_lineno
self.current_lineno = 0
self.current_col_offset = 0
self.tir_namespace = tir_namespace
self.closure_vars = closure_vars
self.meta = None

def init_function_parsing_env(self):
"""Initialize function parsing environment"""
self.context = ContextMaintainer(self.report_error) # scope emitter
self.context = ContextMaintainer(self.report_error, self.closure_vars) # scope emitter

def init_meta(self, meta_dict):
if meta_dict is not None:
Expand Down Expand Up @@ -709,7 +712,7 @@ def transform_For(self, node):
self.context.enter_scope(nodes=node.body.stmts)
# for scope handler process the scope
arg_list = [
tvm.runtime.convert(arg, span=node.rhs.span)
tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span))
for arg in self.parse_arg_list(func, node.rhs)
]
func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
Expand Down Expand Up @@ -1253,12 +1256,14 @@ def from_source(
"""
if isinstance(input_func, str):
tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix))
return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {}))
elif inspect.isfunction(input_func):
_, start_line = inspect.getsourcelines(input_func)
env: Dict[str, Any] = input_func.__globals__
namespace = [key for key in env.keys() if env[key] is tir]
parser = TVMScriptParser(start_line, namespace)
_closure_vars = inspect.getclosurevars(input_func)
closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals}
parser = TVMScriptParser(start_line, namespace, closure_vars)
result = to_ast(input_func, TVMDiagnosticCtx(), parser)
return result
else:
Expand Down
59 changes: 59 additions & 0 deletions tests/python/unittest/test_tvmscript_meta_programming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm.script import tir as T


def matmul_generator(M: int, N: int, K: int, dtype: str):
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [M, K], dtype=dtype)
B = T.match_buffer(b, [N, K], dtype=dtype)
C = T.match_buffer(c, [M, N], dtype=dtype)

for i, j, k in T.grid(M, N, K):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

return matmul


@T.prim_func
def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128], dtype="float16")
B = T.match_buffer(b, [128, 128], dtype="float16")
C = T.match_buffer(c, [128, 128], dtype="float16")

for i, j, k in T.grid(128, 128, 128):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


def test_meta_programming_matmul():
f = matmul_generator(128, 128, 128, "float16")
tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16)


if __name__ == "__main__":
test_meta_programming_matmul()

0 comments on commit c09a24d

Please sign in to comment.