forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
589 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the Licens. | ||
"""The parser""" | ||
from . import dispatch, parser, tir | ||
from .entry import parse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""The dispatcher""" | ||
|
||
import ast | ||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple | ||
|
||
if TYPE_CHECKING: | ||
from .parser import Parser | ||
|
||
|
||
ParseMethod = Callable[ | ||
["Parser", ast.AST], | ||
None, | ||
] | ||
|
||
|
||
class DispatchTable: | ||
"""Dispatch table for parse methods""" | ||
|
||
_instance: Optional["DispatchTable"] = None | ||
table: Dict[Tuple[str, str], ParseMethod] | ||
|
||
def __init__(self): | ||
self.table = {} | ||
|
||
|
||
DispatchTable._instance = DispatchTable() # pylint: disable=protected-access | ||
|
||
|
||
def register( | ||
token: str, | ||
type_name: str, | ||
): | ||
"""Register a method for a dispatch token and type name""" | ||
|
||
def f(method: ParseMethod): | ||
DispatchTable._instance.table[ # pylint: disable=protected-access | ||
(token, type_name) | ||
] = method | ||
|
||
return f | ||
|
||
|
||
def get( | ||
token: str, | ||
type_name: str, | ||
default: Optional[ParseMethod] = None, | ||
) -> Optional[ParseMethod]: | ||
return DispatchTable._instance.table.get( # pylint: disable=protected-access | ||
(token, type_name), | ||
default, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""The entry point of TVM parser.""" | ||
import ast | ||
import inspect | ||
from typing import Any, Dict, Optional, Union | ||
|
||
from ..builder import Builder | ||
from .parser import Parser | ||
|
||
|
||
class SourceCode: | ||
source_name: str | ||
start_line: int | ||
start_column: int | ||
source: str | ||
full_source: str | ||
|
||
def __init__(self, program: Union[str, ast.AST]): | ||
if isinstance(program, str): | ||
self.source_name = "<str>" | ||
self.start_line = 1 | ||
self.start_column = 0 | ||
self.source = program | ||
self.full_source = program | ||
else: | ||
self.source_name = inspect.getsourcefile(program) # type: ignore | ||
lines, self.start_line = inspect.getsourcelines(program) # type: ignore | ||
|
||
if lines: | ||
self.start_column = len(lines[0]) - len(lines[0].lstrip()) | ||
else: | ||
self.start_column = 0 | ||
if self.start_column and lines: | ||
self.source = "\n".join([l[self.start_column :].rstrip() for l in lines]) | ||
else: | ||
self.source = "" | ||
try: | ||
# It will cause a problem when running in Jupyter Notebook. | ||
# `mod` will be <module '__main__'>, which is a built-in module | ||
# and `getsource` will throw a TypeError | ||
mod = inspect.getmodule(program) | ||
if mod: | ||
self.full_source = inspect.getsource(mod) | ||
else: | ||
self.full_source = self.source | ||
except TypeError: | ||
# It's a work around for Jupyter problem. | ||
# Since `findsource` is an internal API of inspect, we just use it | ||
# as a fallback method. | ||
src, _ = inspect.findsource(program) # type: ignore | ||
self.full_source = "".join(src) | ||
|
||
def as_ast(self) -> ast.AST: | ||
return ast.parse(self.source) | ||
|
||
|
||
def parse( | ||
program: Union[ast.AST, Any, str], | ||
extra_vars: Optional[Dict[str, Any]] = None, | ||
): | ||
program_ast = SourceCode(program).as_ast() | ||
parser = Parser() | ||
with Builder() as builder: | ||
with parser.var_table.with_frame(): | ||
if extra_vars: | ||
for k, v in extra_vars.items(): | ||
parser.var_table.add(k, v) | ||
parser.visit(program_ast) | ||
return builder.get() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""AST Evaluation""" | ||
import ast | ||
from typing import Any, Dict, Optional, Union | ||
|
||
|
||
def eval_expr( | ||
node: Union[ast.expr, ast.Expression], | ||
dict_globals: Optional[Dict[str, Any]], | ||
) -> Any: | ||
if isinstance(node, ast.expr): | ||
node = ast.Expression(body=node) | ||
assert isinstance(node, ast.Expression) | ||
if dict_globals is None: | ||
dict_globals = {} | ||
node = ast.fix_missing_locations(node) | ||
exe = compile(node, filename="<ast>", mode="eval") | ||
return eval(exe, dict_globals) # pylint: disable=eval-used | ||
|
||
|
||
def eval_assign( | ||
target: ast.expr, | ||
source: Any, | ||
) -> Dict[str, Any]: | ||
assert isinstance(target, ast.expr) | ||
RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name | ||
rhs_var_name = RHS_VAR_NAME | ||
dict_locals = {rhs_var_name: source} | ||
mod = ast.fix_missing_locations( | ||
ast.Module( | ||
body=[ | ||
ast.Assign( | ||
targets=[target], | ||
value=ast.Name( | ||
id=rhs_var_name, | ||
ctx=ast.Load(), | ||
), | ||
) | ||
], | ||
type_ignores=[], | ||
) | ||
) | ||
exe = compile(mod, filename="<ast>", mode="exec") | ||
exec(exe, {}, dict_locals) # pylint: disable=exec-used | ||
del dict_locals[rhs_var_name] | ||
return dict_locals |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""The core parser""" | ||
import ast | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from ..builder import def_ | ||
from . import dispatch | ||
from .evaluator import eval_assign, eval_expr | ||
from .utils import deferred | ||
from .var_table import VarTable | ||
|
||
|
||
def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod: | ||
for token in [self.dispatch_tokens[-1], "default"]: | ||
func = dispatch.get(token=token, type_name=type_name, default=None) | ||
if func is not None: | ||
return func | ||
return lambda self, node: self.generic_visit(node) | ||
|
||
|
||
def _handle_function(self: "Parser", node: ast.FunctionDef) -> None: | ||
if not node.decorator_list: | ||
self.report_error(node, "Function must be decorated") | ||
# TODO: only the last decorator is parsed | ||
decorator = self.eval_expr(node.decorator_list[-1]) | ||
if hasattr(decorator, "dispatch_token"): | ||
token = decorator.dispatch_token | ||
func = dispatch.get(token=token, type_name="FunctionDef", default=None) | ||
if func is not None: | ||
func(self, node) | ||
return | ||
self.report_error(node, "The parser does not understand the decorator") | ||
|
||
|
||
class Parser(ast.NodeVisitor): | ||
"""The TVMScript parser""" | ||
|
||
dispatch_tokens: List[str] | ||
var_table: VarTable | ||
|
||
def __init__(self) -> None: | ||
self.dispatch_tokens = ["default"] | ||
self.var_table = VarTable() | ||
|
||
def with_dispatch_token(self, token: str): | ||
def pop_token(): | ||
self.dispatch_tokens.pop() | ||
|
||
self.dispatch_tokens.append(token) | ||
return deferred(pop_token) | ||
|
||
def eval_expr( | ||
self, | ||
node: Union[ast.Expression, ast.expr], | ||
extra_vars: Optional[Dict[str, Any]] = None, | ||
) -> Any: | ||
var_values = self.var_table.get() | ||
if extra_vars is not None: | ||
for k, v in extra_vars.items(): | ||
var_values[k] = v | ||
return eval_expr(node, var_values) | ||
|
||
def eval_assign( | ||
self, | ||
target: ast.expr, | ||
source: Any, | ||
) -> Dict[str, Any]: | ||
var_values = eval_assign(target, source) | ||
for k, v in var_values.items(): | ||
def_(k, v) | ||
self.var_table.add(k, v) | ||
return var_values | ||
|
||
def report_error(self, node: ast.AST, msg: str) -> None: # pylint: disable=no-self-use | ||
raise SyntaxError(f"At {node.lineno}:{node.col_offset}: {msg}") | ||
|
||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # pylint: disable=invalid-name | ||
_handle_function(self, node) | ||
|
||
def visit_body(self, node: List[ast.stmt]) -> Any: | ||
for stmt in node: | ||
self.visit(stmt) | ||
|
||
def visit_arguments(self, node: ast.arguments) -> Any: | ||
_dispatch(self, "arguments")(self, node) | ||
|
||
def visit_For(self, node: ast.For) -> Any: # pylint: disable=invalid-name | ||
_dispatch(self, "For")(self, node) | ||
|
||
def visit_With(self, node: ast.With) -> Any: # pylint: disable=invalid-name | ||
_dispatch(self, "With")(self, node) | ||
|
||
def visit_Assign(self, node: ast.Assign) -> Any: # pylint: disable=invalid-name | ||
_dispatch(self, "Assign")(self, node) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
from . import tir |
Oops, something went wrong.