diff --git a/python/tvm/script/builder/__init__.py b/python/tvm/script/builder/__init__.py index 999bfb1b6930..087b3955afe7 100644 --- a/python/tvm/script/builder/__init__.py +++ b/python/tvm/script/builder/__init__.py @@ -16,7 +16,5 @@ # under the License. # pylint: disable=unused-import """Namespace for the TVMScript Builder API.""" - - from .builder import Builder, def_, def_many from .frame import Frame, IRModuleFrame diff --git a/python/tvm/script/builder/tir/axis.py b/python/tvm/script/builder/tir/axis.py index 0371aa5bd802..7be7cd42aba2 100644 --- a/python/tvm/script/builder/tir/axis.py +++ b/python/tvm/script/builder/tir/axis.py @@ -36,3 +36,7 @@ def reduce(dom, binding, dtype="int32") -> IterVar: def remap(kinds, bindings, dtype="int32") -> IterVar: return _ffi_api.AxisRemap(kinds, bindings, dtype) # pylint: disable=no-member # type: ignore + + +S = spatial +R = reduce diff --git a/python/tvm/script/builder/tir/prim_func_frame.py b/python/tvm/script/builder/tir/prim_func_frame.py index 370fe361552d..525be3b66c2c 100644 --- a/python/tvm/script/builder/tir/prim_func_frame.py +++ b/python/tvm/script/builder/tir/prim_func_frame.py @@ -21,6 +21,7 @@ from tvm.tir.buffer import Buffer from tvm.tir.expr import Var +from ..builder import Builder from . import _ffi_api from .base import TIRFrame @@ -36,3 +37,6 @@ def prim_func(name) -> PrimFuncFrame: def arg(name, obj) -> Union[Var, Buffer]: return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore + + +setattr(prim_func, "dispatch_token", "tir") diff --git a/python/tvm/script/parse/__init__.py b/python/tvm/script/parse/__init__.py new file mode 100644 index 000000000000..0b7f8285205c --- /dev/null +++ b/python/tvm/script/parse/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the Licens. +"""The parser""" +from . import dispatch, parser, tir +from .entry import parse diff --git a/python/tvm/script/parse/dispatch.py b/python/tvm/script/parse/dispatch.py new file mode 100644 index 000000000000..ee38d3878f57 --- /dev/null +++ b/python/tvm/script/parse/dispatch.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The dispatcher""" + +import ast +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple + +if TYPE_CHECKING: + from .parser import Parser + + +ParseMethod = Callable[ + ["Parser", ast.AST], + None, +] + + +class DispatchTable: + """Dispatch table for parse methods""" + + _instance: Optional["DispatchTable"] = None + table: Dict[Tuple[str, str], ParseMethod] + + def __init__(self): + self.table = {} + + +DispatchTable._instance = DispatchTable() # pylint: disable=protected-access + + +def register( + token: str, + type_name: str, +): + """Register a method for a dispatch token and type name""" + + def f(method: ParseMethod): + DispatchTable._instance.table[ # pylint: disable=protected-access + (token, type_name) + ] = method + + return f + + +def get( + token: str, + type_name: str, + default: Optional[ParseMethod] = None, +) -> Optional[ParseMethod]: + return DispatchTable._instance.table.get( # pylint: disable=protected-access + (token, type_name), + default, + ) diff --git a/python/tvm/script/parse/entry.py b/python/tvm/script/parse/entry.py new file mode 100644 index 000000000000..6239487327a0 --- /dev/null +++ b/python/tvm/script/parse/entry.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The entry point of TVM parser.""" +import ast +import inspect +from typing import Any, Dict, Optional, Union + +from ..builder import Builder +from .parser import Parser + + +class SourceCode: + source_name: str + start_line: int + start_column: int + source: str + full_source: str + + def __init__(self, program: Union[str, ast.AST]): + if isinstance(program, str): + self.source_name = "" + self.start_line = 1 + self.start_column = 0 + self.source = program + self.full_source = program + else: + self.source_name = inspect.getsourcefile(program) # type: ignore + lines, self.start_line = inspect.getsourcelines(program) # type: ignore + + if lines: + self.start_column = len(lines[0]) - len(lines[0].lstrip()) + else: + self.start_column = 0 + if self.start_column and lines: + self.source = "\n".join([l[self.start_column :].rstrip() for l in lines]) + else: + self.source = "" + try: + # It will cause a problem when running in Jupyter Notebook. + # `mod` will be , which is a built-in module + # and `getsource` will throw a TypeError + mod = inspect.getmodule(program) + if mod: + self.full_source = inspect.getsource(mod) + else: + self.full_source = self.source + except TypeError: + # It's a work around for Jupyter problem. + # Since `findsource` is an internal API of inspect, we just use it + # as a fallback method. + src, _ = inspect.findsource(program) # type: ignore + self.full_source = "".join(src) + + def as_ast(self) -> ast.AST: + return ast.parse(self.source) + + +def parse( + program: Union[ast.AST, Any, str], + extra_vars: Optional[Dict[str, Any]] = None, +): + program_ast = SourceCode(program).as_ast() + parser = Parser() + with Builder() as builder: + with parser.var_table.with_frame(): + if extra_vars: + for k, v in extra_vars.items(): + parser.var_table.add(k, v) + parser.visit(program_ast) + return builder.get() diff --git a/python/tvm/script/parse/evaluator.py b/python/tvm/script/parse/evaluator.py new file mode 100644 index 000000000000..e4f62b7a81c6 --- /dev/null +++ b/python/tvm/script/parse/evaluator.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""AST Evaluation""" +import ast +from typing import Any, Dict, Optional, Union + + +def eval_expr( + node: Union[ast.expr, ast.Expression], + dict_globals: Optional[Dict[str, Any]], +) -> Any: + if isinstance(node, ast.expr): + node = ast.Expression(body=node) + assert isinstance(node, ast.Expression) + if dict_globals is None: + dict_globals = {} + node = ast.fix_missing_locations(node) + exe = compile(node, filename="", mode="eval") + return eval(exe, dict_globals) # pylint: disable=eval-used + + +def eval_assign( + target: ast.expr, + source: Any, +) -> Dict[str, Any]: + assert isinstance(target, ast.expr) + RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name + rhs_var_name = RHS_VAR_NAME + dict_locals = {rhs_var_name: source} + mod = ast.fix_missing_locations( + ast.Module( + body=[ + ast.Assign( + targets=[target], + value=ast.Name( + id=rhs_var_name, + ctx=ast.Load(), + ), + ) + ], + type_ignores=[], + ) + ) + exe = compile(mod, filename="", mode="exec") + exec(exe, {}, dict_locals) # pylint: disable=exec-used + del dict_locals[rhs_var_name] + return dict_locals diff --git a/python/tvm/script/parse/parser.py b/python/tvm/script/parse/parser.py new file mode 100644 index 000000000000..6101a252acef --- /dev/null +++ b/python/tvm/script/parse/parser.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The core parser""" +import ast +from typing import Any, Dict, List, Optional, Union + +from ..builder import def_ +from . import dispatch +from .evaluator import eval_assign, eval_expr +from .utils import deferred +from .var_table import VarTable + + +def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod: + for token in [self.dispatch_tokens[-1], "default"]: + func = dispatch.get(token=token, type_name=type_name, default=None) + if func is not None: + return func + return lambda self, node: self.generic_visit(node) + + +def _handle_function(self: "Parser", node: ast.FunctionDef) -> None: + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if hasattr(decorator, "dispatch_token"): + token = decorator.dispatch_token + func = dispatch.get(token=token, type_name="FunctionDef", default=None) + if func is not None: + func(self, node) + return + self.report_error(node, "The parser does not understand the decorator") + + +class Parser(ast.NodeVisitor): + """The TVMScript parser""" + + dispatch_tokens: List[str] + var_table: VarTable + + def __init__(self) -> None: + self.dispatch_tokens = ["default"] + self.var_table = VarTable() + + def with_dispatch_token(self, token: str): + def pop_token(): + self.dispatch_tokens.pop() + + self.dispatch_tokens.append(token) + return deferred(pop_token) + + def eval_expr( + self, + node: Union[ast.Expression, ast.expr], + extra_vars: Optional[Dict[str, Any]] = None, + ) -> Any: + var_values = self.var_table.get() + if extra_vars is not None: + for k, v in extra_vars.items(): + var_values[k] = v + return eval_expr(node, var_values) + + def eval_assign( + self, + target: ast.expr, + source: Any, + ) -> Dict[str, Any]: + var_values = eval_assign(target, source) + for k, v in var_values.items(): + def_(k, v) + self.var_table.add(k, v) + return var_values + + def report_error(self, node: ast.AST, msg: str) -> None: # pylint: disable=no-self-use + raise SyntaxError(f"At {node.lineno}:{node.col_offset}: {msg}") + + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # pylint: disable=invalid-name + _handle_function(self, node) + + def visit_body(self, node: List[ast.stmt]) -> Any: + for stmt in node: + self.visit(stmt) + + def visit_arguments(self, node: ast.arguments) -> Any: + _dispatch(self, "arguments")(self, node) + + def visit_For(self, node: ast.For) -> Any: # pylint: disable=invalid-name + _dispatch(self, "For")(self, node) + + def visit_With(self, node: ast.With) -> Any: # pylint: disable=invalid-name + _dispatch(self, "With")(self, node) + + def visit_Assign(self, node: ast.Assign) -> Any: # pylint: disable=invalid-name + _dispatch(self, "Assign")(self, node) diff --git a/python/tvm/script/parse/tir/__init__.py b/python/tvm/script/parse/tir/__init__.py new file mode 100644 index 000000000000..ea2e43c60248 --- /dev/null +++ b/python/tvm/script/parse/tir/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from . import tir diff --git a/python/tvm/script/parse/tir/tir.py b/python/tvm/script/parse/tir/tir.py new file mode 100644 index 000000000000..c219e24076c0 --- /dev/null +++ b/python/tvm/script/parse/tir/tir.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import ast +import contextlib + +from ...builder import Frame +from ...builder import tir as T +from .. import dispatch +from ..parser import Parser + + +@dispatch.register(token="tir", type_name="For") +def visit_for(self: Parser, node: ast.For) -> None: + for_frame = self.eval_expr(node.iter) + if not isinstance(for_frame, T.ForFrame): + self.report_error( + node.iter, + "Expect the for loop to be one of the following: " + "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", + ) + with self.var_table.with_frame(): + with for_frame as iters: + self.eval_assign(target=node.target, source=iters) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="Assign") +def visit_assign(self: Parser, node: ast.Assign) -> None: + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + rhs = self.eval_expr(node.value) + self.eval_assign(target=lhs, source=rhs) + + +@dispatch.register(token="tir", type_name="With") +def visit_with(self: Parser, node: ast.With) -> None: + with contextlib.ExitStack() as stack: + stack.enter_context(self.var_table.with_frame()) + for item in node.items: + frame = self.eval_expr(item.context_expr) + if not isinstance(frame, Frame): + self.report_error( + item.context_expr, "Invalid context expression in the with-statement." + ) + rhs = stack.enter_context(frame) + if item.optional_vars is not None: + self.eval_assign( + target=item.optional_vars, + source=rhs, + ) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="FunctionDef") +def visit_function_def(self: Parser, node: ast.FunctionDef) -> None: + with self.var_table.with_frame(): + self.var_table.add("range", T.serial) + with T.prim_func(node.name): + with self.with_dispatch_token("tir"): + # TODO: define the GlobalVar, handle the return value + self.visit(node.args) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="arguments") +def visit_arguments(self: Parser, node: ast.arguments) -> None: + # TODO: handle different types of arguments: + # - vararg: arg | None + # - kwonlyargs: list[arg] + # - kw_defaults: list[expr | None] + # - kwarg: arg | None + # - defaults: list[expr] + # - posonlyargs: list[arg] + for arg in node.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param = T.arg(arg.arg, self.eval_expr(arg.annotation)) + self.var_table.add(arg.arg, param) diff --git a/python/tvm/script/parse/utils.py b/python/tvm/script/parse/utils.py new file mode 100644 index 000000000000..c839f58f5433 --- /dev/null +++ b/python/tvm/script/parse/utils.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from contextlib import contextmanager +from typing import Callable + + +def deferred(f: Callable[[], None]): + @contextmanager + def context(): + try: + yield + finally: + f() + + return context() diff --git a/python/tvm/script/parse/var_table.py b/python/tvm/script/parse/var_table.py new file mode 100644 index 000000000000..cfa271975d4f --- /dev/null +++ b/python/tvm/script/parse/var_table.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The symbol table of variable values""" + +from collections import defaultdict +from typing import Any, Callable, Dict, List, Set + +from .utils import deferred + + +class VarTableFrame: + vars: Set[str] + + def __init__(self): + self.vars = set() + + def add(self, var: str): + if var in self.vars: + raise ValueError(f"Variable {var} already defined in current scope") + self.vars.add(var) + + def pop_all(self, fn_pop: Callable[[str], None]): + for var in self.vars: + fn_pop(var) + self.vars.clear() + + +class VarTable: + + frames: List[VarTableFrame] + name2value: Dict[str, List[Any]] + + def __init__(self): + self.frames = [] + self.name2value = defaultdict(list) + + def with_frame(self): + def pop_frame(): + frame = self.frames.pop() + frame.pop_all(lambda name: self.name2value[name].pop()) + + self.frames.append(VarTableFrame()) + return deferred(pop_frame) + + def add(self, var: str, value: Any): + self.frames[-1].add(var) + self.name2value[var].append(value) + + def get(self) -> Dict[str, Any]: + return {key: values[-1] for key, values in self.name2value.items() if values} diff --git a/src/script/builder/builder.cc b/src/script/builder/builder.cc index b9c5b9848608..c2d4312e9978 100644 --- a/src/script/builder/builder.cc +++ b/src/script/builder/builder.cc @@ -18,6 +18,7 @@ */ #include "./builder.h" +#include #include namespace tvm { @@ -71,6 +72,17 @@ void Namer::Name(ObjectRef node, String name) { f(node, name); } +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using namespace tvm::runtime; + ArrayNode* array = const_cast(node.as()); + ICHECK(array); + int n = array->size(); + for (int i = 0; i < n; ++i) { + Namer::Name(array->at(i), name + std::to_string(i)); + } + }); + namespace details { ObjectRef DefImpl(String name, ObjectRef obj) { diff --git a/tests/python/tvmscript/test_parse_basic.py b/tests/python/tvmscript/test_parse_basic.py new file mode 100644 index 000000000000..2dac332feccc --- /dev/null +++ b/tests/python/tvmscript/test_parse_basic.py @@ -0,0 +1,25 @@ +from tvm.script.builder import tir as T +from tvm.script.parse import parse + +elementwise = """ +@T.prim_func +def elementwise( + A: T.Buffer(shape=(128, 128, 128), dtype="float32"), + B: T.Buffer(shape=(128, 128, 128), dtype="float32"), +) -> None: + for i, j, *vvv, k in T.grid(128, 128, 128, 128, 128, 128, 128): + with T.block("inner_block"): + # vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + vi = T.axis.S(128, i + 1) + vj = T.axis.S(128, j + 20) + vk = T.axis.R(128, k - i) +""" + + +def main(): + result = parse(elementwise, extra_vars={"T": T}) + print(result.script()) + + +if __name__ == "__main__": + main()