diff --git a/python/tvm/script/parse/diagnostics.py b/python/tvm/script/parse/diagnostics.py new file mode 100644 index 000000000000..0bcd2a86cf94 --- /dev/null +++ b/python/tvm/script/parse/diagnostics.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tvm.ir import IRModule, SourceName, Span, diagnostics + +from . import doc +from .source import Source + + +class Diagnostics: + + source: Source + ctx: diagnostics.DiagnosticContext + + def __init__(self, source: Source): + mod = IRModule() + mod.source_map.add(source.source_name, source.source) + self.source = source + self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer()) + + def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None: + self.ctx.emit( + diagnostics.Diagnostic( + level=level, + span=Span( + source_name=SourceName(self.source.source_name), + line=node.lineno, + end_line=node.end_lineno, + column=node.col_offset, + end_column=node.end_col_offset, + ), + message=message, + ) + ) + + def error(self, node: doc.AST, message: str) -> None: + self._emit(node, message, diagnostics.DiagnosticLevel.ERROR) + self.ctx.render() diff --git a/python/tvm/script/parse/doc.py b/python/tvm/script/parse/doc.py index edf6225489e5..c3576b2aa675 100644 --- a/python/tvm/script/parse/doc.py +++ b/python/tvm/script/parse/doc.py @@ -132,17 +132,20 @@ def parse( source, filename="", mode="exec", - *, - type_comments=False, - feature_version=None, ) -> doc.AST: - program = ast.parse( - source=source, - filename=filename, - mode=mode, - type_comments=type_comments, - feature_version=feature_version, - ) + try: + program = ast.parse( + source=source, + filename=filename, + mode=mode, + feature_version=(3, 8), + ) + except: + program = ast.parse( + source=source, + filename=filename, + mode=mode, + ) return to_doc(program) diff --git a/python/tvm/script/parse/entry.py b/python/tvm/script/parse/entry.py index 1fefd02b07ea..1aed9b3270e6 100644 --- a/python/tvm/script/parse/entry.py +++ b/python/tvm/script/parse/entry.py @@ -15,69 +15,22 @@ # specific language governing permissions and limitations # under the License. """The entry point of TVM parser.""" -import inspect from typing import Any, Union from ..builder import Builder from . import doc from .parser import Parser - - -class SourceCode: - source_name: str - start_line: int - start_column: int - source: str - full_source: str - - def __init__(self, program: Union[str, doc.AST]): - if isinstance(program, str): - self.source_name = "" - self.start_line = 1 - self.start_column = 0 - self.source = program - self.full_source = program - else: - self.source_name = inspect.getsourcefile(program) # type: ignore - lines, self.start_line = inspect.getsourcelines(program) # type: ignore - if lines: - self.start_column = len(lines[0]) - len(lines[0].lstrip()) - else: - self.start_column = 0 - if self.start_column and lines: - self.source = "\n".join([l[self.start_column :].rstrip() for l in lines]) - else: - self.source = "".join(lines) - try: - # It will cause a problem when running in Jupyter Notebook. - # `mod` will be , which is a built-in module - # and `getsource` will throw a TypeError - mod = inspect.getmodule(program) - if mod: - self.full_source = inspect.getsource(mod) - else: - self.full_source = self.source - except TypeError: - # It's a work around for Jupyter problem. - # Since `findsource` is an internal API of inspect, we just use it - # as a fallback method. - src, _ = inspect.findsource(program) # type: ignore - self.full_source = "".join(src) - - def as_ast(self) -> doc.AST: - return doc.parse(self.source) +from .source import Source def parse(program: Union[doc.AST, Any, str]): - # TODO: `extra_vars` is a hack - from tvm.script.builder import tir as T + from tvm.script.builder import tir as T # pylint: disable=import-outside-toplevel - extra_vars = {"T": T} - program_ast = SourceCode(program).as_ast() - parser = Parser() + extra_vars = { # TODO: `extra_vars` is a hack + "T": T, + } + source = Source(program) + parser = Parser(source) with Builder() as builder: - with parser.var_table.with_frame(): - for k, v in extra_vars.items(): - parser.var_table.add(k, v) - parser.visit(program_ast) + parser.parse(vars=extra_vars) return builder.get() diff --git a/python/tvm/script/parse/parser.py b/python/tvm/script/parse/parser.py index f5384ceb4832..bf6743e08627 100644 --- a/python/tvm/script/parse/parser.py +++ b/python/tvm/script/parse/parser.py @@ -19,7 +19,9 @@ from ..builder import def_ from . import dispatch, doc +from .diagnostics import Diagnostics from .evaluator import eval_assign, eval_expr +from .source import Source from .utils import deferred from .var_table import VarTable @@ -42,13 +44,24 @@ def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod: class Parser(doc.NodeVisitor): """The TVMScript parser""" + diag: Diagnostics dispatch_tokens: List[str] var_table: VarTable - def __init__(self) -> None: + def __init__(self, source: Source) -> None: + self.diag = Diagnostics(source) self.dispatch_tokens = ["default"] self.var_table = VarTable() + def parse(self, vars: Optional[Dict[str, Any]] = None) -> Any: + if vars is None: + vars = {} + with self.var_table.with_frame(): + for k, v in vars.items(): + self.var_table.add(k, v) + node = self.diag.source.as_ast() + self.visit(node) + def with_dispatch_token(self, token: str): def pop_token(): self.dispatch_tokens.pop() @@ -79,7 +92,7 @@ def eval_assign( return var_values def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-self-use - raise SyntaxError(f"At {node.lineno}:{node.col_offset}: {msg}") + self.diag.error(node, msg) def visit(self, node: doc.AST) -> None: if isinstance(node, (list, tuple)): diff --git a/python/tvm/script/parse/source.py b/python/tvm/script/parse/source.py new file mode 100644 index 000000000000..8dd71fd41e9b --- /dev/null +++ b/python/tvm/script/parse/source.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import inspect +from typing import Union + +from . import doc + + +class Source: + source_name: str + start_line: int + start_column: int + source: str + full_source: str + + def __init__(self, program: Union[str, doc.AST]): + if isinstance(program, str): + self.source_name = "" + self.start_line = 1 + self.start_column = 0 + self.source = program + self.full_source = program + else: + self.source_name = inspect.getsourcefile(program) # type: ignore + lines, self.start_line = inspect.getsourcelines(program) # type: ignore + if lines: + self.start_column = len(lines[0]) - len(lines[0].lstrip()) + else: + self.start_column = 0 + if self.start_column and lines: + self.source = "\n".join([l[self.start_column :].rstrip() for l in lines]) + else: + self.source = "".join(lines) + try: + # It will cause a problem when running in Jupyter Notebook. + # `mod` will be , which is a built-in module + # and `getsource` will throw a TypeError + mod = inspect.getmodule(program) + if mod: + self.full_source = inspect.getsource(mod) + else: + self.full_source = self.source + except TypeError: + # It's a work around for Jupyter problem. + # Since `findsource` is an internal API of inspect, we just use it + # as a fallback method. + src, _ = inspect.findsource(program) # type: ignore + self.full_source = "".join(src) + + def as_ast(self) -> doc.AST: + return doc.parse(self.source) diff --git a/tests/python/tvmscript/test_parse_basic.py b/tests/python/tvmscript/test_parse_basic.py index 280fb6d72db0..8fa8e2ce824b 100644 --- a/tests/python/tvmscript/test_parse_basic.py +++ b/tests/python/tvmscript/test_parse_basic.py @@ -1,5 +1,7 @@ import inspect +import pytest +import tvm from tvm.script.builder import ir as I from tvm.script.builder import tir as T @@ -67,8 +69,19 @@ def f(A: T.int32, B: T.int64, C: T.handle) -> None: assert f.params[2].dtype == "handle" +def test_parse_report_error(): + with pytest.raises(tvm.error.DiagnosticError): + + @T.prim_func + def elementwise() -> None: + for (*vvv,) in T.grid(128, 128, 128, 128, 128, 128, 128): + with T.block("inner_block"): + vj = T.axis.S(128, vvv[10] + 20) + + if __name__ == "__main__": test_parse_elementwise() test_parse_skip() test_parse_class() test_parse_atomic() + test_parse_report_error()