-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Parser] Improve Relay parser and pretty printing, including CMAKE #2377
Changes from all commits
bd55406
73015bc
c07eba1
c7e09b9
8d4445f
a92a668
3cdfd8a
cfde2f7
a6e0b96
ea61b8b
5dbc2af
9eb5861
a1c30a7
5b1b016
2eb17ba
616aa65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,15 @@ | ||
if(USE_ANTLR) | ||
if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar) | ||
set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar") | ||
file(GLOB_RECURSE ANTLR4 | ||
/usr/local/lib/antlr-*-complete.jar | ||
/usr/local/Cellar/*antlr-*-complete.jar) | ||
|
||
if(DEFINED ANTLR4) | ||
# Get the first element of the list of antlr jars. | ||
# Sort and reverse the list so the item selected is the highest | ||
# version in lib or else in Cellar if no lib installation exists. | ||
list(SORT ANTLR4) | ||
list(REVERSE ANTLR4) | ||
list(GET ANTLR4 0 ANTLR4) | ||
set(RELAY_PARSER_DIR | ||
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) | ||
|
||
|
@@ -14,15 +22,21 @@ if(USE_ANTLR) | |
${RELAY_PARSER_DIR}/py3/RelayParser.py | ||
${RELAY_PARSER_DIR}/py3/RelayLexer.py) | ||
|
||
set(JAVA_HOME $ENV{JAVA_HOME}) | ||
if (NOT DEFINED JAVA_HOME) | ||
# Hack to get system to search for Java itself. | ||
set(JAVA_HOME "/usr") | ||
endif() | ||
|
||
# Generate ANTLR grammar for parsing. | ||
add_custom_command(OUTPUT ${RELAY_PARSER} | ||
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 | ||
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 | ||
COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 | ||
COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 | ||
DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 | ||
WORKING_DIRECTORY ${RELAY_PARSER_DIR}) | ||
|
||
add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) | ||
else() | ||
message(FATAL_ERROR "Can't find ANTLR4!") | ||
message(FATAL_ERROR "Can't find ANTLR4: ANTLR4=" ${ANTLR4}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What could ANTLR4 be set to if it's not defined? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we were using exists the variable would be set, but then not work, good to just print it out. |
||
endif() | ||
endif(USE_ANTLR) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable | ||
"""The interface of expr function exposed from C++.""" | ||
from tvm._ffi.function import _init_api | ||
|
||
_init_api("relay._base", __name__) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,17 @@ | |
import sys | ||
|
||
from collections import deque | ||
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any | ||
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict | ||
|
||
import tvm | ||
|
||
from . import module | ||
from .base import Span, SourceName | ||
from . import expr | ||
from . import ty | ||
from . import op | ||
|
||
|
||
class ParseError(Exception): | ||
"""Exception type for parse errors.""" | ||
|
||
|
@@ -76,22 +80,46 @@ def lookup(scopes, name): | |
return val | ||
return None | ||
|
||
def spanify(f): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a docstring? |
||
"""A decorator which attaches span information | ||
to the value returned by calling `f`. | ||
|
||
Intended for use with the below AST visiting | ||
methods. The idea is that after we do the work | ||
of constructing the AST we attach Span information. | ||
""" | ||
|
||
def _wrapper(*args, **kwargs): | ||
# Assumes 0th arg is self and gets source_name from object. | ||
sn = args[0].source_name | ||
# Assumes 1st arg is an ANTLR parser context. | ||
ctx = args[1] | ||
ast = f(*args, **kwargs) | ||
line, col = ctx.getSourceInterval() | ||
sp = Span(sn, line, col) | ||
ast.set_span(sp) | ||
return ast | ||
return _wrapper | ||
|
||
# TODO(@jmp): Use https://stackoverflow.com/q/13889941 | ||
# to figure out how to get ANTLR4 to be more unhappy about syntax errors | ||
class ParseTreeToRelayIR(RelayVisitor): | ||
"""Parse Relay text format into Relay IR.""" | ||
|
||
def __init__(self): | ||
# type: () -> None | ||
def __init__(self, source_name): | ||
# type: (str) -> None | ||
self.source_name = source_name | ||
self.module = module.Module({}) # type: module.Module | ||
|
||
# Adding an empty scope allows naked lets without pain. | ||
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] | ||
self.global_var_scope = deque() # type: Scope[expr.GlobalVar] | ||
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] | ||
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var] | ||
self.global_var_scope = deque() # type: Scope[expr.GlobalVar] | ||
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar] | ||
self.graph_expr = [] # type: List[expr.Expr] | ||
|
||
super(ParseTreeToRelayIR, self).__init__() | ||
|
||
|
||
def enter_var_scope(self): | ||
# type: () -> None | ||
"""Enter a new Var scope so it can be popped off later.""" | ||
|
@@ -146,20 +174,25 @@ def visitTerminal(self, node): | |
|
||
node_type = node.getSymbol().type | ||
node_text = node.getText() | ||
name = node_text[1:] | ||
|
||
# variables | ||
if node_type == RelayLexer.GLOBAL_VAR: | ||
return lookup([self.global_var_scope], node_text[1:]) | ||
return lookup(deque([self.global_var_scope]), node_text[1:]) | ||
elif node_type == RelayLexer.LOCAL_VAR: | ||
name = node_text[1:] | ||
# Remove the leading '%' and lookup the name. | ||
var = lookup(self.var_scopes, name) | ||
if var is None: | ||
raise ParseError("Couldn't resolve `{}`.".format(name)) | ||
|
||
return var | ||
elif node_type == RelayLexer.GRAPH_VAR: | ||
try: | ||
return self.graph_expr[int(name)] | ||
except IndexError: | ||
raise ParseError("Couldn't resolve `{}`".format(name)) | ||
|
||
# data types | ||
elif node_type == RelayLexer.INT: | ||
elif node_type == RelayLexer.NAT: | ||
return int(node_text) | ||
elif node_type == RelayLexer.FLOAT: | ||
return float(node_text) | ||
|
@@ -190,7 +223,7 @@ def getType_(self, ctx): | |
return self.visit(ctx) | ||
|
||
def visitProg(self, ctx): | ||
# type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment] | ||
# type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module] | ||
if ctx.defn(): | ||
self.visit_list(ctx.defn()) | ||
return self.module | ||
|
@@ -219,7 +252,7 @@ def visitScalarFloat(self, ctx): | |
|
||
def visitScalarInt(self, ctx): | ||
# type: (RelayParser.ScalarIntContext) -> expr.Constant | ||
return expr.const(self.visit(ctx.INT())) | ||
return expr.const(self.visit(ctx.NAT())) | ||
|
||
def visitScalarBool(self, ctx): | ||
# type: (RelayParser.ScalarBoolContext) -> expr.Constant | ||
|
@@ -240,7 +273,7 @@ def visitTuple(self, ctx): | |
return expr.Tuple(tup) | ||
|
||
# Currently doesn't support mutable sequencing. | ||
def visitSeq(self, ctx): | ||
def visitLet(self, ctx): | ||
# type: (RelayParser.SeqContext) -> expr.Let | ||
"""Desugar various sequence constructs to Relay Let nodes.""" | ||
if ctx.MUT() is not None: | ||
|
@@ -253,7 +286,7 @@ def visitSeq(self, ctx): | |
else: | ||
local_var = ctx.var().ident().LOCAL_VAR() | ||
if local_var is None: | ||
raise ParseError('Only local ids may be used in `let`s.') | ||
raise ParseError("Only local ids may be used in `let`s.") | ||
ident = local_var.getText()[1:] | ||
type_ = self.getType_(ctx.var().type_()) | ||
|
||
|
@@ -278,12 +311,14 @@ def visitBinOp(self, ctx): | |
|
||
return relay_op(arg0, arg1) | ||
|
||
@spanify | ||
def visitVar(self, ctx): | ||
# type: (RelayParser.VarContext) -> expr.Var | ||
"""Visit a single variable.""" | ||
ident = ctx.ident().LOCAL_VAR() | ||
|
||
if ident is None: | ||
raise ParseError('Only local ids may be used in params.') | ||
raise ParseError("Only local ids may be used in vars.") | ||
|
||
type_ = self.getType_(ctx.type_()) | ||
|
||
|
@@ -293,15 +328,33 @@ def visitVarList(self, ctx): | |
# type: (RelayParser.VarListContext) -> List[expr.Var] | ||
return self.visit_list(ctx.var()) | ||
|
||
# TODO: support a larger class of values than just Relay exprs | ||
def visitAttr(self, ctx): | ||
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr] | ||
return (ctx.CNAME().getText(), self.visit(ctx.expr())) | ||
|
||
def visitAttrList(self, ctx): | ||
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr] | ||
return dict(self.visit_list(ctx.attr())) | ||
|
||
def visitArgList(self, | ||
ctx # type: RelayParser.ArgListContext | ||
): | ||
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]] | ||
var_list = self.visit(ctx.varList()) if ctx.varList() else None | ||
attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None | ||
|
||
return (var_list, attr_list) | ||
|
||
def mk_func(self, ctx): | ||
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function | ||
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function | ||
"""Construct a function from either a Func or Defn.""" | ||
|
||
# Enter var scope early to put params in scope. | ||
self.enter_var_scope() | ||
# Capture type params in params. | ||
self.enter_type_param_scope() | ||
var_list = self.visit(ctx.varList()) | ||
var_list, attr_list = self.visit(ctx.argList()) | ||
ret_type = self.getType_(ctx.type_()) | ||
|
||
type_params = list(self.exit_type_param_scope()) | ||
|
@@ -311,22 +364,28 @@ def mk_func(self, ctx): | |
body = self.visit(ctx.body()) | ||
self.exit_var_scope() | ||
|
||
return expr.Function(var_list, body, ret_type, type_params) # type: ignore | ||
attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None | ||
|
||
return expr.Function(var_list, body, ret_type, type_params, attrs) | ||
|
||
@spanify | ||
def visitFunc(self, ctx): | ||
# type: (RelayParser.FuncContext) -> expr.Function | ||
return self.mk_func(ctx) | ||
|
||
# TODO: how to set spans for definitions? | ||
# @spanify | ||
def visitDefn(self, ctx): | ||
# type: (RelayParser.DefnContext) -> None | ||
ident = ctx.ident().GLOBAL_VAR() | ||
if ident is None: | ||
raise ParseError('Only global ids may be used in `def`s.') | ||
raise ParseError("Only global ids may be used in `def`s.") | ||
ident_name = ident.getText()[1:] | ||
ident = self.mk_global_var(ident_name) | ||
|
||
self.module[ident] = self.mk_func(ctx) | ||
|
||
@spanify | ||
def visitCall(self, ctx): | ||
# type: (RelayParser.CallContext) -> expr.Call | ||
visited_exprs = self.visit_list(ctx.expr()) | ||
|
@@ -336,6 +395,7 @@ def visitCall(self, ctx): | |
|
||
return expr.Call(func, args, None, None) | ||
|
||
@spanify | ||
def visitIfElse(self, ctx): | ||
# type: (RelayParser.IfElseContext) -> expr.If | ||
"""Construct a Relay If node. Creates a new scope for each branch.""" | ||
|
@@ -351,6 +411,27 @@ def visitIfElse(self, ctx): | |
|
||
return expr.If(cond, true_branch, false_branch) | ||
|
||
@spanify | ||
def visitGraph(self, ctx): | ||
# type: (RelayParser.GraphContext) -> expr.Expr | ||
"""Visit a graph variable assignment.""" | ||
if ctx.ident().GRAPH_VAR() is None: | ||
raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText())) | ||
graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:]) | ||
|
||
self.enter_var_scope() | ||
value = self.visit(ctx.expr(0)) | ||
self.exit_var_scope() | ||
|
||
if graph_nid != len(self.graph_expr): | ||
raise ParseError( | ||
"Expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \ | ||
"but got `%{}`".format(graph_nid)) | ||
self.graph_expr.append(value) | ||
|
||
kont = self.visit(ctx.expr(1)) | ||
return kont | ||
|
||
# Types | ||
|
||
# pylint: disable=unused-argument | ||
|
@@ -428,8 +509,18 @@ def make_parser(data): | |
token_stream = CommonTokenStream(lexer) | ||
return RelayParser(token_stream) | ||
|
||
def fromtext(data): | ||
# type: (str) -> Union[expr.Expr, env.Environment] | ||
__source_name_counter__ = 0 | ||
|
||
def fromtext(data, source_name=None): | ||
# type: (str, str) -> Union[expr.Expr, module.Module] | ||
"""Parse a Relay program.""" | ||
global __source_name_counter__ | ||
|
||
if source_name is None: | ||
source_name = "source_file{0}".format(__source_name_counter__) | ||
|
||
if isinstance(source_name, str): | ||
source_name = SourceName(source_name) | ||
|
||
tree = make_parser(data).prog() | ||
return ParseTreeToRelayIR().visit(tree) | ||
return ParseTreeToRelayIR(source_name).visit(tree) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this find a list of all matching files? What if there are multiple jars?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed. see below