Skip to content

Commit

Permalink
Enable error rendering (apache#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jul 13, 2022
1 parent 6a90537 commit 9f42c5d
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 67 deletions.
52 changes: 52 additions & 0 deletions python/tvm/script/parse/diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from tvm.ir import IRModule, SourceName, Span, diagnostics

from . import doc
from .source import Source


class Diagnostics:

source: Source
ctx: diagnostics.DiagnosticContext

def __init__(self, source: Source):
mod = IRModule()
mod.source_map.add(source.source_name, source.source)
self.source = source
self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())

def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
self.ctx.emit(
diagnostics.Diagnostic(
level=level,
span=Span(
source_name=SourceName(self.source.source_name),
line=node.lineno,
end_line=node.end_lineno,
column=node.col_offset,
end_column=node.end_col_offset,
),
message=message,
)
)

def error(self, node: doc.AST, message: str) -> None:
self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
self.ctx.render()
23 changes: 13 additions & 10 deletions python/tvm/script/parse/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,20 @@ def parse(
source,
filename="<unknown>",
mode="exec",
*,
type_comments=False,
feature_version=None,
) -> doc.AST:
program = ast.parse(
source=source,
filename=filename,
mode=mode,
type_comments=type_comments,
feature_version=feature_version,
)
try:
program = ast.parse(
source=source,
filename=filename,
mode=mode,
feature_version=(3, 8),
)
except:
program = ast.parse(
source=source,
filename=filename,
mode=mode,
)
return to_doc(program)


Expand Down
63 changes: 8 additions & 55 deletions python/tvm/script/parse/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,69 +15,22 @@
# specific language governing permissions and limitations
# under the License.
"""The entry point of TVM parser."""
import inspect
from typing import Any, Union

from ..builder import Builder
from . import doc
from .parser import Parser


class SourceCode:
source_name: str
start_line: int
start_column: int
source: str
full_source: str

def __init__(self, program: Union[str, doc.AST]):
if isinstance(program, str):
self.source_name = "<str>"
self.start_line = 1
self.start_column = 0
self.source = program
self.full_source = program
else:
self.source_name = inspect.getsourcefile(program) # type: ignore
lines, self.start_line = inspect.getsourcelines(program) # type: ignore
if lines:
self.start_column = len(lines[0]) - len(lines[0].lstrip())
else:
self.start_column = 0
if self.start_column and lines:
self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
else:
self.source = "".join(lines)
try:
# It will cause a problem when running in Jupyter Notebook.
# `mod` will be <module '__main__'>, which is a built-in module
# and `getsource` will throw a TypeError
mod = inspect.getmodule(program)
if mod:
self.full_source = inspect.getsource(mod)
else:
self.full_source = self.source
except TypeError:
# It's a work around for Jupyter problem.
# Since `findsource` is an internal API of inspect, we just use it
# as a fallback method.
src, _ = inspect.findsource(program) # type: ignore
self.full_source = "".join(src)

def as_ast(self) -> doc.AST:
return doc.parse(self.source)
from .source import Source


def parse(program: Union[doc.AST, Any, str]):
# TODO: `extra_vars` is a hack
from tvm.script.builder import tir as T
from tvm.script.builder import tir as T # pylint: disable=import-outside-toplevel

extra_vars = {"T": T}
program_ast = SourceCode(program).as_ast()
parser = Parser()
extra_vars = { # TODO: `extra_vars` is a hack
"T": T,
}
source = Source(program)
parser = Parser(source)
with Builder() as builder:
with parser.var_table.with_frame():
for k, v in extra_vars.items():
parser.var_table.add(k, v)
parser.visit(program_ast)
parser.parse(vars=extra_vars)
return builder.get()
17 changes: 15 additions & 2 deletions python/tvm/script/parse/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from ..builder import def_
from . import dispatch, doc
from .diagnostics import Diagnostics
from .evaluator import eval_assign, eval_expr
from .source import Source
from .utils import deferred
from .var_table import VarTable

Expand All @@ -42,13 +44,24 @@ def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
class Parser(doc.NodeVisitor):
"""The TVMScript parser"""

diag: Diagnostics
dispatch_tokens: List[str]
var_table: VarTable

def __init__(self) -> None:
def __init__(self, source: Source) -> None:
self.diag = Diagnostics(source)
self.dispatch_tokens = ["default"]
self.var_table = VarTable()

def parse(self, vars: Optional[Dict[str, Any]] = None) -> Any:
if vars is None:
vars = {}
with self.var_table.with_frame():
for k, v in vars.items():
self.var_table.add(k, v)
node = self.diag.source.as_ast()
self.visit(node)

def with_dispatch_token(self, token: str):
def pop_token():
self.dispatch_tokens.pop()
Expand Down Expand Up @@ -79,7 +92,7 @@ def eval_assign(
return var_values

def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-self-use
raise SyntaxError(f"At {node.lineno}:{node.col_offset}: {msg}")
self.diag.error(node, msg)

def visit(self, node: doc.AST) -> None:
if isinstance(node, (list, tuple)):
Expand Down
65 changes: 65 additions & 0 deletions python/tvm/script/parse/source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import inspect
from typing import Union

from . import doc


class Source:
source_name: str
start_line: int
start_column: int
source: str
full_source: str

def __init__(self, program: Union[str, doc.AST]):
if isinstance(program, str):
self.source_name = "<str>"
self.start_line = 1
self.start_column = 0
self.source = program
self.full_source = program
else:
self.source_name = inspect.getsourcefile(program) # type: ignore
lines, self.start_line = inspect.getsourcelines(program) # type: ignore
if lines:
self.start_column = len(lines[0]) - len(lines[0].lstrip())
else:
self.start_column = 0
if self.start_column and lines:
self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
else:
self.source = "".join(lines)
try:
# It will cause a problem when running in Jupyter Notebook.
# `mod` will be <module '__main__'>, which is a built-in module
# and `getsource` will throw a TypeError
mod = inspect.getmodule(program)
if mod:
self.full_source = inspect.getsource(mod)
else:
self.full_source = self.source
except TypeError:
# It's a work around for Jupyter problem.
# Since `findsource` is an internal API of inspect, we just use it
# as a fallback method.
src, _ = inspect.findsource(program) # type: ignore
self.full_source = "".join(src)

def as_ast(self) -> doc.AST:
return doc.parse(self.source)
13 changes: 13 additions & 0 deletions tests/python/tvmscript/test_parse_basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import inspect

import pytest
import tvm
from tvm.script.builder import ir as I
from tvm.script.builder import tir as T

Expand Down Expand Up @@ -67,8 +69,19 @@ def f(A: T.int32, B: T.int64, C: T.handle) -> None:
assert f.params[2].dtype == "handle"


def test_parse_report_error():
with pytest.raises(tvm.error.DiagnosticError):

@T.prim_func
def elementwise() -> None:
for (*vvv,) in T.grid(128, 128, 128, 128, 128, 128, 128):
with T.block("inner_block"):
vj = T.axis.S(128, vvv[10] + 20)


if __name__ == "__main__":
test_parse_elementwise()
test_parse_skip()
test_parse_class()
test_parse_atomic()
test_parse_report_error()

0 comments on commit 9f42c5d

Please sign in to comment.