From 661f03a8a0ae516daf13a8e205270f08fceb5583 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 12 Jun 2022 11:52:34 -0700 Subject: [PATCH] Stable Doc AST (#44) --- python/tvm/script/parse/__init__.py | 2 +- python/tvm/script/parse/dispatch.py | 5 +- python/tvm/script/parse/doc.py | 198 ++++ python/tvm/script/parse/doc_core.py | 1279 ++++++++++++++++++++++++++ python/tvm/script/parse/entry.py | 10 +- python/tvm/script/parse/evaluator.py | 8 +- python/tvm/script/parse/parser.py | 25 +- python/tvm/script/parse/tir/tir.py | 13 +- 8 files changed, 1510 insertions(+), 30 deletions(-) create mode 100644 python/tvm/script/parse/doc.py create mode 100644 python/tvm/script/parse/doc_core.py diff --git a/python/tvm/script/parse/__init__.py b/python/tvm/script/parse/__init__.py index 0b7f8285205c..8844bbebe0d6 100644 --- a/python/tvm/script/parse/__init__.py +++ b/python/tvm/script/parse/__init__.py @@ -15,5 +15,5 @@ # specific language governing permissions and limitations # under the Licens. """The parser""" -from . import dispatch, parser, tir +from . import dispatch, doc, parser, tir from .entry import parse diff --git a/python/tvm/script/parse/dispatch.py b/python/tvm/script/parse/dispatch.py index ee38d3878f57..59332ade3710 100644 --- a/python/tvm/script/parse/dispatch.py +++ b/python/tvm/script/parse/dispatch.py @@ -16,15 +16,16 @@ # under the License. """The dispatcher""" -import ast from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple +from .doc import AST + if TYPE_CHECKING: from .parser import Parser ParseMethod = Callable[ - ["Parser", ast.AST], + ["Parser", AST], None, ] diff --git a/python/tvm/script/parse/doc.py b/python/tvm/script/parse/doc.py new file mode 100644 index 000000000000..edf6225489e5 --- /dev/null +++ b/python/tvm/script/parse/doc.py @@ -0,0 +1,198 @@ +import ast +import inspect +import typing +from collections import defaultdict + +from . import doc_core as doc +from .doc_core import * # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614 + +FnToDoc = typing.Callable[[ast.AST], doc.AST] +FnFromDoc = typing.Callable[[doc.AST], ast.AST] + + +class Entry: + to_doc: typing.Optional[FnToDoc] + from_doc: typing.Optional[FnFromDoc] + + def __init__(self): + self.to_doc = None + self.from_doc = None + + +class Registry: + _inst: typing.Optional["Registry"] = None + table: typing.Dict[str, Entry] + + def __init__(self): + self.table = defaultdict(Entry) + + +def register_to_doc(name: str): + def f(to_doc: FnToDoc): # pylint: disable=redefined-outer-name + reg = Registry._inst # pylint: disable=protected-access + reg.table[name].to_doc = to_doc + + return f + + +def register_from_doc(name: str): + def f(to_doc: FnFromDoc): # pylint: disable=redefined-outer-name + reg = Registry._inst # pylint: disable=protected-access + reg.table[name].from_doc = to_doc + + return f + + +def _is_atomic_type(node): + return ( + node is None + or node in [..., True, False] + or isinstance( + node, + ( + int, + float, + str, + bool, + bytes, + complex, + ), + ) + ) + + +def _get_registry_entry(cls_name, attr): + cls_name = cls_name.split(".")[-1] + reg = Registry._inst # pylint: disable=protected-access + if cls_name in reg.table: + entry = reg.table[cls_name] + return getattr(entry, attr, None) + return None + + +def from_doc(node): + if _is_atomic_type(node): + return node + if isinstance(node, tuple): + return tuple(from_doc(n) for n in node) + if isinstance(node, list): + return [from_doc(n) for n in node] + func = _get_registry_entry(node.__class__.__name__, "from_doc") + if not func: + raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}") + return func(node) + + +def to_doc(node): + if _is_atomic_type(node): + return node + if isinstance(node, tuple): + return tuple(to_doc(n) for n in node) + if isinstance(node, list): + return [to_doc(n) for n in node] + func = _get_registry_entry(node.__class__.__name__, "to_doc") + if not func: + raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}") + return func(node) + + +def _register_default(): + class DefaultTranslator: + def __init__(self, doc_cls, func, fields): + self.doc_cls = doc_cls # getattr(doc, name) + self.func = func + self.fields = fields + + def __call__(self, node): + kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields} + return self.doc_cls(**kv) + + Registry._inst = Registry() # pylint: disable=protected-access + for cls_name in dir(doc): + doc_cls = getattr(doc, cls_name) + if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST): + assert "." not in cls_name + register_to_doc(cls_name)( + DefaultTranslator( + getattr(doc, cls_name), + to_doc, + doc_cls._FIELDS, # pylint: disable=protected-access + ) + ) + register_from_doc(cls_name)( + DefaultTranslator( + getattr(ast, cls_name), + from_doc, + doc_cls._FIELDS, # pylint: disable=protected-access + ) + ) + + +def parse( + source, + filename="", + mode="exec", + *, + type_comments=False, + feature_version=None, +) -> doc.AST: + program = ast.parse( + source=source, + filename=filename, + mode=mode, + type_comments=type_comments, + feature_version=feature_version, + ) + return to_doc(program) + + +class NodeVisitor: + def visit(self, node: doc.AST) -> None: + if isinstance(node, (list, tuple)): + for item in node: + self.visit(item) + return + if not isinstance(node, doc.AST): + return + return getattr( + self, + "visit_" + node.__class__.__name__.split(".")[-1], + self.generic_visit, + )(node) + + def generic_visit(self, node: doc.AST) -> None: + for field in node.__class__._FIELDS: # pylint: disable=protected-access + value = getattr(node, field, None) + if value is None: + pass + elif isinstance(value, (doc.AST, list, tuple)): + self.visit(value) + + +class NodeTransformer: + def visit(self, node: doc.AST) -> doc.AST: + if isinstance(node, list): + return [self.visit(item) for item in node] + if isinstance(node, tuple): + return tuple(self.visit(item) for item in node) + if not isinstance(node, doc.AST): + return node + return getattr( + self, + "visit_" + node.__class__.__name__.split(".")[-1], + self.generic_visit, + )(node) + + def generic_visit(self, node: doc.AST) -> doc.AST: + kv: typing.Dict[str, typing.Any] = {} + for field in node.__class__._FIELDS: # pylint: disable=protected-access + value = getattr(node, field, None) + if value is None: + pass + elif isinstance(value, (doc.AST, list, tuple)): + value = self.visit(value) + kv[field] = value + return node.__class__(**kv) + + +_register_default() diff --git a/python/tvm/script/parse/doc_core.py b/python/tvm/script/parse/doc_core.py new file mode 100644 index 000000000000..cd1485649bc1 --- /dev/null +++ b/python/tvm/script/parse/doc_core.py @@ -0,0 +1,1279 @@ +# pylint: disable=redefined-outer-name,missing-docstring,invalid-name +# pylint: disable=useless-super-delegation,redefined-builtin +# pylint: disable=too-few-public-methods,too-many-arguments +class AST: + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__() + self.lineno = lineno + self.col_offset = col_offset + self.end_lineno = end_lineno + self.end_col_offset = end_col_offset + + +class mod(AST): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Module(mod): + _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.body = body + + +class Interactive(mod): + _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.body = body + + +class Expression(mod): + _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.body = body + + +class stmt(AST): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class FunctionDef(stmt): + _FIELDS = [ + "name", + "args", + "body", + "decorator_list", + "returns", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__( + self, + name, + args, + body, + decorator_list, + returns, + lineno, + col_offset, + end_lineno, + end_col_offset, + ): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.name = name + self.args = args + self.body = body + self.decorator_list = decorator_list + self.returns = returns + + +class AsyncFunctionDef(stmt): + _FIELDS = [ + "name", + "args", + "body", + "decorator_list", + "returns", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__( + self, + name, + args, + body, + decorator_list, + returns, + lineno, + col_offset, + end_lineno, + end_col_offset, + ): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.name = name + self.args = args + self.body = body + self.decorator_list = decorator_list + self.returns = returns + + +class ClassDef(stmt): + _FIELDS = [ + "name", + "bases", + "keywords", + "body", + "decorator_list", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__( + self, + name, + bases, + keywords, + body, + decorator_list, + lineno, + col_offset, + end_lineno, + end_col_offset, + ): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.name = name + self.bases = bases + self.keywords = keywords + self.body = body + self.decorator_list = decorator_list + + +class Return(stmt): + _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + + +class Delete(stmt): + _FIELDS = ["targets", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, targets, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.targets = targets + + +class Assign(stmt): + _FIELDS = [ + "targets", + "value", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, targets, value, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.targets = targets + self.value = value + + +class AugAssign(stmt): + _FIELDS = [ + "target", + "op", + "value", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, target, op, value, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.target = target + self.op = op + self.value = value + + +class AnnAssign(stmt): + _FIELDS = [ + "target", + "annotation", + "value", + "simple", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__( + self, + target, + annotation, + value, + simple, + lineno, + col_offset, + end_lineno, + end_col_offset, + ): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.target = target + self.annotation = annotation + self.value = value + self.simple = simple + + +class For(stmt): + _FIELDS = [ + "target", + "iter", + "body", + "orelse", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, target, iter, body, orelse, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.target = target + self.iter = iter + self.body = body + self.orelse = orelse + + +class AsyncFor(stmt): + _FIELDS = [ + "target", + "iter", + "body", + "orelse", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, target, iter, body, orelse, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.target = target + self.iter = iter + self.body = body + self.orelse = orelse + + +class While(stmt): + _FIELDS = [ + "test", + "body", + "orelse", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.test = test + self.body = body + self.orelse = orelse + + +class If(stmt): + _FIELDS = [ + "test", + "body", + "orelse", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.test = test + self.body = body + self.orelse = orelse + + +class With(stmt): + _FIELDS = ["items", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, items, body, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.items = items + self.body = body + + +class AsyncWith(stmt): + _FIELDS = ["items", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, items, body, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.items = items + self.body = body + + +class Raise(stmt): + _FIELDS = ["exc", "cause", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, exc, cause, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.exc = exc + self.cause = cause + + +class Try(stmt): + _FIELDS = [ + "body", + "handlers", + "orelse", + "finalbody", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__( + self, + body, + handlers, + orelse, + finalbody, + lineno, + col_offset, + end_lineno, + end_col_offset, + ): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.body = body + self.handlers = handlers + self.orelse = orelse + self.finalbody = finalbody + + +class Assert(stmt): + _FIELDS = ["test", "msg", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, test, msg, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.test = test + self.msg = msg + + +class Import(stmt): + _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.names = names + + +class ImportFrom(stmt): + _FIELDS = [ + "module", + "names", + "level", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, module, names, level, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.module = module + self.names = names + self.level = level + + +class Global(stmt): + _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.names = names + + +class Nonlocal(stmt): + _FIELDS = ["names", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, names, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.names = names + + +class Expr(stmt): + _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + + +class Pass(stmt): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Break(stmt): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Continue(stmt): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class expr(AST): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class BoolOp(expr): + _FIELDS = ["op", "values", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, op, values, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.op = op + self.values = values + + +class BinOp(expr): + _FIELDS = [ + "left", + "op", + "right", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, left, op, right, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.left = left + self.op = op + self.right = right + + +class UnaryOp(expr): + _FIELDS = ["op", "operand", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, op, operand, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.op = op + self.operand = operand + + +class Lambda(expr): + _FIELDS = ["args", "body", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, args, body, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.args = args + self.body = body + + +class IfExp(expr): + _FIELDS = [ + "test", + "body", + "orelse", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, test, body, orelse, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.test = test + self.body = body + self.orelse = orelse + + +class Dict(expr): + _FIELDS = ["keys", "values", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, keys, values, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.keys = keys + self.values = values + + +class Set(expr): + _FIELDS = ["elts", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, elts, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.elts = elts + + +class ListComp(expr): + _FIELDS = [ + "elt", + "generators", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.elt = elt + self.generators = generators + + +class SetComp(expr): + _FIELDS = [ + "elt", + "generators", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.elt = elt + self.generators = generators + + +class DictComp(expr): + _FIELDS = [ + "key", + "value", + "generators", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, key, value, generators, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.key = key + self.value = value + self.generators = generators + + +class GeneratorExp(expr): + _FIELDS = [ + "elt", + "generators", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, elt, generators, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.elt = elt + self.generators = generators + + +class Await(expr): + _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + + +class Yield(expr): + _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + + +class YieldFrom(expr): + _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + + +class Compare(expr): + _FIELDS = [ + "left", + "ops", + "comparators", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, left, ops, comparators, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.left = left + self.ops = ops + self.comparators = comparators + + +class Call(expr): + _FIELDS = [ + "func", + "args", + "keywords", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, func, args, keywords, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.func = func + self.args = args + self.keywords = keywords + + +class FormattedValue(expr): + _FIELDS = [ + "value", + "conversion", + "format_spec", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__( + self, + value, + conversion, + format_spec, + lineno, + col_offset, + end_lineno, + end_col_offset, + ): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + self.conversion = conversion + self.format_spec = format_spec + + +class JoinedStr(expr): + _FIELDS = ["values", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, values, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.values = values + + +class Constant(expr): + _FIELDS = [ + "value", + "kind", + "s", + "n", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, value, kind, s, n, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + self.kind = kind + self.s = s + self.n = n + + +class NamedExpr(expr): + _FIELDS = [ + "target", + "value", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, target, value, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.target = target + self.value = value + + +class Attribute(expr): + _FIELDS = [ + "value", + "attr", + "ctx", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, value, attr, ctx, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + self.attr = attr + self.ctx = ctx + + +class slice(AST): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Slice(slice): + _FIELDS = [ + "lower", + "upper", + "step", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, lower, upper, step, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.lower = lower + self.upper = upper + self.step = step + + +class ExtSlice(slice): + _FIELDS = ["dims", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, dims, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.dims = dims + + +class Index(slice): + _FIELDS = ["value", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, value, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + + +class Subscript(expr): + _FIELDS = [ + "value", + "slice", + "ctx", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, value, slice, ctx, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + self.slice = slice + self.ctx = ctx + + +class Starred(expr): + _FIELDS = ["value", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, value, ctx, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.value = value + self.ctx = ctx + + +class Name(expr): + _FIELDS = ["id", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, id, ctx, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.id = id + self.ctx = ctx + + +class List(expr): + _FIELDS = ["elts", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, elts, ctx, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.elts = elts + self.ctx = ctx + + +class Tuple(expr): + _FIELDS = ["elts", "ctx", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, elts, ctx, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.elts = elts + self.ctx = ctx + + +class expr_context(AST): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class AugLoad(expr_context): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class AugStore(expr_context): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Param(expr_context): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Suite(mod): + _FIELDS = ["body", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, body, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.body = body + + +class Del(expr_context): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Load(expr_context): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Store(expr_context): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class boolop(AST): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class And(boolop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Or(boolop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class operator(AST): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Add(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class BitAnd(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class BitOr(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class BitXor(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Div(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class FloorDiv(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class LShift(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Mod(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Mult(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class MatMult(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Pow(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class RShift(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Sub(operator): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class unaryop(AST): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Invert(unaryop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Not(unaryop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class UAdd(unaryop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class USub(unaryop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class cmpop(AST): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Eq(cmpop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Gt(cmpop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class GtE(cmpop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class In(cmpop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Is(cmpop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class IsNot(cmpop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class Lt(cmpop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class LtE(cmpop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class NotEq(cmpop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class NotIn(cmpop): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class comprehension(AST): + _FIELDS = [ + "target", + "iter", + "ifs", + "is_async", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__( + self, + target, + iter, + ifs, + is_async, + lineno, + col_offset, + end_lineno, + end_col_offset, + ): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.target = target + self.iter = iter + self.ifs = ifs + self.is_async = is_async + + +class excepthandler(AST): + _FIELDS = ["lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + + +class ExceptHandler(excepthandler): + _FIELDS = [ + "type", + "name", + "body", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, type, name, body, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.type = type + self.name = name + self.body = body + + +class arguments(AST): + _FIELDS = [ + "args", + "vararg", + "kwonlyargs", + "kw_defaults", + "kwarg", + "defaults", + "posonlyargs", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__( + self, + args, + vararg, + kwonlyargs, + kw_defaults, + kwarg, + defaults, + posonlyargs, + lineno, + col_offset, + end_lineno, + end_col_offset, + ): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.args = args + self.vararg = vararg + self.kwonlyargs = kwonlyargs + self.kw_defaults = kw_defaults + self.kwarg = kwarg + self.defaults = defaults + self.posonlyargs = posonlyargs + + +class arg(AST): + _FIELDS = [ + "arg", + "annotation", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__(self, arg, annotation, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.arg = arg + self.annotation = annotation + + +class keyword(AST): + _FIELDS = ["arg", "value", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, arg, value, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.arg = arg + self.value = value + + +class alias(AST): + _FIELDS = ["name", "asname", "lineno", "col_offset", "end_lineno", "end_col_offset"] + + def __init__(self, name, asname, lineno, col_offset, end_lineno, end_col_offset): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.name = name + self.asname = asname + + +class withitem(AST): + _FIELDS = [ + "context_expr", + "optional_vars", + "lineno", + "col_offset", + "end_lineno", + "end_col_offset", + ] + + def __init__( + self, + context_expr, + optional_vars, + lineno, + col_offset, + end_lineno, + end_col_offset, + ): + super().__init__(lineno, col_offset, end_lineno, end_col_offset) + self.context_expr = context_expr + self.optional_vars = optional_vars diff --git a/python/tvm/script/parse/entry.py b/python/tvm/script/parse/entry.py index 6239487327a0..b4f756b0582c 100644 --- a/python/tvm/script/parse/entry.py +++ b/python/tvm/script/parse/entry.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. """The entry point of TVM parser.""" -import ast import inspect from typing import Any, Dict, Optional, Union from ..builder import Builder +from . import doc from .parser import Parser @@ -30,7 +30,7 @@ class SourceCode: source: str full_source: str - def __init__(self, program: Union[str, ast.AST]): + def __init__(self, program: Union[str, doc.AST]): if isinstance(program, str): self.source_name = "" self.start_line = 1 @@ -65,12 +65,12 @@ def __init__(self, program: Union[str, ast.AST]): src, _ = inspect.findsource(program) # type: ignore self.full_source = "".join(src) - def as_ast(self) -> ast.AST: - return ast.parse(self.source) + def as_ast(self) -> doc.AST: + return doc.parse(self.source) def parse( - program: Union[ast.AST, Any, str], + program: Union[doc.AST, Any, str], extra_vars: Optional[Dict[str, Any]] = None, ): program_ast = SourceCode(program).as_ast() diff --git a/python/tvm/script/parse/evaluator.py b/python/tvm/script/parse/evaluator.py index e4f62b7a81c6..1f02621f282f 100644 --- a/python/tvm/script/parse/evaluator.py +++ b/python/tvm/script/parse/evaluator.py @@ -18,11 +18,14 @@ import ast from typing import Any, Dict, Optional, Union +from . import doc + def eval_expr( - node: Union[ast.expr, ast.Expression], + node: Union[doc.expr, doc.Expression], dict_globals: Optional[Dict[str, Any]], ) -> Any: + node = doc.from_doc(node) if isinstance(node, ast.expr): node = ast.Expression(body=node) assert isinstance(node, ast.Expression) @@ -34,9 +37,10 @@ def eval_expr( def eval_assign( - target: ast.expr, + target: doc.expr, source: Any, ) -> Dict[str, Any]: + target = doc.from_doc(target) assert isinstance(target, ast.expr) RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name rhs_var_name = RHS_VAR_NAME diff --git a/python/tvm/script/parse/parser.py b/python/tvm/script/parse/parser.py index 6101a252acef..e9939672d4bf 100644 --- a/python/tvm/script/parse/parser.py +++ b/python/tvm/script/parse/parser.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. """The core parser""" -import ast from typing import Any, Dict, List, Optional, Union from ..builder import def_ -from . import dispatch +from . import dispatch, doc from .evaluator import eval_assign, eval_expr from .utils import deferred from .var_table import VarTable @@ -33,7 +32,7 @@ def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod: return lambda self, node: self.generic_visit(node) -def _handle_function(self: "Parser", node: ast.FunctionDef) -> None: +def _handle_function(self: "Parser", node: doc.FunctionDef) -> None: if not node.decorator_list: self.report_error(node, "Function must be decorated") # TODO: only the last decorator is parsed @@ -47,7 +46,7 @@ def _handle_function(self: "Parser", node: ast.FunctionDef) -> None: self.report_error(node, "The parser does not understand the decorator") -class Parser(ast.NodeVisitor): +class Parser(doc.NodeVisitor): """The TVMScript parser""" dispatch_tokens: List[str] @@ -66,7 +65,7 @@ def pop_token(): def eval_expr( self, - node: Union[ast.Expression, ast.expr], + node: Union[doc.Expression, doc.expr], extra_vars: Optional[Dict[str, Any]] = None, ) -> Any: var_values = self.var_table.get() @@ -77,7 +76,7 @@ def eval_expr( def eval_assign( self, - target: ast.expr, + target: doc.expr, source: Any, ) -> Dict[str, Any]: var_values = eval_assign(target, source) @@ -86,24 +85,24 @@ def eval_assign( self.var_table.add(k, v) return var_values - def report_error(self, node: ast.AST, msg: str) -> None: # pylint: disable=no-self-use + def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-self-use raise SyntaxError(f"At {node.lineno}:{node.col_offset}: {msg}") - def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # pylint: disable=invalid-name + def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name _handle_function(self, node) - def visit_body(self, node: List[ast.stmt]) -> Any: + def visit_body(self, node: List[doc.stmt]) -> Any: for stmt in node: self.visit(stmt) - def visit_arguments(self, node: ast.arguments) -> Any: + def visit_arguments(self, node: doc.arguments) -> Any: _dispatch(self, "arguments")(self, node) - def visit_For(self, node: ast.For) -> Any: # pylint: disable=invalid-name + def visit_For(self, node: doc.For) -> Any: # pylint: disable=invalid-name _dispatch(self, "For")(self, node) - def visit_With(self, node: ast.With) -> Any: # pylint: disable=invalid-name + def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name _dispatch(self, "With")(self, node) - def visit_Assign(self, node: ast.Assign) -> Any: # pylint: disable=invalid-name + def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name _dispatch(self, "Assign")(self, node) diff --git a/python/tvm/script/parse/tir/tir.py b/python/tvm/script/parse/tir/tir.py index c219e24076c0..202f51614a9b 100644 --- a/python/tvm/script/parse/tir/tir.py +++ b/python/tvm/script/parse/tir/tir.py @@ -15,17 +15,16 @@ # specific language governing permissions and limitations # under the License. -import ast import contextlib from ...builder import Frame from ...builder import tir as T -from .. import dispatch +from .. import dispatch, doc from ..parser import Parser @dispatch.register(token="tir", type_name="For") -def visit_for(self: Parser, node: ast.For) -> None: +def visit_for(self: Parser, node: doc.For) -> None: for_frame = self.eval_expr(node.iter) if not isinstance(for_frame, T.ForFrame): self.report_error( @@ -40,7 +39,7 @@ def visit_for(self: Parser, node: ast.For) -> None: @dispatch.register(token="tir", type_name="Assign") -def visit_assign(self: Parser, node: ast.Assign) -> None: +def visit_assign(self: Parser, node: doc.Assign) -> None: if len(node.targets) != 1: self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") lhs = node.targets[0] @@ -49,7 +48,7 @@ def visit_assign(self: Parser, node: ast.Assign) -> None: @dispatch.register(token="tir", type_name="With") -def visit_with(self: Parser, node: ast.With) -> None: +def visit_with(self: Parser, node: doc.With) -> None: with contextlib.ExitStack() as stack: stack.enter_context(self.var_table.with_frame()) for item in node.items: @@ -68,7 +67,7 @@ def visit_with(self: Parser, node: ast.With) -> None: @dispatch.register(token="tir", type_name="FunctionDef") -def visit_function_def(self: Parser, node: ast.FunctionDef) -> None: +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: with self.var_table.with_frame(): self.var_table.add("range", T.serial) with T.prim_func(node.name): @@ -79,7 +78,7 @@ def visit_function_def(self: Parser, node: ast.FunctionDef) -> None: @dispatch.register(token="tir", type_name="arguments") -def visit_arguments(self: Parser, node: ast.arguments) -> None: +def visit_arguments(self: Parser, node: doc.arguments) -> None: # TODO: handle different types of arguments: # - vararg: arg | None # - kwonlyargs: list[arg]