Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Support TVMScript template meta-programming over variables #11097

Merged
merged 1 commit into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class ContextMaintainer:
"""Dict[Var, Range]: The dict from loop var to its domain outside the block"""
symbols: List[Dict[str, Union[Var, Buffer]]] = []
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""
closure_vars: Dict[str, Object] = {}
"""ClosureVars: The closure vars defined in Python interpreter"""

# function context
func_params: List[Var] = []
Expand All @@ -144,12 +146,17 @@ class ContextMaintainer:
root_alloc_buffers: List[Buffer] = []
"""List[Buffer]: The buffers allocated under root block"""

def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
def __init__(
self,
_report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
closure_vars: Dict[str, Object],
):
# scope context
self.node_stack = []
self.block_info_stack = []
self.loop_stack = {}
self.symbols = []
self.closure_vars = closure_vars
# function context
self.func_params = []
self.func_buffer_map = {}
Expand Down Expand Up @@ -233,7 +240,7 @@ def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
for symbols in reversed(self.symbols):
if name in symbols:
return symbols[name]
return None
return self.closure_vars.get(name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm running into some errors when a closure variable isn't already a PrimExpr. In the example below, this will cause an error message saying that it cannot add together a PrimExpr and an int.

offset = 1

@T.prim_func
def func(A: T.Buffer[(1,), "int32"], B: T.Buffer[(1,), "int32"]):
    B[0] = A[0] + offset

Whenever we pull a variable out of the closure, can we run it through tvm.runtime.convert? That way, any expression type supported by the FFI would be converted to a TIR-supported format.

if name in self.closure_vars:
    return tvm.runtime.convert(self.closure_vars[name])

Copy link
Contributor

@Lunderberg Lunderberg Apr 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it looks like this recommended change also causes some of the passing tests in https://gist.github.com/Lunderberg/dd38f82810e4e06c0834087d4a96bda9 to fail, such as using a meta-parameter as a loop iteration bound. But that's more than I can debug on a Friday afternoon.


def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
self._report_error(message, span)
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,21 @@ class TVMScriptParser(Transformer):

# pylint gets confused here with synr.Transformer which doesn't have a
# custom init, so just disable it
def __init__(self, base_lineno, tir_namespace): # pylint: disable=super-init-not-called
def __init__(
self, base_lineno, tir_namespace, closure_vars
): # pylint: disable=super-init-not-called
self.context = None

self.base_lineno = base_lineno
self.current_lineno = 0
self.current_col_offset = 0
self.tir_namespace = tir_namespace
self.closure_vars = closure_vars
self.meta = None

def init_function_parsing_env(self):
"""Initialize function parsing environment"""
self.context = ContextMaintainer(self.report_error) # scope emitter
self.context = ContextMaintainer(self.report_error, self.closure_vars) # scope emitter

def init_meta(self, meta_dict):
if meta_dict is not None:
Expand Down Expand Up @@ -708,7 +711,7 @@ def transform_For(self, node):
self.context.enter_scope(nodes=node.body.stmts)
# for scope handler process the scope
arg_list = [
tvm.runtime.convert(arg, span=node.rhs.span)
tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span))
for arg in self.parse_arg_list(func, node.rhs)
]
func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
Expand Down Expand Up @@ -1252,12 +1255,14 @@ def from_source(
"""
if isinstance(input_func, str):
tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix))
return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {}))
elif inspect.isfunction(input_func):
_, start_line = inspect.getsourcelines(input_func)
env: Dict[str, Any] = input_func.__globals__
namespace = [key for key in env.keys() if env[key] is tir]
parser = TVMScriptParser(start_line, namespace)
_closure_vars = inspect.getclosurevars(input_func)
closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals}
parser = TVMScriptParser(start_line, namespace, closure_vars)
result = to_ast(input_func, TVMDiagnosticCtx(), parser)
return result
else:
Expand Down
59 changes: 59 additions & 0 deletions tests/python/unittest/test_tvmscript_meta_programming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm.script import tir as T


def matmul_generator(M: int, N: int, K: int, dtype: str):
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [M, K], dtype=dtype)
B = T.match_buffer(b, [N, K], dtype=dtype)
C = T.match_buffer(c, [M, N], dtype=dtype)

for i, j, k in T.grid(M, N, K):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

return matmul


@T.prim_func
def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128], dtype="float16")
B = T.match_buffer(b, [128, 128], dtype="float16")
C = T.match_buffer(c, [128, 128], dtype="float16")

for i, j, k in T.grid(128, 128, 128):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


def test_meta_programming_matmul():
f = matmul_generator(128, 128, 128, "float16")
tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16)


if __name__ == "__main__":
test_meta_programming_matmul()