Skip to content

Commit

Permalink
annotate _parser, add spanify comments, fix parser tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joshpoll committed Jan 15, 2019
1 parent 4c0c7fd commit 2603606
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
16 changes: 10 additions & 6 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,14 @@ def lookup(scopes, name):
return None

def spanify(f):
""" Adds span information to the output of f. """
def _wrapper(*args, **kwargs):
# 0th arg assumed to be self. Gets source name from object.
sn = args[0].source_name
# 1st arg is assumed to be a parser context
ctx = args[1]
ast = f(*args, **kwargs)
# get line and col information from ANTLR parser context
line, col = ctx.getSourceInterval()
sp = Span(sn, line, col)
ast.set_span(sp)
Expand All @@ -96,14 +100,14 @@ class ParseTreeToRelayIR(RelayVisitor):
"""Parse Relay text format into Relay IR."""

def __init__(self, source_name):
# type: () -> None
self.source_name = source_name
self.module = module.Module({}) # type: module.Module
# type: (str) -> None
self.source_name = source_name # type: str
self.module = module.Module({}) # type: module.Module

# Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
self.graph_expr = []
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
self.graph_expr = [] # type: List[expr.Expr]

super(ParseTreeToRelayIR, self).__init__()

Expand Down
3 changes: 1 addition & 2 deletions tests/python/relay/test_ir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from typing import Union
from functools import wraps
if enabled():
from tvm.relay._parser import ParseError
raises_parse_error = raises(ParseError)
raises_parse_error = raises(tvm._ffi.base.TVMError)
else:
raises_parse_error = lambda x: x

Expand Down

0 comments on commit 2603606

Please sign in to comment.