Skip to content

Commit

Permalink
[Relay][Parser] Improve Relay parser and pretty printing, including C…
Browse files Browse the repository at this point in the history
…MAKE (apache#2377)
  • Loading branch information
jroesch authored and Anthony-Mai committed Jan 20, 2019
1 parent 597a684 commit 0680b68
Show file tree
Hide file tree
Showing 9 changed files with 400 additions and 219 deletions.
24 changes: 19 additions & 5 deletions cmake/modules/ANTLR.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
if(USE_ANTLR)
if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar)
set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar")
file(GLOB_RECURSE ANTLR4
/usr/local/lib/antlr-*-complete.jar
/usr/local/Cellar/*antlr-*-complete.jar)

if(DEFINED ANTLR4)
# Get the first element of the list of antlr jars.
# Sort and reverse the list so the item selected is the highest
# version in lib or else in Cellar if no lib installation exists.
list(SORT ANTLR4)
list(REVERSE ANTLR4)
list(GET ANTLR4 0 ANTLR4)
set(RELAY_PARSER_DIR
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar)

Expand All @@ -14,15 +22,21 @@ if(USE_ANTLR)
${RELAY_PARSER_DIR}/py3/RelayParser.py
${RELAY_PARSER_DIR}/py3/RelayLexer.py)

set(JAVA_HOME $ENV{JAVA_HOME})
if (NOT DEFINED JAVA_HOME)
# Hack to get system to search for Java itself.
set(JAVA_HOME "/usr")
endif()

# Generate ANTLR grammar for parsing.
add_custom_command(OUTPUT ${RELAY_PARSER}
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2
COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
DEPENDS ${RELAY_PARSER_DIR}/Relay.g4
WORKING_DIRECTORY ${RELAY_PARSER_DIR})

add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER})
else()
message(FATAL_ERROR "Can't find ANTLR4!")
message(FATAL_ERROR "Can't find ANTLR4: ANTLR4=" ${ANTLR4})
endif()
endif(USE_ANTLR)
4 changes: 3 additions & 1 deletion include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ class SourceName : public NodeRef {
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SourceNameNode* operator->() const;
inline const SourceNameNode* operator->() const {
return static_cast<SourceNameNode*>(this->node_.get());
}

/*!
* \brief Get an SourceName for a given operator name.
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api

_init_api("relay._base", __name__)
135 changes: 113 additions & 22 deletions python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
import sys

from collections import deque
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict

import tvm

from . import module
from .base import Span, SourceName
from . import expr
from . import ty
from . import op


class ParseError(Exception):
"""Exception type for parse errors."""

Expand Down Expand Up @@ -76,22 +80,46 @@ def lookup(scopes, name):
return val
return None

def spanify(f):
"""A decorator which attaches span information
to the value returned by calling `f`.
Intended for use with the below AST visiting
methods. The idea is that after we do the work
of constructing the AST we attach Span information.
"""

def _wrapper(*args, **kwargs):
# Assumes 0th arg is self and gets source_name from object.
sn = args[0].source_name
# Assumes 1st arg is an ANTLR parser context.
ctx = args[1]
ast = f(*args, **kwargs)
line, col = ctx.getSourceInterval()
sp = Span(sn, line, col)
ast.set_span(sp)
return ast
return _wrapper

# TODO(@jmp): Use https://stackoverflow.com/q/13889941
# to figure out how to get ANTLR4 to be more unhappy about syntax errors
class ParseTreeToRelayIR(RelayVisitor):
"""Parse Relay text format into Relay IR."""

def __init__(self):
# type: () -> None
def __init__(self, source_name):
# type: (str) -> None
self.source_name = source_name
self.module = module.Module({}) # type: module.Module

# Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.global_var_scope = deque() # type: Scope[expr.GlobalVar]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.global_var_scope = deque() # type: Scope[expr.GlobalVar]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
self.graph_expr = [] # type: List[expr.Expr]

super(ParseTreeToRelayIR, self).__init__()


def enter_var_scope(self):
# type: () -> None
"""Enter a new Var scope so it can be popped off later."""
Expand Down Expand Up @@ -146,20 +174,25 @@ def visitTerminal(self, node):

node_type = node.getSymbol().type
node_text = node.getText()
name = node_text[1:]

# variables
if node_type == RelayLexer.GLOBAL_VAR:
return lookup([self.global_var_scope], node_text[1:])
return lookup(deque([self.global_var_scope]), node_text[1:])
elif node_type == RelayLexer.LOCAL_VAR:
name = node_text[1:]
# Remove the leading '%' and lookup the name.
var = lookup(self.var_scopes, name)
if var is None:
raise ParseError("Couldn't resolve `{}`.".format(name))

return var
elif node_type == RelayLexer.GRAPH_VAR:
try:
return self.graph_expr[int(name)]
except IndexError:
raise ParseError("Couldn't resolve `{}`".format(name))

# data types
elif node_type == RelayLexer.INT:
elif node_type == RelayLexer.NAT:
return int(node_text)
elif node_type == RelayLexer.FLOAT:
return float(node_text)
Expand Down Expand Up @@ -190,7 +223,7 @@ def getType_(self, ctx):
return self.visit(ctx)

def visitProg(self, ctx):
# type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment]
# type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module]
if ctx.defn():
self.visit_list(ctx.defn())
return self.module
Expand Down Expand Up @@ -219,7 +252,7 @@ def visitScalarFloat(self, ctx):

def visitScalarInt(self, ctx):
# type: (RelayParser.ScalarIntContext) -> expr.Constant
return expr.const(self.visit(ctx.INT()))
return expr.const(self.visit(ctx.NAT()))

def visitScalarBool(self, ctx):
# type: (RelayParser.ScalarBoolContext) -> expr.Constant
Expand All @@ -240,7 +273,7 @@ def visitTuple(self, ctx):
return expr.Tuple(tup)

# Currently doesn't support mutable sequencing.
def visitSeq(self, ctx):
def visitLet(self, ctx):
# type: (RelayParser.SeqContext) -> expr.Let
"""Desugar various sequence constructs to Relay Let nodes."""
if ctx.MUT() is not None:
Expand All @@ -253,7 +286,7 @@ def visitSeq(self, ctx):
else:
local_var = ctx.var().ident().LOCAL_VAR()
if local_var is None:
raise ParseError('Only local ids may be used in `let`s.')
raise ParseError("Only local ids may be used in `let`s.")
ident = local_var.getText()[1:]
type_ = self.getType_(ctx.var().type_())

Expand All @@ -278,12 +311,14 @@ def visitBinOp(self, ctx):

return relay_op(arg0, arg1)

@spanify
def visitVar(self, ctx):
# type: (RelayParser.VarContext) -> expr.Var
"""Visit a single variable."""
ident = ctx.ident().LOCAL_VAR()

if ident is None:
raise ParseError('Only local ids may be used in params.')
raise ParseError("Only local ids may be used in vars.")

type_ = self.getType_(ctx.type_())

Expand All @@ -293,15 +328,33 @@ def visitVarList(self, ctx):
# type: (RelayParser.VarListContext) -> List[expr.Var]
return self.visit_list(ctx.var())

# TODO: support a larger class of values than just Relay exprs
def visitAttr(self, ctx):
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
return (ctx.CNAME().getText(), self.visit(ctx.expr()))

def visitAttrList(self, ctx):
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
return dict(self.visit_list(ctx.attr()))

def visitArgList(self,
ctx # type: RelayParser.ArgListContext
):
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
var_list = self.visit(ctx.varList()) if ctx.varList() else None
attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None

return (var_list, attr_list)

def mk_func(self, ctx):
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function
"""Construct a function from either a Func or Defn."""

# Enter var scope early to put params in scope.
self.enter_var_scope()
# Capture type params in params.
self.enter_type_param_scope()
var_list = self.visit(ctx.varList())
var_list, attr_list = self.visit(ctx.argList())
ret_type = self.getType_(ctx.type_())

type_params = list(self.exit_type_param_scope())
Expand All @@ -311,22 +364,28 @@ def mk_func(self, ctx):
body = self.visit(ctx.body())
self.exit_var_scope()

return expr.Function(var_list, body, ret_type, type_params) # type: ignore
attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None

return expr.Function(var_list, body, ret_type, type_params, attrs)

@spanify
def visitFunc(self, ctx):
# type: (RelayParser.FuncContext) -> expr.Function
return self.mk_func(ctx)

# TODO: how to set spans for definitions?
# @spanify
def visitDefn(self, ctx):
# type: (RelayParser.DefnContext) -> None
ident = ctx.ident().GLOBAL_VAR()
if ident is None:
raise ParseError('Only global ids may be used in `def`s.')
raise ParseError("Only global ids may be used in `def`s.")
ident_name = ident.getText()[1:]
ident = self.mk_global_var(ident_name)

self.module[ident] = self.mk_func(ctx)

@spanify
def visitCall(self, ctx):
# type: (RelayParser.CallContext) -> expr.Call
visited_exprs = self.visit_list(ctx.expr())
Expand All @@ -336,6 +395,7 @@ def visitCall(self, ctx):

return expr.Call(func, args, None, None)

@spanify
def visitIfElse(self, ctx):
# type: (RelayParser.IfElseContext) -> expr.If
"""Construct a Relay If node. Creates a new scope for each branch."""
Expand All @@ -351,6 +411,27 @@ def visitIfElse(self, ctx):

return expr.If(cond, true_branch, false_branch)

@spanify
def visitGraph(self, ctx):
# type: (RelayParser.GraphContext) -> expr.Expr
"""Visit a graph variable assignment."""
if ctx.ident().GRAPH_VAR() is None:
raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText()))
graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:])

self.enter_var_scope()
value = self.visit(ctx.expr(0))
self.exit_var_scope()

if graph_nid != len(self.graph_expr):
raise ParseError(
"Expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \
"but got `%{}`".format(graph_nid))
self.graph_expr.append(value)

kont = self.visit(ctx.expr(1))
return kont

# Types

# pylint: disable=unused-argument
Expand Down Expand Up @@ -428,8 +509,18 @@ def make_parser(data):
token_stream = CommonTokenStream(lexer)
return RelayParser(token_stream)

def fromtext(data):
# type: (str) -> Union[expr.Expr, env.Environment]
__source_name_counter__ = 0

def fromtext(data, source_name=None):
# type: (str, str) -> Union[expr.Expr, module.Module]
"""Parse a Relay program."""
global __source_name_counter__

if source_name is None:
source_name = "source_file{0}".format(__source_name_counter__)

if isinstance(source_name, str):
source_name = SourceName(source_name)

tree = make_parser(data).prog()
return ParseTreeToRelayIR().visit(tree)
return ParseTreeToRelayIR(source_name).visit(tree)
10 changes: 10 additions & 0 deletions python/tvm/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make
from . import _expr
from . import _base

NodeBase = NodeBase

Expand Down Expand Up @@ -63,6 +64,9 @@ def astext(self, show_meta_data=True, annotate=None):
"""
return _expr.RelayPrint(self, show_meta_data, annotate)

def set_span(self, span):
_base.set_span(self, span)


@register_relay_node
class Span(RelayNode):
Expand All @@ -71,6 +75,12 @@ class Span(RelayNode):
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)

@register_relay_node
class SourceName(RelayNode):
"""A identifier for a source location"""

def __init__(self, name):
self.__init_handle_by_constructor__(_make.SourceName, name)

@register_relay_node
class Id(NodeBase):
Expand Down
Loading

0 comments on commit 0680b68

Please sign in to comment.